# HG changeset patch # User Laman # Date 2019-05-09 22:25:27 # Node ID 4381957e967eef6f36b894d06633a91609b92294 # Parent ecf98a415d975b60960c84ce1bc4c5fa68cc3769 keras checkpoint saving diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -6,10 +6,9 @@ import logging as log import numpy as np from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D from keras.models import Sequential,load_model -from keras.callbacks import TensorBoard +from keras.callbacks import TensorBoard,ModelCheckpoint import config as cfg -import ftp parser=argparse.ArgumentParser() parser.add_argument("data") @@ -78,12 +77,17 @@ with np.load(args.data) as data: testLabels=data["testLabels"] log.info("done") -tensorboard = TensorBoard(log_dir=os.path.join(cfg.thisDir,"../logs","{}".format(time()))) -BIG_STEP=20 -for i in range(args.initial_epoch//BIG_STEP,args.epochs//BIG_STEP): - model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=20,validation_split=0.2,callbacks=[tensorboard]) - path=args.save_model.format((i+1)*BIG_STEP) - log.info("saving model...") - model.save(path) - # ftp.push(path) +tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time()))) +checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10) + +model.fit( + trainImages.reshape((-1,224,224,1)), + trainLabels, + epochs=args.epochs, + initial_epoch=args.initial_epoch, + batch_size=20, + validation_split=0.2, + callbacks=[tensorboard,checkpoint] +) + log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels))