import os
import time
import argparse
import logging as log
import numpy as np
import tensorflow as tf
from PIL import Image
from distutils.version import StrictVersion
import exp_config as cfg
if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(cfg.sansaModel, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.Session() as sess:
# Get handles to input and output tensors
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in [
'num_detections', 'detection_boxes', 'detection_scores',
'detection_classes', 'detection_masks'
]:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
tensor_name)
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
# Run inference
output_dict = sess.run(tensor_dict,
feed_dict={image_tensor: np.expand_dims(image, 0)})
# all outputs are float32 numpy arrays, so convert types as appropriate
output_dict['num_detections'] = int(output_dict['num_detections'][0])
output_dict['detection_classes'] = output_dict[
'detection_classes'][0].astype(np.uint8)
output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
output_dict['detection_scores'] = output_dict['detection_scores'][0]
return output_dict
def extractBoard(image):
t1=time.time()
image_np=load_image_into_numpy_array(image)
output_dict=run_inference_for_single_image(image_np,detection_graph)
(width,height)=image.size
(y1,x1,y2,x2)=output_dict["detection_boxes"][0]
(w,h)=((x2-x1),(y2-y1))
x1=max(0,x1-0.1*w)
x2=min(1,x2+0.1*w)
y1=max(0,y1-0.1*h)
y2=min(1,y2+0.1*h)
goban=image.crop((x1*width,y1*height,x2*width,y2*height))
t=time.time()-t1
log.info("board detected in {0:.3}s".format(t))
return goban
if __name__=="__main__":
parser=argparse.ArgumentParser()
parser.add_argument("-i","--input",nargs="+")
parser.add_argument("-o","--output_dir",required=True)
args=parser.parse_args()
for image_path in args.input:
image=Image.open(image_path)
goban=extractBoard(image)
goban.save(os.path.join(args.output_dir,os.path.basename(image_path)))