Changeset - 4381957e967e
[Not reviewed]
default
0 1 0
Laman - 6 years ago 2019-05-09 22:25:27

keras checkpoint saving
1 file changed with 14 insertions and 10 deletions:
0 comments (0 inline, 0 general)
exp/kerokero/train.py
Show inline comments
 
import os
 
from time import time
 
import argparse
 
import logging as log
 

	
 
import numpy as np
 
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D
 
from keras.models import Sequential,load_model
 
from keras.callbacks import TensorBoard
 
from keras.callbacks import TensorBoard,ModelCheckpoint
 

	
 
import config as cfg
 
import ftp
 

	
 
parser=argparse.ArgumentParser()
 
parser.add_argument("data")
 
parser.add_argument("--load_model")
 
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
 
parser.add_argument("--epochs",type=int,default=100)
 
parser.add_argument("--initial_epoch",type=int,default=0)
 
parser.add_argument("--log_dir",default="/tmp/tflogs")
 
args=parser.parse_args()
 

	
 

	
 
def createFullyConnected():
 
	model=Sequential([
 
		Flatten(input_shape=(224,224)),
 
		Dense(128, activation="relu"),
 
		Dropout(0.1),
 
		Dense(64, activation="relu"),
 
		Dense(8)
 
	])
 

	
 
	model.compile(
 
		optimizer='adam',
 
		loss='mse',
 
		metrics=['mae','accuracy']
 
	)
 
	return model
 

	
 
def createCNN():
 
	model=Sequential()
 

	
 
	model.add(BatchNormalization(input_shape=(224,224,1)))
 

	
 
	model.add(Conv2D(24,(5,5),border_mode="same",init="he_normal",activation="relu",input_shape=(224,224,1),dim_ordering="tf"))
 
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
 

	
 
	model.add(Conv2D(36,(5,5),activation="relu"))
 
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
 

	
 
	model.add(Conv2D(48,(5,5),activation="relu"))
 
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
 

	
 
	model.add(Conv2D(64,(3,3),activation="relu"))
 
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
 

	
 
	model.add(Conv2D(64,(3,3),activation="relu"))
 

	
 
	model.add(GlobalAveragePooling2D())
 

	
 
	model.add(Dense(500,activation="relu"))
 
	model.add(Dense(90,activation="relu"))
 
	model.add(Dense(8))
 

	
 
	model.compile(optimizer="rmsprop",loss="mse",metrics=["mae","accuracy"])
 
	return model
 

	
 

	
 
model=createCNN()
 
if args.load_model:
 
	model=load_model(args.load_model)
 

	
 
log.info("loading data...")
 
with np.load(args.data) as data:
 
	trainImages=data["trainImages"]
 
	trainLabels=data["trainLabels"]
 
	testImages=data["testImages"]
 
	testLabels=data["testLabels"]
 
log.info("done")
 

	
 
tensorboard = TensorBoard(log_dir=os.path.join(cfg.thisDir,"../logs","{}".format(time())))
 
BIG_STEP=20
 
for i in range(args.initial_epoch//BIG_STEP,args.epochs//BIG_STEP):
 
	model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=20,validation_split=0.2,callbacks=[tensorboard])
 
	path=args.save_model.format((i+1)*BIG_STEP)
 
	log.info("saving model...")
 
	model.save(path)
 
	# ftp.push(path)
 
tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time())))
 
checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10)
 

	
 
model.fit(
 
	trainImages.reshape((-1,224,224,1)),
 
	trainLabels,
 
	epochs=args.epochs,
 
	initial_epoch=args.initial_epoch,
 
	batch_size=20,
 
	validation_split=0.2,
 
	callbacks=[tensorboard,checkpoint]
 
)
 

	
 
log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels))
0 comments (0 inline, 0 general)