import os import time import argparse import logging as log import numpy as np import tensorflow as tf from PIL import Image from distutils.version import StrictVersion import exp_config as cfg if StrictVersion(tf.__version__) < StrictVersion('1.9.0'): raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!') detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(cfg.sansaModel, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') def load_image_into_numpy_array(image): (im_width, im_height) = image.size return np.array(image.getdata()).reshape( (im_height, im_width, 3)).astype(np.uint8) def run_inference_for_single_image(image, graph): with graph.as_default(): with tf.Session() as sess: # Get handles to input and output tensors ops = tf.get_default_graph().get_operations() all_tensor_names = {output.name for op in ops for output in op.outputs} tensor_dict = {} for key in [ 'num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks' ]: tensor_name = key + ':0' if tensor_name in all_tensor_names: tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( tensor_name) image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') # Run inference output_dict = sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)}) # all outputs are float32 numpy arrays, so convert types as appropriate output_dict['num_detections'] = int(output_dict['num_detections'][0]) output_dict['detection_classes'] = output_dict[ 'detection_classes'][0].astype(np.uint8) output_dict['detection_boxes'] = output_dict['detection_boxes'][0] output_dict['detection_scores'] = output_dict['detection_scores'][0] return output_dict def extractBoard(image): t1=time.time() image_np=load_image_into_numpy_array(image) output_dict=run_inference_for_single_image(image_np,detection_graph) (width,height)=image.size (y1,x1,y2,x2)=output_dict["detection_boxes"][0] (w,h)=((x2-x1),(y2-y1)) x1=max(0,x1-0.1*w) x2=min(1,x2+0.1*w) y1=max(0,y1-0.1*h) y2=min(1,y2+0.1*h) goban=image.crop((x1*width,y1*height,x2*width,y2*height)) t=time.time()-t1 log.info("board detected in {0:.3}s".format(t)) return goban if __name__=="__main__": parser=argparse.ArgumentParser() parser.add_argument("-i","--input",nargs="+") parser.add_argument("-o","--output_dir",required=True) args=parser.parse_args() for image_path in args.input: image=Image.open(image_path) goban=extractBoard(image) goban.save(os.path.join(args.output_dir,os.path.basename(image_path)))