Files @ d4478db55eec
Branch filter:

Location: OneEye/exp/sansa.py - annotation

Laman
readme: line breaks
import os
import time
import argparse
import logging as log

import numpy as np
import tensorflow as tf
from PIL import Image
from distutils.version import StrictVersion

import exp_config as cfg

if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
	raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')

detection_graph = tf.Graph()
with detection_graph.as_default():
	od_graph_def = tf.GraphDef()
	with tf.gfile.GFile(cfg.sansaModel, 'rb') as fid:
		serialized_graph = fid.read()
		od_graph_def.ParseFromString(serialized_graph)
		tf.import_graph_def(od_graph_def, name='')


def load_image_into_numpy_array(image):
	(im_width, im_height) = image.size
	return np.array(image.getdata()).reshape(
			(im_height, im_width, 3)).astype(np.uint8)


def run_inference_for_single_image(image, graph):
	with graph.as_default():
		with tf.Session() as sess:
			# Get handles to input and output tensors
			ops = tf.get_default_graph().get_operations()
			all_tensor_names = {output.name for op in ops for output in op.outputs}
			tensor_dict = {}
			for key in [
					'num_detections', 'detection_boxes', 'detection_scores',
					'detection_classes', 'detection_masks'
			]:
				tensor_name = key + ':0'
				if tensor_name in all_tensor_names:
					tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
							tensor_name)
			image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

			# Run inference
			output_dict = sess.run(tensor_dict,
														 feed_dict={image_tensor: np.expand_dims(image, 0)})

			# all outputs are float32 numpy arrays, so convert types as appropriate
			output_dict['num_detections'] = int(output_dict['num_detections'][0])
			output_dict['detection_classes'] = output_dict[
					'detection_classes'][0].astype(np.uint8)
			output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
			output_dict['detection_scores'] = output_dict['detection_scores'][0]
	return output_dict


def extractBoard(image):
	t1=time.time()
	image_np=load_image_into_numpy_array(image)
	output_dict=run_inference_for_single_image(image_np,detection_graph)

	(width,height)=image.size
	(y1,x1,y2,x2)=output_dict["detection_boxes"][0]
	(w,h)=((x2-x1),(y2-y1))
	x1=max(0,x1-0.1*w)
	x2=min(1,x2+0.1*w)
	y1=max(0,y1-0.1*h)
	y2=min(1,y2+0.1*h)
	goban=image.crop((x1*width,y1*height,x2*width,y2*height))
	t=time.time()-t1
	log.info("board detected in {0:.3}s".format(t))
	return goban


if __name__=="__main__":
	parser=argparse.ArgumentParser()
	parser.add_argument("-i","--input",nargs="+")
	parser.add_argument("-o","--output_dir",required=True)
	args=parser.parse_args()

	for image_path in args.input:
		image=Image.open(image_path)
		goban=extractBoard(image)
		goban.save(os.path.join(args.output_dir,os.path.basename(image_path)))