diff --git a/exp/kerokero/train.py b/exp/kerokero/train.py --- a/exp/kerokero/train.py +++ b/exp/kerokero/train.py @@ -79,9 +79,9 @@ if args.load_model: log.info("loading data...") with np.load(args.data) as data: - trainImages=data["trainImages"].reshape((-1,224,224,1)) + trainImages=(np.float32(data["trainImages"])/128-1).reshape((-1,224,224,1)) trainLabels=data["trainLabels"].reshape((-1,4,2)) - testImages=data["testImages"].reshape((-1,224,224,1)) + testImages=(np.float32(data["testImages"])/128-1).reshape((-1,224,224,1)) testLabels=data["testLabels"].reshape((-1,4,2)) log.info("done")