diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -5,10 +5,9 @@ import argparse import logging as log import numpy as np -from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape -from keras.models import Sequential,load_model +from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape,concatenate +from keras.models import Sequential,load_model,Model,Input from keras.callbacks import TensorBoard,ModelCheckpoint -import keras.losses import keras.metrics import config as cfg @@ -21,9 +20,10 @@ parser=argparse.ArgumentParser() parser.add_argument("data") parser.add_argument("--load_model") parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5") +parser.add_argument("--load_hints") +parser.add_argument("--log_dir",default="/tmp/tflogs") parser.add_argument("--epochs",type=int,default=100) parser.add_argument("--initial_epoch",type=int,default=0) -parser.add_argument("--log_dir",default="/tmp/tflogs") args=parser.parse_args() @@ -73,9 +73,42 @@ def createCNN(): return model -model=createCNN() +def createHinted(): + input=Input((224,224,1)) + base=load_model(args.load_hints) + for layer in base.layers: + layer.trainable=False + hints=base(input) + + x=BatchNormalization()(input) + x=Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")(x) + x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) + x=Conv2D(36,(5,5),activation="relu")(x) + x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) + x=Conv2D(48,(5,5),activation="relu")(x) + x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) + x=Conv2D(64,(3,3),activation="relu")(x) + x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) + x=Conv2D(64,(3,3),activation="relu")(x) + x=GlobalAveragePooling2D()(x) + + x=concatenate([x,Flatten()(hints)]) + x=Dense(500,activation="relu")(x) + x=Dense(90,activation="relu")(x) + + predictions=Reshape((4,2))(Dense(8)(x)) + + model=Model(inputs=input,outputs=predictions) + + model.compile(optimizer='rmsprop',loss=averageDistance,metrics=['mae','accuracy']) + return model + + if args.load_model: model=load_model(args.load_model) +else: + model=createHinted() + model.summary() log.info("loading data...") with np.load(args.data) as data: @@ -104,4 +137,4 @@ model.fit_generator( callbacks=[tensorboard,checkpoint] ) -log.info(model.evaluate(testImages,testLabels)) +print(model.evaluate(testImages,testLabels))