diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -4,9 +4,11 @@ import argparse import logging as log import numpy as np -from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization -from keras.models import Sequential,load_model +from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,GlobalAveragePooling2D,BatchNormalization +from keras.models import Sequential,load_model,Model +from keras.optimizers import SGD from keras.callbacks import TensorBoard +from keras.applications.inception_v3 import InceptionV3,preprocess_input import config as cfg import ftp @@ -17,6 +19,7 @@ parser.add_argument("--load_model") parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5") parser.add_argument("--epochs",type=int,default=100) parser.add_argument("--initial_epoch",type=int,default=0) +parser.add_argument("--log_dir",default="/tmp/tflogs") args=parser.parse_args() @@ -72,24 +75,50 @@ def createCNN(): return model -model=createCNN() +def createPretrained(): + base=InceptionV3(weights="imagenet",include_top=False,input_shape=(224,224,3)) + + x=base.output + x=GlobalAveragePooling2D()(x) + x=Dense(1024,activation="relu")(x) + predictions=Dense(8)(x) + + model=Model(inputs=base.input,outputs=predictions) + for layer in base.layers: + layer.trainable=False + + model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy']) + return model + + if args.load_model: model=load_model(args.load_model) +else: + model=createPretrained() log.info("loading data...") with np.load(args.data) as data: - trainImages=data["trainImages"] + trainImages=preprocess_input(data["trainImages"]) trainLabels=data["trainLabels"] - testImages=data["testImages"] + testImages=preprocess_input(data["testImages"]) testLabels=data["testLabels"] log.info("done") -tensorboard = TensorBoard(log_dir=os.path.join(cfg.thisDir,"../logs","{}".format(time()))) -BIG_STEP=50 +tensorboard = TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time()))) + +if not args.load_model: + model.fit(trainImages.reshape((-1,224,224,3)),trainLabels,epochs=10,batch_size=128,validation_split=0.2,callbacks=[tensorboard]) +for layer in model.layers[:249]: + layer.trainable = False +for layer in model.layers[249:]: + layer.trainable = True +model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),loss='mse') + +BIG_STEP=20 for i in range(args.initial_epoch//BIG_STEP,args.epochs//BIG_STEP): - model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=128,validation_split=0.2,callbacks=[tensorboard]) + model.fit(trainImages.reshape((-1,224,224,3)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=128,validation_split=0.2,callbacks=[tensorboard]) path=args.save_model.format((i+1)*BIG_STEP) log.info("saving model...") model.save(path) # ftp.push(path) -log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels)) +log.info(model.evaluate(testImages.reshape((-1,224,224,3)),testLabels))