Files @ f1f8a2421f92
Branch filter:

Location: OneEye/exp/sansa.py

Laman
updated readme
import os
import time
import argparse
import logging as log

import numpy as np
import tensorflow as tf
from PIL import Image
from distutils.version import StrictVersion

import exp_config as cfg

if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
	raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')

detection_graph = tf.Graph()
with detection_graph.as_default():
	od_graph_def = tf.GraphDef()
	with tf.gfile.GFile(cfg.sansaModel, 'rb') as fid:
		serialized_graph = fid.read()
		od_graph_def.ParseFromString(serialized_graph)
		tf.import_graph_def(od_graph_def, name='')


def load_image_into_numpy_array(image):
	(im_width, im_height) = image.size
	return np.array(image.getdata()).reshape(
			(im_height, im_width, 3)).astype(np.uint8)


def run_inference_for_single_image(image, graph):
	with graph.as_default():
		with tf.Session() as sess:
			# Get handles to input and output tensors
			ops = tf.get_default_graph().get_operations()
			all_tensor_names = {output.name for op in ops for output in op.outputs}
			tensor_dict = {}
			for key in [
					'num_detections', 'detection_boxes', 'detection_scores',
					'detection_classes', 'detection_masks'
			]:
				tensor_name = key + ':0'
				if tensor_name in all_tensor_names:
					tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
							tensor_name)
			image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

			# Run inference
			output_dict = sess.run(tensor_dict,
														 feed_dict={image_tensor: np.expand_dims(image, 0)})

			# all outputs are float32 numpy arrays, so convert types as appropriate
			output_dict['num_detections'] = int(output_dict['num_detections'][0])
			output_dict['detection_classes'] = output_dict[
					'detection_classes'][0].astype(np.uint8)
			output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
			output_dict['detection_scores'] = output_dict['detection_scores'][0]
	return output_dict


def extractBoard(image):
	t1=time.time()
	image_np=load_image_into_numpy_array(image)
	output_dict=run_inference_for_single_image(image_np,detection_graph)

	(width,height)=image.size
	(y1,x1,y2,x2)=output_dict["detection_boxes"][0]
	(w,h)=((x2-x1),(y2-y1))
	x1=max(0,x1-0.1*w)
	x2=min(1,x2+0.1*w)
	y1=max(0,y1-0.1*h)
	y2=min(1,y2+0.1*h)
	goban=image.crop((x1*width,y1*height,x2*width,y2*height))
	t=time.time()-t1
	log.info("board detected in {0:.3}s".format(t))
	return goban


if __name__=="__main__":
	parser=argparse.ArgumentParser()
	parser.add_argument("-i","--input",nargs="+")
	parser.add_argument("-o","--output_dir",required=True)
	args=parser.parse_args()

	for image_path in args.input:
		image=Image.open(image_path)
		goban=extractBoard(image)
		goban.save(os.path.join(args.output_dir,os.path.basename(image_path)))