diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -1,4 +1,5 @@ import os +import math from time import time import argparse import logging as log @@ -11,7 +12,7 @@ import keras.losses import keras.metrics import config as cfg -from k_util import averageDistance +from k_util import averageDistance,generateData keras.losses.averageDistance=averageDistance keras.metrics.averageDistance=averageDistance @@ -78,23 +79,29 @@ if args.load_model: log.info("loading data...") with np.load(args.data) as data: - trainImages=data["trainImages"] - trainLabels=data["trainLabels"] - testImages=data["testImages"] - testLabels=data["testLabels"] + trainImages=data["trainImages"].reshape((-1,224,224,1)) + trainLabels=data["trainLabels"].reshape((-1,4,2)) + testImages=data["testImages"].reshape((-1,224,224,1)) + testLabels=data["testLabels"].reshape((-1,4,2)) log.info("done") +n=len(trainImages) +k=round(n*0.9) +n_=n-k +(trainImages,valImages)=(np.float32(trainImages[:k]),np.float32(trainImages[k:])) +(trainLabels,valLabels)=(np.float32(trainLabels[:k]),np.float32(trainLabels[k:])) + tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time()))) checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10) -model.fit( - trainImages.reshape((-1,224,224,1)), - trainLabels.reshape((-1,4,2)), +model.fit_generator( + generateData(trainImages,trainLabels,batch_size=20), epochs=args.epochs, initial_epoch=args.initial_epoch, - batch_size=20, - validation_split=0.2, + steps_per_epoch=math.ceil(n_/20), + validation_data=generateData(valImages,valLabels,batch_size=20), + validation_steps=math.ceil(k/20), callbacks=[tensorboard,checkpoint] ) -log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels.reshape((-1,4,2)))) +log.info(model.evaluate(testImages,testLabels))