Files
@ 655956f6ba89
Branch filter:
Location: OneEye/exp/kerokero/train.py - annotation
655956f6ba89
1.1 KiB
text/x-python
training and testing model
655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 | import argparse
from keras.layers import Conv2D,Dropout,Dense,Flatten
from keras.models import Sequential,load_model
from prepare_data import loadDataset
parser=argparse.ArgumentParser()
parser.add_argument("data_dir")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
args=parser.parse_args()
model=Sequential([
Flatten(input_shape=(224,224)),
Dense(128, activation="relu"),
Dropout(0.1),
Dense(64, activation="relu"),
Dense(8)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae','accuracy']
)
if args.load_model:
model=load_model(args.load_model)
print("loading data...")
((trainImages,trainLabels),(testImages,testLabels))=loadDataset(args.data_dir)
print("done")
for i in range(args.initial_epoch,args.epochs//10):
model.fit(trainImages,trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=1/9)
model.save(args.save_model.format(i+1))
print(model.evaluate(testImages,testLabels))
|