diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -20,7 +20,6 @@ parser=argparse.ArgumentParser() parser.add_argument("data") parser.add_argument("--load_model") parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5") -parser.add_argument("--load_hints") parser.add_argument("--log_dir",default="/tmp/tflogs") parser.add_argument("--epochs",type=int,default=100) parser.add_argument("--initial_epoch",type=int,default=0) @@ -66,10 +65,10 @@ def createCNN(): model.add(Dense(500,activation="relu")) model.add(Dense(90,activation="relu")) - model.add(Dense(8)) - model.add(Reshape((4,2))) + model.add(Dense(2)) + model.add(Reshape((1,2))) - model.compile(optimizer="rmsprop",loss=averageDistance,metrics=["mae","accuracy"]) + model.compile(optimizer="rmsprop",loss="mae",metrics=["mae","accuracy"]) return model @@ -82,9 +81,9 @@ else: log.info("loading data...") with np.load(args.data) as data: trainImages=(np.float32(data["trainImages"])/128-1).reshape((-1,224,224,1)) - trainLabels=data["trainLabels"].reshape((-1,4,2)) + trainLabels=data["trainLabels"].reshape((-1,1,2)) testImages=(np.float32(data["testImages"])/128-1).reshape((-1,224,224,1)) - testLabels=data["testLabels"].reshape((-1,4,2)) + testLabels=data["testLabels"].reshape((-1,1,2)) log.info("done") n=len(trainImages)