Files
@ d9cf0ed8e7fd
Branch filter:
Location: OneEye/exp/kerokero/train.py
d9cf0ed8e7fd
3.6 KiB
text/x-python
Inception transfer learning (failed)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import os
from time import time
import argparse
import logging as log
import numpy as np
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,GlobalAveragePooling2D,BatchNormalization
from keras.models import Sequential,load_model,Model
from keras.optimizers import SGD
from keras.callbacks import TensorBoard
from keras.applications.inception_v3 import InceptionV3,preprocess_input
import config as cfg
import ftp
parser=argparse.ArgumentParser()
parser.add_argument("data")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
parser.add_argument("--log_dir",default="/tmp/tflogs")
args=parser.parse_args()
def createFullyConnected():
model=Sequential([
Flatten(input_shape=(224,224)),
Dense(128, activation="relu"),
Dropout(0.1),
Dense(64, activation="relu"),
Dense(8)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae','accuracy']
)
return model
def createCNN():
model=Sequential()
model.add(Conv2D(filters=16,kernel_size=2,padding="same",activation="relu",input_shape=(224,224,1)))
model.add(Dropout(0.1))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(32,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Conv2D(64,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(128,(3,3),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.4))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(500,activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(128,activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(8))
model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy'])
return model
def createPretrained():
base=InceptionV3(weights="imagenet",include_top=False,input_shape=(224,224,3))
x=base.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation="relu")(x)
predictions=Dense(8)(x)
model=Model(inputs=base.input,outputs=predictions)
for layer in base.layers:
layer.trainable=False
model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy'])
return model
if args.load_model:
model=load_model(args.load_model)
else:
model=createPretrained()
log.info("loading data...")
with np.load(args.data) as data:
trainImages=preprocess_input(data["trainImages"])
trainLabels=data["trainLabels"]
testImages=preprocess_input(data["testImages"])
testLabels=data["testLabels"]
log.info("done")
tensorboard = TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time())))
if not args.load_model:
model.fit(trainImages.reshape((-1,224,224,3)),trainLabels,epochs=10,batch_size=128,validation_split=0.2,callbacks=[tensorboard])
for layer in model.layers[:249]:
layer.trainable = False
for layer in model.layers[249:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),loss='mse')
BIG_STEP=20
for i in range(args.initial_epoch//BIG_STEP,args.epochs//BIG_STEP):
model.fit(trainImages.reshape((-1,224,224,3)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=128,validation_split=0.2,callbacks=[tensorboard])
path=args.save_model.format((i+1)*BIG_STEP)
log.info("saving model...")
model.save(path)
# ftp.push(path)
log.info(model.evaluate(testImages.reshape((-1,224,224,3)),testLabels))
|