deeplab_segmenter.py
Go to the documentation of this file.
1 #!/usr/bin/python
2 # BSD 3-Clause License
3 
4 # Copyright (c) 2019, Noam C. Golombek
5 # All rights reserved.
6 
7 # Redistribution and use in source and binary forms, with or without
8 # modification, are permitted provided that the following conditions are met:
9 
10 # 1. Redistributions of source code must retain the above copyright notice, this
11 # list of conditions and the following disclaimer.
12 
13 # 2. Redistributions in binary form must reproduce the above copyright notice,
14 # this list of conditions and the following disclaimer in the documentation
15 # and/or other materials provided with the distribution.
16 
17 # 3. Neither the name of the copyright holder nor the names of its
18 # contributors may be used to endorse or promote products derived from
19 # this software without specific prior written permission.
20 
21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 
32 
33 
34 import numpy as np
35 import tensorflow as tf
36 import timeit
37 import rospy
38 
39 
40 from tools import ResizeAndCrop
41 
42 
43 def load_hypes(model_dir):
44  import os
45  import json
46  if os.path.isdir(model_dir):
47  hypes_name = os.path.join(model_dir, "deeplab.json")
48  else:
49  hypes_name = model_dir
50 
51  with open(hypes_name, 'r') as f:
52  return json.load(f)
53 
54 
55 class DeepLabSegmenter(object):
56  """Class to load deeplab model and run inference."""
57 
58  def __init__(self, model_dir, original_image_size, tensor_io, runCPU, gpu_percent=1):
59  self.hypes = load_hypes(model_dir)
60  self.input_tensor = tensor_io["input_tensor"]
61  self.output_tensor = tensor_io["output_tensor"]
62  frozen_graph_path = self.hypes['frozen_graph_path']
63  rospy.logwarn("Deeplab to load: " + frozen_graph_path)
64  # ---------------------------------------------------------------------
65  """Creates and loads pretrained deeplab model."""
66  self.graph = tf.Graph()
67 
68  graph_def = None
69  # Extract frozen graph from given path.
70  with open(frozen_graph_path, 'rb') as file_handle:
71  graph_def = tf.GraphDef.FromString(file_handle.read())
72 
73  if graph_def is None:
74  raise RuntimeError('Cannot find inference graph in given path.')
75 
76  with self.graph.as_default():
77  tf.import_graph_def(graph_def, name='')
78 
79  config = tf.ConfigProto()
80  config.gpu_options.per_process_gpu_memory_fraction = gpu_percent
81  self.sess = tf.Session(graph=self.graph, config=config)
82 
83  # ---------------------------------------------------------------------
84  if "input_image_size" in self.hypes.keys():
85  self.input_image_size = self.hypes["input_image_size"]
86  else:
87  self.input_image_size = (641, 361)
88 
89  self.tools = ResizeAndCrop(self.hypes, original_image_size)
91 
92  def run_model_on_image(self, image):
93  """A function that sets up and runs an image through KittiSeg
94  Input: Image to process
95  Output: way_prediction, time_tf"""
96 
97  image_for_proc, self.output_image_uncropped = self.tools.preprocess_image(
98  image, self.output_image_uncropped)
99 
100  # height, width, channels = image.shape
101  # resize_ratio = 1.0 * self.input_image_size / max(width, height)
102  # target_size = (int(resize_ratio * width), int(resize_ratio * height))
103  # resized_image = image.convert('RGB').resize(
104  # target_size, Image.ANTIALIAS)
105 
106  output_image, time_tf = self.run_processed_image(image_for_proc)
107 
108  # -----------------------------------------------------------------
109  # Plot confidences as red-blue overlay
110  # rb_image = seg.make_overlay(image, output_image)
111 
112  return self.tools.postprocess_image(
113  output_image, self.output_image_uncropped, image, self.hypes["selected_classes"]), time_tf
114 
115  def run_processed_image(self, image):
116  """Runs inference on a single image.
117 
118  Args:
119  image: A PIL.Image object, raw input image.
120 
121  Returns:
122  resized_image: RGB image resized from original input image.
123  seg_map: Segmentation map of `resized_image`.
124  """
125  time__tf_start = timeit.default_timer()
126  # ---------------------------------
127  batch_seg_map = self.sess.run(
128  self.output_tensor,
129  feed_dict={self.input_tensor: [np.asarray(image)]})
130  # ---------------------------------
131  time__tf = timeit.default_timer() - time__tf_start
132 
133  seg_map = batch_seg_map[0]
134  return seg_map, time__tf
135 
136 # def create_pascal_label_colormap():
137 # """Creates a label colormap used in PASCAL VOC segmentation benchmark.
138 
139 # Returns:
140 # A Colormap for visualizing segmentation results.
141 # """
142 # colormap = np.zeros((256, 3), dtype=int)
143 # ind = np.arange(256, dtype=int)
144 
145 # for shift in reversed(range(8)):
146 # for channel in range(3):
147 # colormap[:, channel] |= ((ind >> channel) & 1) << shift
148 # ind >>= 3
149 
150 # return colormap
151 
152 # def label_to_color_image(label):
153 # """Adds color defined by the dataset colormap to the label.
154 
155 # Args:
156 # label: A 2D array with integer type, storing the segmentation label.
157 
158 # Returns:
159 # result: A 2D array with floating type. The element of the array
160 # is the color indexed by the corresponding element in the input label
161 # to the PASCAL color map.
162 
163 # Raises:
164 # ValueError: If label is not of rank 2 or its value is larger than color
165 # map maximum entry.
166 # """
167 # if label.ndim != 2:
168 # raise ValueError('Expect 2-D input label')
169 
170 # colormap = create_pascal_label_colormap()
171 
172 # if np.max(label) >= len(colormap):
173 # raise ValueError('label value too large.')
174 
175 # return colormap[label]
def __init__(self, model_dir, original_image_size, tensor_io, runCPU, gpu_percent=1)


cnn_bridge
Author(s): Noam C. Golombek , Alexander Beringolts
autogenerated on Mon Jun 10 2019 12:53:26