35 import tensorflow
as tf
40 from tools
import ResizeAndCrop
46 if os.path.isdir(model_dir):
47 hypes_name = os.path.join(model_dir,
"deeplab.json")
49 hypes_name = model_dir
51 with open(hypes_name,
'r') as f: 56 """Class to load deeplab model and run inference.""" 58 def __init__(self, model_dir, original_image_size, tensor_io, runCPU, gpu_percent=1):
62 frozen_graph_path = self.
hypes[
'frozen_graph_path']
63 rospy.logwarn(
"Deeplab to load: " + frozen_graph_path)
65 """Creates and loads pretrained deeplab model.""" 70 with open(frozen_graph_path,
'rb')
as file_handle:
71 graph_def = tf.GraphDef.FromString(file_handle.read())
74 raise RuntimeError(
'Cannot find inference graph in given path.')
76 with self.graph.as_default():
77 tf.import_graph_def(graph_def, name=
'')
79 config = tf.ConfigProto()
80 config.gpu_options.per_process_gpu_memory_fraction = gpu_percent
81 self.
sess = tf.Session(graph=self.
graph, config=config)
84 if "input_image_size" in self.hypes.keys():
89 self.
tools = ResizeAndCrop(self.
hypes, original_image_size)
93 """A function that sets up and runs an image through KittiSeg 94 Input: Image to process 95 Output: way_prediction, time_tf""" 112 return self.tools.postprocess_image(
116 """Runs inference on a single image. 119 image: A PIL.Image object, raw input image. 122 resized_image: RGB image resized from original input image. 123 seg_map: Segmentation map of `resized_image`. 125 time__tf_start = timeit.default_timer()
127 batch_seg_map = self.sess.run(
131 time__tf = timeit.default_timer() - time__tf_start
133 seg_map = batch_seg_map[0]
134 return seg_map, time__tf
def load_hypes(model_dir)
def run_processed_image(self, image)
def run_model_on_image(self, image)
def __init__(self, model_dir, original_image_size, tensor_io, runCPU, gpu_percent=1)