diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -1,6 +1,6 @@ import argparse -from keras.layers import Conv2D,Dropout,Dense,Flatten +from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization from keras.models import Sequential,load_model from prepare_data import loadDataset @@ -14,19 +14,60 @@ parser.add_argument("--epochs",type=int, parser.add_argument("--initial_epoch",type=int,default=0) args=parser.parse_args() -model=Sequential([ - Flatten(input_shape=(224,224)), - Dense(128, activation="relu"), - Dropout(0.1), - Dense(64, activation="relu"), - Dense(8) -]) + +def createFullyConnected(): + model=Sequential([ + Flatten(input_shape=(224,224)), + Dense(128, activation="relu"), + Dropout(0.1), + Dense(64, activation="relu"), + Dense(8) + ]) + + model.compile( + optimizer='adam', + loss='mse', + metrics=['mae','accuracy'] + ) + return model -model.compile( - optimizer='adam', - loss='mse', - metrics=['mae','accuracy'] -) +def createCNN(): + model=Sequential() + + model.add(Conv2D(filters=16,kernel_size=2,padding="same",activation="relu",input_shape=(224,224,1))) + model.add(Dropout(0.1)) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) + model.add(BatchNormalization()) + + model.add(Conv2D(32,(5,5),activation="relu")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) + model.add(Dropout(0.2)) + model.add(BatchNormalization()) + + model.add(Conv2D(64,(5,5),activation="relu")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) + model.add(BatchNormalization()) + + model.add(Conv2D(128,(3,3),activation="relu")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) + model.add(Dropout(0.4)) + model.add(BatchNormalization()) + + model.add(Flatten()) + + model.add(Dense(500, activation="relu")) + model.add(Dropout(0.1)) + + model.add(Dense(128, activation="relu")) + model.add(Dropout(0.1)) + + model.add(Dense(8)) + + model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy']) + return model + + +model=createCNN() if args.load_model: model=load_model(args.load_model) @@ -35,6 +76,6 @@ print("loading data...") print("done") for i in range(args.initial_epoch,args.epochs//10): - model.fit(trainImages,trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=1/9) + model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=0.2) model.save(args.save_model.format(i+1)) print(model.evaluate(testImages,testLabels))