diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -4,11 +4,17 @@ import argparse import logging as log import numpy as np -from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D +from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape from keras.models import Sequential,load_model from keras.callbacks import TensorBoard,ModelCheckpoint +import keras.losses +import keras.metrics import config as cfg +from k_util import averageDistance + +keras.losses.averageDistance=averageDistance +keras.metrics.averageDistance=averageDistance parser=argparse.ArgumentParser() parser.add_argument("data") @@ -41,17 +47,17 @@ def createCNN(): model.add(BatchNormalization(input_shape=(224,224,1))) - model.add(Conv2D(24,(5,5),border_mode="same",init="he_normal",activation="relu",input_shape=(224,224,1),dim_ordering="tf")) - model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid")) + model.add(Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(36,(5,5),activation="relu")) - model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(48,(5,5),activation="relu")) - model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(64,(3,3),activation="relu")) - model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid")) + model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(64,(3,3),activation="relu")) @@ -60,8 +66,9 @@ def createCNN(): model.add(Dense(500,activation="relu")) model.add(Dense(90,activation="relu")) model.add(Dense(8)) + model.add(Reshape((4,2))) - model.compile(optimizer="rmsprop",loss="mse",metrics=["mae","accuracy"]) + model.compile(optimizer="rmsprop",loss=averageDistance,metrics=["mae","accuracy"]) return model @@ -82,7 +89,7 @@ checkpoint=ModelCheckpoint(args.save_mod model.fit( trainImages.reshape((-1,224,224,1)), - trainLabels, + trainLabels.reshape((-1,4,2)), epochs=args.epochs, initial_epoch=args.initial_epoch, batch_size=20, @@ -90,4 +97,4 @@ model.fit( callbacks=[tensorboard,checkpoint] ) -log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels)) +log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels.reshape((-1,4,2))))