Files
@ 8b30e6dba468
Branch filter:
Location: OneEye/exp/kerokero/train.py - annotation
8b30e6dba468
2.5 KiB
text/x-python
fix: corrupted training data
655956f6ba89 9483b964f560 655956f6ba89 9483b964f560 dd45e200a0dc 655956f6ba89 655956f6ba89 9483b964f560 655956f6ba89 9483b964f560 655956f6ba89 655956f6ba89 9483b964f560 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc 655956f6ba89 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc 9483b964f560 dd45e200a0dc dd45e200a0dc 9483b964f560 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc 655956f6ba89 655956f6ba89 655956f6ba89 9483b964f560 9483b964f560 9483b964f560 9483b964f560 9483b964f560 9483b964f560 9483b964f560 655956f6ba89 5fb721461494 dd45e200a0dc 5fb721461494 9483b964f560 9483b964f560 5fb721461494 9483b964f560 | import argparse
import logging as log
import numpy as np
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,BatchNormalization
from keras.models import Sequential,load_model
import ftp
log.basicConfig(level=log.INFO,format="%(asctime)s %(levelname)s: %(message)s")
parser=argparse.ArgumentParser()
parser.add_argument("data")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
args=parser.parse_args()
def createFullyConnected():
model=Sequential([
Flatten(input_shape=(224,224)),
Dense(128, activation="relu"),
Dropout(0.1),
Dense(64, activation="relu"),
Dense(8)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae','accuracy']
)
return model
def createCNN():
model=Sequential()
model.add(Conv2D(filters=16,kernel_size=2,padding="same",activation="relu",input_shape=(224,224,1)))
model.add(Dropout(0.1))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(32,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Conv2D(64,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(128,(3,3),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.4))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(500,activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(128,activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(8))
model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy'])
return model
model=createCNN()
if args.load_model:
model=load_model(args.load_model)
log.info("loading data...")
with np.load(args.data) as data:
trainImages=data["trainImages"]
trainLabels=data["trainLabels"]
testImages=data["testImages"]
testLabels=data["testLabels"]
log.info("done")
for i in range(args.initial_epoch//10,args.epochs//10):
model.fit(trainImages.reshape((-1,224,224,1)),trainLabels,epochs=(i+1)*10,initial_epoch=i*10,batch_size=128,validation_split=0.2)
path=args.save_model.format((i+1)*10)
log.info("saving model...")
model.save(path)
if i%2==1: ftp.push(path)
log.info(model.evaluate(testImages,testLabels))
|