solidity_rag_merge.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 
4 from distutils.version import LooseVersion
5 import sys
6 
7 import matplotlib
8 matplotlib.use('Agg') # NOQA
9 import matplotlib.pyplot as plt
10 import networkx
11 import numpy as np
12 import scipy.ndimage as ndi
13 import skimage
14 from skimage.color import gray2rgb
15 from skimage.color import label2rgb
16 from skimage.future.graph import merge_hierarchical
17 from skimage.future.graph import RAG
18 from skimage.future.graph.rag import _add_edge_filter
19 from skimage.morphology.convex_hull import convex_hull_image
20 from skimage.segmentation import slic
21 from skimage.util import img_as_uint
22 
23 import cv_bridge
24 from jsk_topic_tools import ConnectionBasedTransport
25 from jsk_topic_tools import warn_no_remap
26 import message_filters
27 import rospy
28 from sensor_msgs.msg import Image
29 
30 if LooseVersion(skimage.__version__) >= '0.13.0':
31  from skimage.future.graph import show_rag
32 else:
33  from skimage.future.graph import draw_rag
34 
35 
36 
39 
40 def rag_solidity(labels, connectivity=2):
41 
42  if LooseVersion(skimage.__version__) >= '0.16.0':
43  RAG.node = RAG.nodes
44  graph = RAG()
45 
46  # The footprint is constructed in such a way that the first
47  # element in the array being passed to _add_edge_filter is
48  # the central value.
49  fp = ndi.generate_binary_structure(labels.ndim, connectivity)
50  for d in range(fp.ndim):
51  fp = fp.swapaxes(0, d)
52  fp[0, ...] = 0
53  fp = fp.swapaxes(0, d)
54 
55  # For example
56  # if labels.ndim = 2 and connectivity = 1
57  # fp = [[0,0,0],
58  # [0,1,1],
59  # [0,1,0]]
60  #
61  # if labels.ndim = 2 and connectivity = 2
62  # fp = [[0,0,0],
63  # [0,1,1],
64  # [0,1,1]]
65 
66  ndi.generic_filter(
67  labels,
68  function=_add_edge_filter,
69  footprint=fp,
70  mode='nearest',
71  output=np.zeros(labels.shape, dtype=np.uint8),
72  extra_arguments=(graph,))
73 
74  # remove bg_label
75  # graph.remove_node(-1)
76  graph.remove_node(0)
77 
78  for n in graph:
79  mask = (labels == n)
80  solidity = 1. * mask.sum() / convex_hull_image(mask).sum()
81  graph.node[n].update({'labels': [n],
82  'solidity': solidity,
83  'mask': mask})
84 
85  if LooseVersion(networkx.__version__) >= '2':
86  edges_iter = graph.edges(data=True)
87  else:
88  edges_iter = graph.edges_iter(data=True)
89  for x, y, d in edges_iter:
90  new_mask = np.logical_or(graph.node[x]['mask'], graph.node[y]['mask'])
91  new_solidity = 1. * new_mask.sum() / convex_hull_image(new_mask).sum()
92  org_solidity = np.mean([graph.node[x]['solidity'],
93  graph.node[y]['solidity']])
94  d['weight'] = org_solidity / new_solidity
95 
96  return graph
97 
98 
99 
102 
103 def _solidity_weight_func(graph, src, dst, n):
104  """Callback to handle merging nodes by recomputing solidity."""
105  org_solidity = np.mean([graph.node[src]['solidity'],
106  graph.node[dst]['solidity']])
107  new_mask1 = np.logical_or(graph.node[src]['mask'], graph.node[n]['mask'])
108  new_mask2 = np.logical_or(graph.node[dst]['mask'], graph.node[n]['mask'])
109  new_solidity1 = 1. * new_mask1.sum() / convex_hull_image(new_mask1).sum()
110  new_solidity2 = 1. * new_mask2.sum() / convex_hull_image(new_mask2).sum()
111  weight1 = org_solidity / new_solidity1
112  weight2 = org_solidity / new_solidity2
113  return {'weight': min([weight1, weight2])}
114 
115 
116 def _solidity_merge_func(graph, src, dst):
117  """Callback called before merging two nodes of a solidity graph."""
118  new_mask = np.logical_or(graph.node[src]['mask'], graph.node[dst]['mask'])
119  graph.node[dst]['mask'] = new_mask
120  graph.node[dst]['solidity'] = \
121  1. * np.sum(new_mask) / np.sum(convex_hull_image(new_mask))
122 
123 
124 
127 
128 def masked_slic(img, mask, n_segments, compactness):
129  labels = slic(img, n_segments=n_segments, compactness=compactness)
130  labels += 1
131  n_labels = len(np.unique(labels))
132  try:
133  mask = ndi.binary_closing(mask, structure=np.ones((3, 3)), iterations=1)
134  except IndexError as e:
135  rospy.logerr(e)
136  return
137  labels[mask == 0] = 0 # set bg_label
138  if len(np.unique(labels)) < n_labels - 2:
139  sys.stderr.write('WARNING: number of label differs after masking.'
140  ' Maybe this is not good for RAG construction.\n')
141  return labels
142 
143 
144 def closed_mask_roi(mask):
145  closed_mask = ndi.binary_closing(
146  mask, structure=np.ones((3, 3)), iterations=8)
147  roi = ndi.find_objects(closed_mask, max_label=1)[0]
148  return roi
149 
150 
151 
154 
155 class SolidityRagMerge(ConnectionBasedTransport):
156 
157  def __init__(self):
158  super(SolidityRagMerge, self).__init__()
159  self.pub = self.advertise('~output', Image, queue_size=5)
160  self.is_debugging = rospy.get_param('~debug', True)
161  if self.is_debugging:
162  self.pub_rag = self.advertise('~debug/rag', Image, queue_size=5)
163  self.pub_slic = self.advertise('~debug/slic', Image, queue_size=5)
164  self.pub_label = self.advertise('~debug/label_viz', Image,
165  queue_size=5)
166 
167  def subscribe(self):
168  self.use_mask = rospy.get_param('~use_mask', True)
169  if self.use_mask:
170  self.sub = message_filters.Subscriber('~input', Image)
171  self.sub_mask = message_filters.Subscriber('~input/mask', Image)
172  self.use_async = rospy.get_param('~approximate_sync', False)
173  rospy.loginfo('~approximate_sync: {}'.format(self.use_async))
174  if self.use_async:
175  sync = message_filters.ApproximateTimeSynchronizer(
176  [self.sub, self.sub_mask], queue_size=1000, slop=0.1)
177  else:
179  [self.sub, self.sub_mask], queue_size=1000)
180  sync.registerCallback(self.sub_cb)
181  warn_no_remap('~input', '~input/mask')
182  else:
183  self.sub = rospy.Subscriber('~input', Image, self.sub_img_cb,
184  queue_size=1)
185  warn_no_remap('~input')
186 
187  def unsubscribe(self):
188  self.sub.unregister()
189  if self.use_mask:
190  self.sub_mask.unregister()
191 
192  def sub_cb(self, imgmsg, maskmsg):
193  self._apply(imgmsg, maskmsg)
194 
195  def sub_img_cb(self, imgmsg):
196  self._apply(imgmsg, None)
197 
198  def _apply(self, imgmsg, maskmsg=None):
199  bridge = cv_bridge.CvBridge()
200  img = bridge.imgmsg_to_cv2(imgmsg)
201  if img.ndim == 2:
202  img = gray2rgb(img)
203 
204  if maskmsg is None:
205  # compute label
206  mask = np.ones(img.shape[:2], dtype=np.uint8)
207  labels = masked_slic(img=img, mask=mask,
208  n_segments=20, compactness=30)
209  if labels is None:
210  return
211  labels = labels.astype(np.int32)
212  else:
213  mask = bridge.imgmsg_to_cv2(maskmsg, desired_encoding='mono8')
214  mask = mask.reshape(mask.shape[:2])
215  # compute label
216  roi = closed_mask_roi(mask)
217  roi_labels = masked_slic(img=img[roi], mask=mask[roi],
218  n_segments=20, compactness=30)
219  if roi_labels is None:
220  return
221  labels = np.zeros(mask.shape, dtype=np.int32)
222  labels[roi] = roi_labels
223 
224  # labels.fill(-1) # set bg_label
225  if self.is_debugging:
226  # publish debug slic label
227  slic_labelmsg = bridge.cv2_to_imgmsg(labels)
228  slic_labelmsg.header = imgmsg.header
229  self.pub_slic.publish(slic_labelmsg)
230  # compute rag
231  g = rag_solidity(labels, connectivity=2)
232  if self.is_debugging:
233  # publish debug rag drawn image
234  if LooseVersion(skimage.__version__) >= '0.13.0':
235  fig, ax = plt.subplots(
236  figsize=(img.shape[1] * 0.01, img.shape[0] * 0.01))
237  show_rag(labels, g, img, ax=ax)
238  ax.axis('off')
239  plt.subplots_adjust(0, 0, 1, 1)
240  fig.canvas.draw()
241  w, h = fig.canvas.get_width_height()
242  rag_img = np.fromstring(
243  fig.canvas.tostring_rgb(), dtype=np.uint8)
244  rag_img.shape = (h, w, 3)
245  plt.close()
246  else:
247  rag_img = draw_rag(labels, g, img)
248  rag_img = img_as_uint(rag_img)
249  rag_imgmsg = bridge.cv2_to_imgmsg(
250  rag_img.astype(np.uint8), encoding='rgb8')
251  rag_imgmsg.header = imgmsg.header
252  self.pub_rag.publish(rag_imgmsg)
253  # merge rag with solidity
254  merged_labels = merge_hierarchical(
255  labels, g, thresh=1, rag_copy=False,
256  in_place_merge=True,
257  merge_func=_solidity_merge_func,
258  weight_func=_solidity_weight_func)
259  merged_labels += 1
260  merged_labels[mask == 0] = 0
261  merged_labelmsg = bridge.cv2_to_imgmsg(merged_labels.astype(np.int32))
262  merged_labelmsg.header = imgmsg.header
263  self.pub.publish(merged_labelmsg)
264  if self.is_debugging:
265  out = label2rgb(merged_labels, img)
266  out = (out * 255).astype(np.uint8)
267  out_msg = bridge.cv2_to_imgmsg(out, encoding='rgb8')
268  out_msg.header = imgmsg.header
269  self.pub_label.publish(out_msg)
270 
271 
272 if __name__ == '__main__':
273  rospy.init_node('solidity_rag_merge')
275  rospy.spin()
node_scripts.solidity_rag_merge.SolidityRagMerge.subscribe
def subscribe(self)
Definition: solidity_rag_merge.py:167
node_scripts.solidity_rag_merge.SolidityRagMerge.use_async
use_async
Definition: solidity_rag_merge.py:172
node_scripts.solidity_rag_merge.SolidityRagMerge.pub
pub
Definition: solidity_rag_merge.py:159
node_scripts.solidity_rag_merge.SolidityRagMerge
ros node
Definition: solidity_rag_merge.py:155
node_scripts.solidity_rag_merge.SolidityRagMerge._apply
def _apply(self, imgmsg, maskmsg=None)
Definition: solidity_rag_merge.py:198
node_scripts.solidity_rag_merge.rag_solidity
def rag_solidity(labels, connectivity=2)
rag function
Definition: solidity_rag_merge.py:40
node_scripts.solidity_rag_merge._solidity_weight_func
def _solidity_weight_func(graph, src, dst, n)
rag merging functions
Definition: solidity_rag_merge.py:103
node_scripts.solidity_rag_merge.SolidityRagMerge.__init__
def __init__(self)
Definition: solidity_rag_merge.py:157
message_filters::Subscriber
node_scripts.solidity_rag_merge.SolidityRagMerge.unsubscribe
def unsubscribe(self)
Definition: solidity_rag_merge.py:187
node_scripts.solidity_rag_merge.SolidityRagMerge.is_debugging
is_debugging
Definition: solidity_rag_merge.py:160
node_scripts.solidity_rag_merge.SolidityRagMerge.pub_slic
pub_slic
Definition: solidity_rag_merge.py:163
node_scripts.solidity_rag_merge.SolidityRagMerge.pub_rag
pub_rag
Definition: solidity_rag_merge.py:162
node_scripts.solidity_rag_merge.SolidityRagMerge.sub
sub
Definition: solidity_rag_merge.py:170
node_scripts.solidity_rag_merge.closed_mask_roi
def closed_mask_roi(mask)
Definition: solidity_rag_merge.py:144
node_scripts.solidity_rag_merge.SolidityRagMerge.sub_img_cb
def sub_img_cb(self, imgmsg)
Definition: solidity_rag_merge.py:195
node_scripts.solidity_rag_merge.SolidityRagMerge.sub_cb
def sub_cb(self, imgmsg, maskmsg)
Definition: solidity_rag_merge.py:192
node_scripts.label_image_decomposer.label2rgb
def label2rgb(lbl, img=None, label_names=None, alpha=0.3, bg_label=0)
Definition: label_image_decomposer.py:32
node_scripts.solidity_rag_merge._solidity_merge_func
def _solidity_merge_func(graph, src, dst)
Definition: solidity_rag_merge.py:116
node_scripts.solidity_rag_merge.SolidityRagMerge.sub_mask
sub_mask
Definition: solidity_rag_merge.py:171
message_filters::TimeSynchronizer
node_scripts.solidity_rag_merge.SolidityRagMerge.pub_label
pub_label
Definition: solidity_rag_merge.py:164
node_scripts.solidity_rag_merge.SolidityRagMerge.use_mask
use_mask
Definition: solidity_rag_merge.py:168
node_scripts.solidity_rag_merge.masked_slic
def masked_slic(img, mask, n_segments, compactness)
utils
Definition: solidity_rag_merge.py:128


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:17