solidity_rag_merge.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 
4 from distutils.version import LooseVersion
5 import sys
6 
7 import matplotlib
8 matplotlib.use('Agg') # NOQA
9 import matplotlib.pyplot as plt
10 import networkx
11 import numpy as np
12 import scipy.ndimage as ndi
13 import skimage
14 from skimage.color import gray2rgb
15 from skimage.color import label2rgb
16 from skimage.future.graph import merge_hierarchical
17 from skimage.future.graph import RAG
18 from skimage.future.graph.rag import _add_edge_filter
19 from skimage.morphology.convex_hull import convex_hull_image
20 from skimage.segmentation import slic
21 from skimage.util import img_as_uint
22 
23 import cv_bridge
24 from jsk_topic_tools import ConnectionBasedTransport
25 from jsk_topic_tools import warn_no_remap
26 import message_filters
27 import rospy
28 from sensor_msgs.msg import Image
29 
30 if LooseVersion(skimage.__version__) >= '0.13.0':
31  from skimage.future.graph import show_rag
32 else:
33  from skimage.future.graph import draw_rag
34 
35 
36 ###############################################################################
37 # rag function
38 ###############################################################################
39 
40 def rag_solidity(labels, connectivity=2):
41 
42  graph = RAG()
43 
44  # The footprint is constructed in such a way that the first
45  # element in the array being passed to _add_edge_filter is
46  # the central value.
47  fp = ndi.generate_binary_structure(labels.ndim, connectivity)
48  for d in range(fp.ndim):
49  fp = fp.swapaxes(0, d)
50  fp[0, ...] = 0
51  fp = fp.swapaxes(0, d)
52 
53  # For example
54  # if labels.ndim = 2 and connectivity = 1
55  # fp = [[0,0,0],
56  # [0,1,1],
57  # [0,1,0]]
58  #
59  # if labels.ndim = 2 and connectivity = 2
60  # fp = [[0,0,0],
61  # [0,1,1],
62  # [0,1,1]]
63 
64  ndi.generic_filter(
65  labels,
66  function=_add_edge_filter,
67  footprint=fp,
68  mode='nearest',
69  output=np.zeros(labels.shape, dtype=np.uint8),
70  extra_arguments=(graph,))
71 
72  # remove bg_label
73  # graph.remove_node(-1)
74  graph.remove_node(0)
75 
76  for n in graph:
77  mask = (labels == n)
78  solidity = 1. * mask.sum() / convex_hull_image(mask).sum()
79  graph.node[n].update({'labels': [n],
80  'solidity': solidity,
81  'mask': mask})
82 
83  if LooseVersion(networkx.__version__) >= '2':
84  edges_iter = graph.edges(data=True)
85  else:
86  edges_iter = graph.edges_iter(data=True)
87  for x, y, d in edges_iter:
88  new_mask = np.logical_or(graph.node[x]['mask'], graph.node[y]['mask'])
89  new_solidity = 1. * new_mask.sum() / convex_hull_image(new_mask).sum()
90  org_solidity = np.mean([graph.node[x]['solidity'],
91  graph.node[y]['solidity']])
92  d['weight'] = org_solidity / new_solidity
93 
94  return graph
95 
96 
97 ###############################################################################
98 # rag merging functions
99 ###############################################################################
100 
101 def _solidity_weight_func(graph, src, dst, n):
102  """Callback to handle merging nodes by recomputing solidity."""
103  org_solidity = np.mean([graph.node[src]['solidity'],
104  graph.node[dst]['solidity']])
105  new_mask1 = np.logical_or(graph.node[src]['mask'], graph.node[n]['mask'])
106  new_mask2 = np.logical_or(graph.node[dst]['mask'], graph.node[n]['mask'])
107  new_solidity1 = 1. * new_mask1.sum() / convex_hull_image(new_mask1).sum()
108  new_solidity2 = 1. * new_mask2.sum() / convex_hull_image(new_mask2).sum()
109  weight1 = org_solidity / new_solidity1
110  weight2 = org_solidity / new_solidity2
111  return {'weight': min([weight1, weight2])}
112 
113 
114 def _solidity_merge_func(graph, src, dst):
115  """Callback called before merging two nodes of a solidity graph."""
116  new_mask = np.logical_or(graph.node[src]['mask'], graph.node[dst]['mask'])
117  graph.node[dst]['mask'] = new_mask
118  graph.node[dst]['solidity'] = \
119  1. * np.sum(new_mask) / np.sum(convex_hull_image(new_mask))
120 
121 
122 ###############################################################################
123 # utils
124 ###############################################################################
125 
126 def masked_slic(img, mask, n_segments, compactness):
127  labels = slic(img, n_segments=n_segments, compactness=compactness)
128  labels += 1
129  n_labels = len(np.unique(labels))
130  try:
131  mask = ndi.binary_closing(mask, structure=np.ones((3, 3)), iterations=1)
132  except IndexError as e:
133  rospy.logerr(e)
134  return
135  labels[mask == 0] = 0 # set bg_label
136  if len(np.unique(labels)) < n_labels - 2:
137  sys.stderr.write('WARNING: number of label differs after masking.'
138  ' Maybe this is not good for RAG construction.\n')
139  return labels
140 
141 
142 def closed_mask_roi(mask):
143  closed_mask = ndi.binary_closing(
144  mask, structure=np.ones((3, 3)), iterations=8)
145  roi = ndi.find_objects(closed_mask, max_label=1)[0]
146  return roi
147 
148 
149 ###############################################################################
150 # ros node
151 ###############################################################################
152 
153 class SolidityRagMerge(ConnectionBasedTransport):
154 
155  def __init__(self):
156  super(SolidityRagMerge, self).__init__()
157  self.pub = self.advertise('~output', Image, queue_size=5)
158  self.is_debugging = rospy.get_param('~debug', True)
159  if self.is_debugging:
160  self.pub_rag = self.advertise('~debug/rag', Image, queue_size=5)
161  self.pub_slic = self.advertise('~debug/slic', Image, queue_size=5)
162  self.pub_label = self.advertise('~debug/label_viz', Image,
163  queue_size=5)
164 
165  def subscribe(self):
166  self.sub = message_filters.Subscriber('~input', Image)
167  self.sub_mask = message_filters.Subscriber('~input/mask', Image)
168  self.use_async = rospy.get_param('~approximate_sync', False)
169  rospy.loginfo('~approximate_sync: {}'.format(self.use_async))
170  if self.use_async:
171  sync = message_filters.ApproximateTimeSynchronizer(
172  [self.sub, self.sub_mask], queue_size=1000, slop=0.1)
173  else:
175  [self.sub, self.sub_mask], queue_size=1000)
176  sync.registerCallback(self._apply)
177  warn_no_remap('~input', '~input/mask')
178 
179  def unsubscribe(self):
180  self.sub.unregister()
181  self.sub_mask.unregister()
182 
183  def _apply(self, imgmsg, maskmsg):
184  bridge = cv_bridge.CvBridge()
185  img = bridge.imgmsg_to_cv2(imgmsg)
186  if img.ndim == 2:
187  img = gray2rgb(img)
188  mask = bridge.imgmsg_to_cv2(maskmsg, desired_encoding='mono8')
189  mask = mask.reshape(mask.shape[:2])
190  # compute label
191  roi = closed_mask_roi(mask)
192  roi_labels = masked_slic(img=img[roi], mask=mask[roi],
193  n_segments=20, compactness=30)
194  if roi_labels is None:
195  return
196  labels = np.zeros(mask.shape, dtype=np.int32)
197  # labels.fill(-1) # set bg_label
198  labels[roi] = roi_labels
199  if self.is_debugging:
200  # publish debug slic label
201  slic_labelmsg = bridge.cv2_to_imgmsg(labels)
202  slic_labelmsg.header = imgmsg.header
203  self.pub_slic.publish(slic_labelmsg)
204  # compute rag
205  g = rag_solidity(labels, connectivity=2)
206  if self.is_debugging:
207  # publish debug rag drawn image
208  if LooseVersion(skimage.__version__) >= '0.13.0':
209  fig, ax = plt.subplots(
210  figsize=(img.shape[1] * 0.01, img.shape[0] * 0.01))
211  show_rag(labels, g, img, ax=ax)
212  ax.axis('off')
213  plt.subplots_adjust(0, 0, 1, 1)
214  fig.canvas.draw()
215  w, h = fig.canvas.get_width_height()
216  rag_img = np.fromstring(
217  fig.canvas.tostring_rgb(), dtype=np.uint8)
218  rag_img.shape = (h, w, 3)
219  plt.close()
220  else:
221  rag_img = draw_rag(labels, g, img)
222  rag_img = img_as_uint(rag_img)
223  rag_imgmsg = bridge.cv2_to_imgmsg(
224  rag_img.astype(np.uint8), encoding='rgb8')
225  rag_imgmsg.header = imgmsg.header
226  self.pub_rag.publish(rag_imgmsg)
227  # merge rag with solidity
228  merged_labels = merge_hierarchical(
229  labels, g, thresh=1, rag_copy=False,
230  in_place_merge=True,
231  merge_func=_solidity_merge_func,
232  weight_func=_solidity_weight_func)
233  merged_labels += 1
234  merged_labels[mask == 0] = 0
235  merged_labelmsg = bridge.cv2_to_imgmsg(merged_labels.astype(np.int32))
236  merged_labelmsg.header = imgmsg.header
237  self.pub.publish(merged_labelmsg)
238  if self.is_debugging:
239  out = label2rgb(merged_labels, img)
240  out = (out * 255).astype(np.uint8)
241  out_msg = bridge.cv2_to_imgmsg(out, encoding='rgb8')
242  out_msg.header = imgmsg.header
243  self.pub_label.publish(out_msg)
244 
245 
246 if __name__ == '__main__':
247  rospy.init_node('solidity_rag_merge')
249  rospy.spin()
def rag_solidity(labels, connectivity=2)
rag function
def _solidity_weight_func(graph, src, dst, n)
rag merging functions
def _solidity_merge_func(graph, src, dst)
def label2rgb(lbl, img=None, label_names=None, alpha=0.3, bg_label=0)
def masked_slic(img, mask, n_segments, compactness)
utils


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27