diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -6,10 +6,9 @@ import logging as log import numpy as np from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D from keras.models import Sequential,load_model -from keras.callbacks import TensorBoard +from keras.callbacks import TensorBoard,ModelCheckpoint import config as cfg -import ftp parser=argparse.ArgumentParser() parser.add_argument("data") @@ -78,12 +77,17 @@ with np.load(args.data) as data: testLabels=data["testLabels"] log.info("done") -tensorboard = TensorBoard(log_dir=os.path.join(cfg.thisDir,"../logs","{}".format(time()))) -BIG_STEP=20 -for i in range(args.initial_epoch//BIG_STEP,args.epochs//BIG_STEP): - model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=20,validation_split=0.2,callbacks=[tensorboard]) - path=args.save_model.format((i+1)*BIG_STEP) - log.info("saving model...") - model.save(path) - # ftp.push(path) +tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time()))) +checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10) + +model.fit( + trainImages.reshape((-1,224,224,1)), + trainLabels, + epochs=args.epochs, + initial_epoch=args.initial_epoch, + batch_size=20, + validation_split=0.2, + callbacks=[tensorboard,checkpoint] +) + log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels))