Files
@ 4381957e967e
Branch filter:
Location: OneEye/exp/kerokero/train.py - annotation
4381957e967e
2.5 KiB
text/x-python
keras checkpoint saving
c934d44cdf5c c934d44cdf5c 655956f6ba89 9483b964f560 655956f6ba89 9483b964f560 ecf98a415d97 ecf98a415d97 4381957e967e 655956f6ba89 c934d44cdf5c 655956f6ba89 655956f6ba89 9483b964f560 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 d9cf0ed8e7fd 655956f6ba89 655956f6ba89 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc 655956f6ba89 dd45e200a0dc dd45e200a0dc ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 ecf98a415d97 9483b964f560 ecf98a415d97 dd45e200a0dc dd45e200a0dc ecf98a415d97 dd45e200a0dc dd45e200a0dc dd45e200a0dc ecf98a415d97 655956f6ba89 655956f6ba89 655956f6ba89 9483b964f560 9483b964f560 ecf98a415d97 9483b964f560 ecf98a415d97 9483b964f560 9483b964f560 655956f6ba89 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e 4381957e967e ecf98a415d97 | import os
from time import time
import argparse
import logging as log
import numpy as np
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization,GlobalAveragePooling2D
from keras.models import Sequential,load_model
from keras.callbacks import TensorBoard,ModelCheckpoint
import config as cfg
parser=argparse.ArgumentParser()
parser.add_argument("data")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
parser.add_argument("--log_dir",default="/tmp/tflogs")
args=parser.parse_args()
def createFullyConnected():
model=Sequential([
Flatten(input_shape=(224,224)),
Dense(128, activation="relu"),
Dropout(0.1),
Dense(64, activation="relu"),
Dense(8)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae','accuracy']
)
return model
def createCNN():
model=Sequential()
model.add(BatchNormalization(input_shape=(224,224,1)))
model.add(Conv2D(24,(5,5),border_mode="same",init="he_normal",activation="relu",input_shape=(224,224,1),dim_ordering="tf"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
model.add(Conv2D(36,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
model.add(Conv2D(48,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
model.add(Conv2D(64,(3,3),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),border_mode="valid"))
model.add(Conv2D(64,(3,3),activation="relu"))
model.add(GlobalAveragePooling2D())
model.add(Dense(500,activation="relu"))
model.add(Dense(90,activation="relu"))
model.add(Dense(8))
model.compile(optimizer="rmsprop",loss="mse",metrics=["mae","accuracy"])
return model
model=createCNN()
if args.load_model:
model=load_model(args.load_model)
log.info("loading data...")
with np.load(args.data) as data:
trainImages=data["trainImages"]
trainLabels=data["trainLabels"]
testImages=data["testImages"]
testLabels=data["testLabels"]
log.info("done")
tensorboard=TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time())))
checkpoint=ModelCheckpoint(args.save_model,monitor="val_loss",period=10)
model.fit(
trainImages.reshape((-1,224,224,1)),
trainLabels,
epochs=args.epochs,
initial_epoch=args.initial_epoch,
batch_size=20,
validation_split=0.2,
callbacks=[tensorboard,checkpoint]
)
log.info(model.evaluate(testImages.reshape((-1,224,224,1)),testLabels))
|