Files @ 655956f6ba89
Branch filter:

Location: OneEye/exp/kerokero/train.py - annotation

Laman
training and testing model
import argparse

from keras.layers import Conv2D,Dropout,Dense,Flatten
from keras.models import Sequential,load_model

from prepare_data import loadDataset


parser=argparse.ArgumentParser()
parser.add_argument("data_dir")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
args=parser.parse_args()

model=Sequential([
	Flatten(input_shape=(224,224)),
	Dense(128, activation="relu"),
	Dropout(0.1),
	Dense(64, activation="relu"),
	Dense(8)
])

model.compile(
	optimizer='adam',
	loss='mse',
	metrics=['mae','accuracy']
)
if args.load_model:
	model=load_model(args.load_model)

print("loading data...")
((trainImages,trainLabels),(testImages,testLabels))=loadDataset(args.data_dir)
print("done")

for i in range(args.initial_epoch,args.epochs//10):
	model.fit(trainImages,trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=1/9)
	model.save(args.save_model.format(i+1))
print(model.evaluate(testImages,testLabels))