Files @ 7cb01d4080c9
Branch filter:

Location: OneEye/exp/kerokero/train.py - annotation

Laman
a hinted neural network (failed)
c934d44cdf5c
247811dfb9be
c934d44cdf5c
655956f6ba89
9483b964f560
655956f6ba89
9483b964f560
7cb01d4080c9
7cb01d4080c9
4381957e967e
006c6f1aab13
655956f6ba89
c934d44cdf5c
247811dfb9be
006c6f1aab13
006c6f1aab13
006c6f1aab13
655956f6ba89
655956f6ba89
9483b964f560
655956f6ba89
655956f6ba89
7cb01d4080c9
7cb01d4080c9
655956f6ba89
655956f6ba89
655956f6ba89
655956f6ba89
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
655956f6ba89
dd45e200a0dc
dd45e200a0dc
ecf98a415d97
ecf98a415d97
ecf98a415d97
006c6f1aab13
006c6f1aab13
ecf98a415d97
ecf98a415d97
006c6f1aab13
ecf98a415d97
ecf98a415d97
006c6f1aab13
ecf98a415d97
ecf98a415d97
006c6f1aab13
ecf98a415d97
ecf98a415d97
ecf98a415d97
ecf98a415d97
ecf98a415d97
9483b964f560
ecf98a415d97
dd45e200a0dc
006c6f1aab13
dd45e200a0dc
006c6f1aab13
dd45e200a0dc
dd45e200a0dc
dd45e200a0dc
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
655956f6ba89
655956f6ba89
7cb01d4080c9
7cb01d4080c9
7cb01d4080c9
655956f6ba89
9483b964f560
9483b964f560
5f42b982809c
247811dfb9be
5f42b982809c
247811dfb9be
9483b964f560
655956f6ba89
247811dfb9be
247811dfb9be
247811dfb9be
247811dfb9be
247811dfb9be
247811dfb9be
4381957e967e
4381957e967e
4381957e967e
247811dfb9be
247811dfb9be
4381957e967e
4381957e967e
247811dfb9be
247811dfb9be
247811dfb9be
4381957e967e
4381957e967e
4381957e967e
7cb01d4080c9
import os
import math
from time import time
import argparse
import logging as log

import numpy as np
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape,concatenate
from keras.models import Sequential,load_model,Model,Input
from keras.callbacks import TensorBoard,ModelCheckpoint
import keras.metrics

import config as cfg
from k_util import averageDistance,generateData

keras.losses.averageDistance=averageDistance
keras.metrics.averageDistance=averageDistance

parser=argparse.ArgumentParser()
parser.add_argument("data")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--load_hints")
parser.add_argument("--log_dir",default="/tmp/tflogs")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
args=parser.parse_args()


def createFullyConnected():
	model=Sequential([
		Flatten(input_shape=(224,224)),
		Dense(128, activation="relu"),
		Dropout(0.1),
		Dense(64, activation="relu"),
		Dense(8)
	])

	model.compile(
		optimizer='adam',
		loss='mse',
		metrics=['mae','accuracy']
	)
	return model

def createCNN():
	model=Sequential()

	model.add(BatchNormalization(input_shape=(224,224,1)))

	model.add(Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last"))
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))

	model.add(Conv2D(36,(5,5),activation="relu"))
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))

	model.add(Conv2D(48,(5,5),activation="relu"))
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))

	model.add(Conv2D(64,(3,3),activation="relu"))
	model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))

	model.add(Conv2D(64,(3,3),activation="relu"))

	model.add(GlobalAveragePooling2D())

	model.add(Dense(500,activation="relu"))
	model.add(Dense(90,activation="relu"))
	model.add(Dense(8))
	model.add(Reshape((4,2)))

	model.compile(optimizer="rmsprop",loss=averageDistance,metrics=["mae","accuracy"])
	return model


def createHinted():
	input=Input((224,224,1))
	base=load_model(args.load_hints)
	for layer in base.layers:
		layer.trainable=False
	hints=base(input)

	x=BatchNormalization()(input)
	x=Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")(x)
	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
	x=Conv2D(36,(5,5),activation="relu")(x)
	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
	x=Conv2D(48,(5,5),activation="relu")(x)
	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
	x=Conv2D(64,(3,3),activation="relu")(x)
	x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x)
	x=Conv2D(64,(3,3),activation="relu")(x)
	x=GlobalAveragePooling2D()(x)

	x=concatenate([x,Flatten()(hints)])
	x=Dense(500,activation="relu")(x)
	x=Dense(90,activation="relu")(x)

	predictions=Reshape((4,2))(Dense(8)(x))

	model=Model(inputs=input,outputs=predictions)

	model.compile(optimizer='rmsprop',loss=averageDistance,metrics=['mae','accuracy'])
	return model


if args.load_model:
	model=load_model(args.load_model)
else:
	model=createHinted()
	model.summary()

log.info("loading data...")
with np.load(args.data) as data:
	trainImages=(np.float32(data["trainImages"])/128-1).reshape((-1,224,224,1))
	trainLabels=data["trainLabels"].reshape((-1,4,2))
	testImages=(np.float32(data["testImages"])/128-1).reshape((-1,224,224,1))
	testLabels=data["testLabels"].reshape((-1,4,2))
log.info("done")

n=len(trainImages)
k=round(n*0.9)
n_=n-k
(trainImages,valImages)=(np.float32(trainImages[:k]),np.float32(trainImages[k:]))
(trainLabels,valLabels)=(np.float32(trainLabels[:k]),np.float32(trainLabels[k:]))

tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time())))
checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10)

model.fit_generator(
	generateData(trainImages,trainLabels,batch_size=20),
	epochs=args.epochs,
	initial_epoch=args.initial_epoch,
	steps_per_epoch=math.ceil(n_/20),
	validation_data=generateData(valImages,valLabels,batch_size=20),
	validation_steps=math.ceil(k/20),
	callbacks=[tensorboard,checkpoint]
)

print(model.evaluate(testImages,testLabels))