import argparse
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization
from keras.models import Sequential,load_model
from prepare_data import loadDataset
parser=argparse.ArgumentParser()
parser.add_argument("data_dir")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
args=parser.parse_args()
def createFullyConnected():
model=Sequential([
Flatten(input_shape=(224,224)),
Dense(128, activation="relu"),
Dropout(0.1),
Dense(64, activation="relu"),
Dense(8)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae','accuracy']
)
return model
def createCNN():
model=Sequential()
model.add(Conv2D(filters=16,kernel_size=2,padding="same",activation="relu",input_shape=(224,224,1)))
model.add(Dropout(0.1))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(32,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Conv2D(64,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(128,(3,3),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.4))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(500, activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(128, activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(8))
model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy'])
return model
model=createCNN()
if args.load_model:
model=load_model(args.load_model)
print("loading data...")
((trainImages,trainLabels),(testImages,testLabels))=loadDataset(args.data_dir)
print("done")
for i in range(args.initial_epoch,args.epochs//10):
model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=0.2)
model.save(args.save_model.format(i+1))
print(model.evaluate(testImages,testLabels))