import argparse from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization from keras.models import Sequential,load_model from prepare_data import loadDataset parser=argparse.ArgumentParser() parser.add_argument("data_dir") parser.add_argument("--load_model") parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5") parser.add_argument("--epochs",type=int,default=100) parser.add_argument("--initial_epoch",type=int,default=0) args=parser.parse_args() def createFullyConnected(): model=Sequential([ Flatten(input_shape=(224,224)), Dense(128, activation="relu"), Dropout(0.1), Dense(64, activation="relu"), Dense(8) ]) model.compile( optimizer='adam', loss='mse', metrics=['mae','accuracy'] ) return model def createCNN(): model=Sequential() model.add(Conv2D(filters=16,kernel_size=2,padding="same",activation="relu",input_shape=(224,224,1))) model.add(Dropout(0.1)) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(BatchNormalization()) model.add(Conv2D(32,(5,5),activation="relu")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Dropout(0.2)) model.add(BatchNormalization()) model.add(Conv2D(64,(5,5),activation="relu")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(BatchNormalization()) model.add(Conv2D(128,(3,3),activation="relu")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Dropout(0.4)) model.add(BatchNormalization()) model.add(Flatten()) model.add(Dense(500, activation="relu")) model.add(Dropout(0.1)) model.add(Dense(128, activation="relu")) model.add(Dropout(0.1)) model.add(Dense(8)) model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy']) return model model=createCNN() if args.load_model: model=load_model(args.load_model) print("loading data...") ((trainImages,trainLabels),(testImages,testLabels))=loadDataset(args.data_dir) print("done") for i in range(args.initial_epoch,args.epochs//10): model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=0.2) model.save(args.save_model.format(i+1)) print(model.evaluate(testImages,testLabels))