import os import math from time import time import argparse import logging as log import numpy as np from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D,Reshape,concatenate from keras.models import Sequential,load_model,Model,Input from keras.callbacks import TensorBoard,ModelCheckpoint import keras.metrics import config as cfg from k_util import averageDistance,generateData keras.losses.averageDistance=averageDistance keras.metrics.averageDistance=averageDistance parser=argparse.ArgumentParser() parser.add_argument("data") parser.add_argument("--load_model") parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5") parser.add_argument("--load_hints") parser.add_argument("--log_dir",default="/tmp/tflogs") parser.add_argument("--epochs",type=int,default=100) parser.add_argument("--initial_epoch",type=int,default=0) args=parser.parse_args() def createFullyConnected(): model=Sequential([ Flatten(input_shape=(224,224)), Dense(128, activation="relu"), Dropout(0.1), Dense(64, activation="relu"), Dense(8) ]) model.compile( optimizer='adam', loss='mse', metrics=['mae','accuracy'] ) return model def createCNN(): model=Sequential() model.add(BatchNormalization(input_shape=(224,224,1))) model.add(Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(36,(5,5),activation="relu")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(48,(5,5),activation="relu")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(64,(3,3),activation="relu")) model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")) model.add(Conv2D(64,(3,3),activation="relu")) model.add(GlobalAveragePooling2D()) model.add(Dense(500,activation="relu")) model.add(Dense(90,activation="relu")) model.add(Dense(8)) model.add(Reshape((4,2))) model.compile(optimizer="rmsprop",loss=averageDistance,metrics=["mae","accuracy"]) return model def createHinted(): input=Input((224,224,1)) base=load_model(args.load_hints) for layer in base.layers: layer.trainable=False hints=base(input) x=BatchNormalization()(input) x=Conv2D(24,(5,5),padding="same",kernel_initializer="he_normal",activation="relu",input_shape=(224,224,1),data_format="channels_last")(x) x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) x=Conv2D(36,(5,5),activation="relu")(x) x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) x=Conv2D(48,(5,5),activation="relu")(x) x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) x=Conv2D(64,(3,3),activation="relu")(x) x=MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid")(x) x=Conv2D(64,(3,3),activation="relu")(x) x=GlobalAveragePooling2D()(x) x=concatenate([x,Flatten()(hints)]) x=Dense(500,activation="relu")(x) x=Dense(90,activation="relu")(x) predictions=Reshape((4,2))(Dense(8)(x)) model=Model(inputs=input,outputs=predictions) model.compile(optimizer='rmsprop',loss=averageDistance,metrics=['mae','accuracy']) return model if args.load_model: model=load_model(args.load_model) else: model=createHinted() model.summary() log.info("loading data...") with np.load(args.data) as data: trainImages=(np.float32(data["trainImages"])/128-1).reshape((-1,224,224,1)) trainLabels=data["trainLabels"].reshape((-1,4,2)) testImages=(np.float32(data["testImages"])/128-1).reshape((-1,224,224,1)) testLabels=data["testLabels"].reshape((-1,4,2)) log.info("done") n=len(trainImages) k=round(n*0.9) n_=n-k (trainImages,valImages)=(np.float32(trainImages[:k]),np.float32(trainImages[k:])) (trainLabels,valLabels)=(np.float32(trainLabels[:k]),np.float32(trainLabels[k:])) tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time()))) checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10) model.fit_generator( generateData(trainImages,trainLabels,batch_size=20), epochs=args.epochs, initial_epoch=args.initial_epoch, steps_per_epoch=math.ceil(n_/20), validation_data=generateData(valImages,valLabels,batch_size=20), validation_steps=math.ceil(k/20), callbacks=[tensorboard,checkpoint] ) print(model.evaluate(testImages,testLabels))