yolo_detector.py
Go to the documentation of this file.
1 #!/usr/bin/python
2 # BSD 3-Clause License
3 
4 # Copyright (c) 2019, Noam C. Golombek
5 # All rights reserved.
6 
7 # Redistribution and use in source and binary forms, with or without
8 # modification, are permitted provided that the following conditions are met:
9 
10 # 1. Redistributions of source code must retain the above copyright notice, this
11 # list of conditions and the following disclaimer.
12 
13 # 2. Redistributions in binary form must reproduce the above copyright notice,
14 # this list of conditions and the following disclaimer in the documentation
15 # and/or other materials provided with the distribution.
16 
17 # 3. Neither the name of the copyright holder nor the names of its
18 # contributors may be used to endorse or promote products derived from
19 # this software without specific prior written permission.
20 
21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 
32 
33 
34 import numpy as np
35 import tensorflow as tf
36 import timeit
37 import rospy
38 
39 
40 from tools import ResizeAndCrop
41 
42 def load_hypes(model_dir):
43  import os
44  import json
45  if os.path.isdir(model_dir):
46  hypes_name = os.path.join(model_dir, "deeplab.json")
47  else:
48  hypes_name = model_dir
49 
50  with open(hypes_name, 'r') as f:
51  return json.load(f)
52 
53 
54 class YoloDetector(object):
55  """Class to load deeplab model and run inference."""
56 
57  def __init__(self, model_dir, original_image_size, tensor_io, runCPU, gpu_percent=1):
58  self.hypes = load_hypes(model_dir)
59  self.input_tensor = tensor_io['input_tensor']
60  self.output_tensor = tensor_io['output_tensor']
61  frozen_graph_path = self.hypes['frozen_graph_path']
62  rospy.logwarn("Weights to load: " + frozen_graph_path)
63  # ---------------------------------------------------------------------
64  """Creates and loads pretrained deeplab model."""
65  self.graph = tf.Graph()
66 
67  graph_def = None
68  # Extract frozen graph from given path.
69  with open(frozen_graph_path, 'rb') as file_handle:
70  graph_def = tf.GraphDef.FromString(file_handle.read())
71 
72  if graph_def is None:
73  raise RuntimeError('Cannot find inference graph in given path.')
74 
75  with self.graph.as_default():
76  tf.import_graph_def(graph_def, name='')
77 
78  config = tf.ConfigProto()
79  config.gpu_options.per_process_gpu_memory_fraction = gpu_percent
80  self.sess = tf.Session(graph=self.graph, config=config)
81 
82  self.output_tensor = [self.graph.get_tensor_by_name(tensor) for tensor in self.output_tensor ]
83  # ---------------------------------------------------------------------
84  if "input_image_size" in self.hypes.keys():
85  self.input_image_size = self.hypes["input_image_size"]
86  else:
87  self.input_image_size = (641, 361)
88 
89  self.tools = ResizeAndCrop(self.hypes, original_image_size)
91 
92  def run_model_on_image(self, image):
93  """A function that sets up and runs an image through KittiSeg
94  Input: Image to process
95  Output: way_prediction, time_tf"""
96 
97  image_for_proc, self.output_image_uncropped = self.tools.preprocess_image(
98  image, self.output_image_uncropped)
99 
100  return self.run_processed_image(image_for_proc)
101 
102  def run_processed_image(self, image):
103  """Runs inference on a single image.
104 
105  Args:
106  image: A PIL.Image object, raw input image.
107 
108  Returns:
109  resized_image: RGB image resized from original input image.
110  detected_classes: Segmentation map of `resized_image`.
111  """
112  time__tf_start = timeit.default_timer()
113  # ---------------------------------
114  boxes, scores, classes = self.sess.run(
115  self.output_tensor,
116  feed_dict={self.input_tensor: [np.asarray(image)]})
117  # ---------------------------------
118  time__tf = timeit.default_timer() - time__tf_start
119 
120  detected_classes = {}
121  detected_classes['boxes'] = boxes
122  detected_classes['scores'] = scores
123  detected_classes['classes'] = classes
124  return detected_classes, time__tf
def __init__(self, model_dir, original_image_size, tensor_io, runCPU, gpu_percent=1)
def load_hypes(model_dir)


cnn_bridge
Author(s): Noam C. Golombek , Alexander Beringolts
autogenerated on Mon Jun 10 2019 12:53:26