diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -5,8 +5,8 @@ import argparse import logging as log import numpy as np -from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape,concatenate -from keras.models import Sequential,load_model,Model,Input +from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape +from keras.models import Sequential,load_model from keras.callbacks import TensorBoard,ModelCheckpoint import keras.metrics @@ -73,41 +73,10 @@ def createCNN(): return model -def createHinted(): - input=Input((224,224,1)) - base=load_model(args.load_hints) - for layer in base.layers: - layer.trainable=False - hints=base(input) - - x=BatchNormalization()(input) - x=Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")(x) - x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) - x=Conv2D(36,(5,5),activation="relu")(x) - x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) - x=Conv2D(48,(5,5),activation="relu")(x) - x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) - x=Conv2D(64,(3,3),activation="relu")(x) - x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) - x=Conv2D(64,(3,3),activation="relu")(x) - x=GlobalAveragePooling2D()(x) - - x=concatenate([x,Flatten()(hints)]) - x=Dense(500,activation="relu")(x) - x=Dense(90,activation="relu")(x) - - predictions=Reshape((4,2))(Dense(8)(x)) - - model=Model(inputs=input,outputs=predictions) - - model.compile(optimizer='rmsprop',loss=averageDistance,metrics=['mae','accuracy']) - return model - - if args.load_model: model=load_model(args.load_model) else: - model=createHinted() + model=createCNN() model.summary() log.info("loading data...")