Files
@ d9cf0ed8e7fd
Branch filter:
Location: OneEye/exp/kerokero/train.py - annotation
d9cf0ed8e7fd
3.6 KiB
text/x-python
Inception transfer learning (failed)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | c934d44cdf5c c934d44cdf5c 655956f6ba89 9483b964f560 655956f6ba89 9483b964f560 d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd c934d44cdf5c d9cf0ed8e7fd 655956f6ba89 c934d44cdf5c 9483b964f560 655956f6ba89 655956f6ba89 9483b964f560 655956f6ba89 655956f6ba89 655956f6ba89 655956f6ba89 d9cf0ed8e7fd 655956f6ba89 655956f6ba89 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc 655956f6ba89 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc 9483b964f560 dd45e200a0dc dd45e200a0dc 9483b964f560 dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc dd45e200a0dc d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd 655956f6ba89 655956f6ba89 d9cf0ed8e7fd d9cf0ed8e7fd 655956f6ba89 9483b964f560 9483b964f560 d9cf0ed8e7fd 9483b964f560 d9cf0ed8e7fd 9483b964f560 9483b964f560 655956f6ba89 d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd d9cf0ed8e7fd fad34516870e d9cf0ed8e7fd fad34516870e 9483b964f560 9483b964f560 fad34516870e d9cf0ed8e7fd | import os
from time import time
import argparse
import logging as log
import numpy as np
from keras.layers import Conv2D,Dropout,Dense,Flatten,MaxPooling2D,GlobalAveragePooling2D,BatchNormalization
from keras.models import Sequential,load_model,Model
from keras.optimizers import SGD
from keras.callbacks import TensorBoard
from keras.applications.inception_v3 import InceptionV3,preprocess_input
import config as cfg
import ftp
parser=argparse.ArgumentParser()
parser.add_argument("data")
parser.add_argument("--load_model")
parser.add_argument("--save_model",default="/tmp/gogo-{0:03}.h5")
parser.add_argument("--epochs",type=int,default=100)
parser.add_argument("--initial_epoch",type=int,default=0)
parser.add_argument("--log_dir",default="/tmp/tflogs")
args=parser.parse_args()
def createFullyConnected():
model=Sequential([
Flatten(input_shape=(224,224)),
Dense(128, activation="relu"),
Dropout(0.1),
Dense(64, activation="relu"),
Dense(8)
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae','accuracy']
)
return model
def createCNN():
model=Sequential()
model.add(Conv2D(filters=16,kernel_size=2,padding="same",activation="relu",input_shape=(224,224,1)))
model.add(Dropout(0.1))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(32,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Conv2D(64,(5,5),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(BatchNormalization())
model.add(Conv2D(128,(3,3),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2),padding="valid"))
model.add(Dropout(0.4))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(500,activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(128,activation="relu"))
model.add(Dropout(0.1))
model.add(Dense(8))
model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy'])
return model
def createPretrained():
base=InceptionV3(weights="imagenet",include_top=False,input_shape=(224,224,3))
x=base.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation="relu")(x)
predictions=Dense(8)(x)
model=Model(inputs=base.input,outputs=predictions)
for layer in base.layers:
layer.trainable=False
model.compile(optimizer='adam',loss='mse',metrics=['mae','accuracy'])
return model
if args.load_model:
model=load_model(args.load_model)
else:
model=createPretrained()
log.info("loading data...")
with np.load(args.data) as data:
trainImages=preprocess_input(data["trainImages"])
trainLabels=data["trainLabels"]
testImages=preprocess_input(data["testImages"])
testLabels=data["testLabels"]
log.info("done")
tensorboard = TensorBoard(log_dir=os.path.join(args.log_dir,"{}".format(time())))
if not args.load_model:
model.fit(trainImages.reshape((-1,224,224,3)),trainLabels,epochs=10,batch_size=128,validation_split=0.2,callbacks=[tensorboard])
for layer in model.layers[:249]:
layer.trainable = False
for layer in model.layers[249:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),loss='mse')
BIG_STEP=20
for i in range(args.initial_epoch//BIG_STEP,args.epochs//BIG_STEP):
model.fit(trainImages.reshape((-1,224,224,3)),trainLabels,epochs=(i+1)*BIG_STEP,initial_epoch=i*BIG_STEP,batch_size=128,validation_split=0.2,callbacks=[tensorboard])
path=args.save_model.format((i+1)*BIG_STEP)
log.info("saving model...")
model.save(path)
# ftp.push(path)
log.info(model.evaluate(testImages.reshape((-1,224,224,3)),testLabels))
|