# HG changeset patch # User Laman # Date 2019-05-15 16:26:02 # Node ID 4eb46a5b9c2bcd53c3352bec68e3ecc41f4ee01d # Parent 247811dfb9beef1515ef63e2fc31eba7821f33c1 Sansa board detection module diff --git a/.hgignore b/.hgignore --- a/.hgignore +++ b/.hgignore @@ -1,5 +1,5 @@ __pycache__/ ^\.idea/ ^images/ -^config.json$ +config.json$ ftp.json$ diff --git a/exp/config.py b/exp/exp_config.py rename from exp/config.py rename to exp/exp_config.py --- a/exp/config.py +++ b/exp/exp_config.py @@ -1,8 +1,17 @@ import os +import json +import logging as log +log.basicConfig(level=log.INFO,format="%(asctime)s %(levelname)s: %(message)s") +thisDir=os.path.dirname(__file__) + +with open(os.path.join(thisDir,"config.json")) as f: + cfg=json.load(f) + INTERACTIVE=False imgDir="/tmp/oneEye" +sansaModel=cfg.get("sansaModel") i=1 if not os.path.exists(imgDir): diff --git a/exp/sansa.py b/exp/sansa.py new file mode 100644 --- /dev/null +++ b/exp/sansa.py @@ -0,0 +1,88 @@ +import os +import time +import argparse +import logging as log + +import numpy as np +import tensorflow as tf +from PIL import Image +from distutils.version import StrictVersion + +import exp_config as cfg + +if StrictVersion(tf.__version__) < StrictVersion('1.9.0'): + raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!') + +detection_graph = tf.Graph() +with detection_graph.as_default(): + od_graph_def = tf.GraphDef() + with tf.gfile.GFile(cfg.sansaModel, 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + + +def load_image_into_numpy_array(image): + (im_width, im_height) = image.size + return np.array(image.getdata()).reshape( + (im_height, im_width, 3)).astype(np.uint8) + + +def run_inference_for_single_image(image, graph): + with graph.as_default(): + with tf.Session() as sess: + # Get handles to input and output tensors + ops = tf.get_default_graph().get_operations() + all_tensor_names = {output.name for op in ops for output in op.outputs} + tensor_dict = {} + for key in [ + 'num_detections', 'detection_boxes', 'detection_scores', + 'detection_classes', 'detection_masks' + ]: + tensor_name = key + ':0' + if tensor_name in all_tensor_names: + tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( + tensor_name) + image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') + + # Run inference + output_dict = sess.run(tensor_dict, + feed_dict={image_tensor: np.expand_dims(image, 0)}) + + # all outputs are float32 numpy arrays, so convert types as appropriate + output_dict['num_detections'] = int(output_dict['num_detections'][0]) + output_dict['detection_classes'] = output_dict[ + 'detection_classes'][0].astype(np.uint8) + output_dict['detection_boxes'] = output_dict['detection_boxes'][0] + output_dict['detection_scores'] = output_dict['detection_scores'][0] + return output_dict + + +def extractBoard(image): + t1=time.time() + image_np=load_image_into_numpy_array(image) + output_dict=run_inference_for_single_image(image_np,detection_graph) + + (width,height)=image.size + (y1,x1,y2,x2)=output_dict["detection_boxes"][0] + (w,h)=((x2-x1),(y2-y1)) + x1=max(0,x1-0.1*w) + x2=min(1,x2+0.1*w) + y1=max(0,y1-0.1*h) + y2=min(1,y2+0.1*h) + goban=image.crop((x1*width,y1*height,x2*width,y2*height)) + t=time.time()-t1 + log.info("board detected in {0:.3}s".format(t)) + return goban + + +if __name__=="__main__": + parser=argparse.ArgumentParser() + parser.add_argument("-i","--input",nargs="+") + parser.add_argument("-o","--output_dir",required=True) + args=parser.parse_args() + + for image_path in args.input: + image=Image.open(image_path) + goban=extractBoard(image) + goban.save(os.path.join(args.output_dir,os.path.basename(image_path)))