# HG changeset patch
# User Laman
# Date 2019-05-30 12:42:29
# Node ID 7cb01d4080c978bfc9d697f49fc81af5d2e191da
# Parent  f1f8a2421f9208005fbea78704be5168c1d86b4c

a hinted neural network (failed)

diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py
--- a/exp/kerokero/train.py
+++ b/exp/kerokero/train.py
@@ -5,10 +5,9 @@ import argparse
 import logging as log
 
 import numpy as np
-from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape
-from keras.models import Sequential,load_model
+from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape,concatenate
+from keras.models import Sequential,load_model,Model,Input
 from keras.callbacks import TensorBoard,ModelCheckpoint
-import keras.losses
 import keras.metrics
 
 import config as cfg
@@ -21,9 +20,10 @@ parser=argparse.ArgumentParser()
 parser.add_argument("data")
 parser.add_argument("--load_model")
 parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
+parser.add_argument("--load_hints")
+parser.add_argument("--log_dir",default="/tmp/tflogs")
 parser.add_argument("--epochs",type=int,default=100)
 parser.add_argument("--initial_epoch",type=int,default=0)
-parser.add_argument("--log_dir",default="/tmp/tflogs")
 args=parser.parse_args()
 
 
@@ -73,9 +73,42 @@ def createCNN():
 	return model
 
 
-model=createCNN()
+def createHinted():
+	input=Input((224,224,1))
+	base=load_model(args.load_hints)
+	for layer in base.layers:
+		layer.trainable=False
+	hints=base(input)
+
+	x=BatchNormalization()(input)
+	x=Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")(x)
+	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
+	x=Conv2D(36,(5,5),activation="relu")(x)
+	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
+	x=Conv2D(48,(5,5),activation="relu")(x)
+	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
+	x=Conv2D(64,(3,3),activation="relu")(x)
+	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
+	x=Conv2D(64,(3,3),activation="relu")(x)
+	x=GlobalAveragePooling2D()(x)
+
+	x=concatenate([x,Flatten()(hints)])
+	x=Dense(500,activation="relu")(x)
+	x=Dense(90,activation="relu")(x)
+
+	predictions=Reshape((4,2))(Dense(8)(x))
+
+	model=Model(inputs=input,outputs=predictions)
+
+	model.compile(optimizer='rmsprop',loss=averageDistance,metrics=['mae','accuracy'])
+	return model
+
+
 if args.load_model:
 	model=load_model(args.load_model)
+else:
+	model=createHinted()
+	model.summary()
 
 log.info("loading data...")
 with np.load(args.data) as data:
@@ -104,4 +137,4 @@ model.fit_generator(
 	callbacks=[tensorboard,checkpoint]
 )
 
-log.info(model.evaluate(testImages,testLabels))
+print(model.evaluate(testImages,testLabels))