Files
@ 29f28718a69b
Branch filter:
Location: OneEye/exp/sansa.py - annotation
29f28718a69b
2.7 KiB
text/x-python
transitional data processing
4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b 4eb46a5b9c2b | import os
import time
import argparse
import logging as log
import numpy as np
import tensorflow as tf
from PIL import Image
from distutils.version import StrictVersion
import exp_config as cfg
if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(cfg.sansaModel, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.Session() as sess:
# Get handles to input and output tensors
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in [
'num_detections', 'detection_boxes', 'detection_scores',
'detection_classes', 'detection_masks'
]:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
tensor_name)
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
# Run inference
output_dict = sess.run(tensor_dict,
feed_dict={image_tensor: np.expand_dims(image, 0)})
# all outputs are float32 numpy arrays, so convert types as appropriate
output_dict['num_detections'] = int(output_dict['num_detections'][0])
output_dict['detection_classes'] = output_dict[
'detection_classes'][0].astype(np.uint8)
output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
output_dict['detection_scores'] = output_dict['detection_scores'][0]
return output_dict
def extractBoard(image):
t1=time.time()
image_np=load_image_into_numpy_array(image)
output_dict=run_inference_for_single_image(image_np,detection_graph)
(width,height)=image.size
(y1,x1,y2,x2)=output_dict["detection_boxes"][0]
(w,h)=((x2-x1),(y2-y1))
x1=max(0,x1-0.1*w)
x2=min(1,x2+0.1*w)
y1=max(0,y1-0.1*h)
y2=min(1,y2+0.1*h)
goban=image.crop((x1*width,y1*height,x2*width,y2*height))
t=time.time()-t1
log.info("board detected in {0:.3}s".format(t))
return goban
if __name__=="__main__":
parser=argparse.ArgumentParser()
parser.add_argument("-i","--input",nargs="+")
parser.add_argument("-o","--output_dir",required=True)
args=parser.parse_args()
for image_path in args.input:
image=Image.open(image_path)
goban=extractBoard(image)
goban.save(os.path.join(args.output_dir,os.path.basename(image_path)))
|