solidity_rag_merge.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 
00004 import sys
00005 
00006 import numpy as np
00007 import scipy.ndimage as ndi
00008 from skimage.color import gray2rgb
00009 from skimage.color import label2rgb
00010 from skimage.future.graph import draw_rag
00011 from skimage.future.graph import merge_hierarchical
00012 from skimage.future.graph import RAG
00013 from skimage.future.graph.rag import _add_edge_filter
00014 from skimage.measure import regionprops
00015 from skimage.morphology.convex_hull import convex_hull_image
00016 from skimage.segmentation import slic
00017 from skimage.util import img_as_uint
00018 
00019 import cv_bridge
00020 from jsk_topic_tools import ConnectionBasedTransport
00021 from jsk_topic_tools import warn_no_remap
00022 import message_filters
00023 import rospy
00024 from sensor_msgs.msg import Image
00025 
00026 
00027 ###############################################################################
00028 # rag function
00029 ###############################################################################
00030 
00031 def rag_solidity(labels, connectivity=2):
00032 
00033     graph = RAG()
00034 
00035     # The footprint is constructed in such a way that the first
00036     # element in the array being passed to _add_edge_filter is
00037     # the central value.
00038     fp = ndi.generate_binary_structure(labels.ndim, connectivity)
00039     for d in range(fp.ndim):
00040         fp = fp.swapaxes(0, d)
00041         fp[0, ...] = 0
00042         fp = fp.swapaxes(0, d)
00043 
00044     # For example
00045     # if labels.ndim = 2 and connectivity = 1
00046     # fp = [[0,0,0],
00047     #       [0,1,1],
00048     #       [0,1,0]]
00049     #
00050     # if labels.ndim = 2 and connectivity = 2
00051     # fp = [[0,0,0],
00052     #       [0,1,1],
00053     #       [0,1,1]]
00054 
00055     ndi.generic_filter(
00056         labels,
00057         function=_add_edge_filter,
00058         footprint=fp,
00059         mode='nearest',
00060         output=np.zeros(labels.shape, dtype=np.uint8),
00061         extra_arguments=(graph,))
00062 
00063     # remove bg_label
00064     # graph.remove_node(-1)
00065     graph.remove_node(0)
00066 
00067     for n in graph:
00068         mask = (labels == n)
00069         solidity = 1. * mask.sum() / convex_hull_image(mask).sum()
00070         graph.node[n].update({'labels': [n],
00071                               'solidity': solidity,
00072                               'mask': mask})
00073 
00074     for x, y, d in graph.edges_iter(data=True):
00075         new_mask = np.logical_or(graph.node[x]['mask'], graph.node[y]['mask'])
00076         new_solidity = 1. * new_mask.sum() / convex_hull_image(new_mask).sum()
00077         org_solidity = np.mean([graph.node[x]['solidity'],
00078                                 graph.node[y]['solidity']])
00079         d['weight'] = org_solidity / new_solidity
00080 
00081     return graph
00082 
00083 
00084 ###############################################################################
00085 # rag merging functions
00086 ###############################################################################
00087 
00088 def _solidity_weight_func(graph, src, dst, n):
00089     """Callback to handle merging nodes by recomputing solidity."""
00090     org_solidity = np.mean([graph.node[src]['solidity'],
00091                             graph.node[dst]['solidity']])
00092     new_mask1 = np.logical_or(graph.node[src]['mask'], graph.node[n]['mask'])
00093     new_mask2 = np.logical_or(graph.node[dst]['mask'], graph.node[n]['mask'])
00094     new_solidity1 = 1. * new_mask1.sum() / convex_hull_image(new_mask1).sum()
00095     new_solidity2 = 1. * new_mask2.sum() / convex_hull_image(new_mask2).sum()
00096     weight1 = org_solidity / new_solidity1
00097     weight2 = org_solidity / new_solidity2
00098     return min([weight1, weight2])
00099 
00100 
00101 def _solidity_merge_func(graph, src, dst):
00102     """Callback called before merging two nodes of a solidity graph."""
00103     new_mask = np.logical_or(graph.node[src]['mask'], graph.node[dst]['mask'])
00104     graph.node[dst]['mask'] = new_mask
00105     graph.node[dst]['solidity'] = \
00106         1. * np.sum(new_mask) / np.sum(convex_hull_image(new_mask))
00107 
00108 
00109 ###############################################################################
00110 # utils
00111 ###############################################################################
00112 
00113 def masked_slic(img, mask, n_segments, compactness):
00114     labels = slic(img, n_segments=n_segments, compactness=compactness)
00115     labels += 1
00116     n_labels = len(np.unique(labels))
00117     try:
00118         mask = ndi.binary_closing(mask, structure=np.ones((3, 3)), iterations=1)
00119     except IndexError, e:
00120         rospy.logerr(e)
00121         return
00122     labels[mask == 0] = 0  # set bg_label
00123     if len(np.unique(labels)) < n_labels - 2:
00124         sys.stderr.write('WARNING: number of label differs after masking.'
00125                          ' Maybe this is not good for RAG construction.\n')
00126     return labels
00127 
00128 
00129 def closed_mask_roi(mask):
00130     closed_mask = ndi.binary_closing(
00131         mask, structure=np.ones((3, 3)), iterations=8)
00132     roi = ndi.find_objects(closed_mask, max_label=1)[0]
00133     return roi
00134 
00135 
00136 ###############################################################################
00137 # ros node
00138 ###############################################################################
00139 
00140 class SolidityRagMerge(ConnectionBasedTransport):
00141 
00142     def __init__(self):
00143         super(SolidityRagMerge, self).__init__()
00144         self.pub = self.advertise('~output', Image, queue_size=5)
00145         self.is_debugging = rospy.get_param('~debug', True)
00146         if self.is_debugging:
00147             self.pub_rag = self.advertise('~debug/rag', Image, queue_size=5)
00148             self.pub_slic = self.advertise('~debug/slic', Image, queue_size=5)
00149             self.pub_label = self.advertise('~debug/label_viz', Image,
00150                                             queue_size=5)
00151 
00152     def subscribe(self):
00153         self.sub = message_filters.Subscriber('~input', Image)
00154         self.sub_mask = message_filters.Subscriber('~input/mask', Image)
00155         self.use_async = rospy.get_param('~approximate_sync', False)
00156         rospy.loginfo('~approximate_sync: {}'.format(self.use_async))
00157         if self.use_async:
00158             sync = message_filters.ApproximateTimeSynchronizer(
00159                 [self.sub, self.sub_mask], queue_size=1000, slop=0.1)
00160         else:
00161             sync = message_filters.TimeSynchronizer(
00162                 [self.sub, self.sub_mask], queue_size=1000)
00163         sync.registerCallback(self._apply)
00164         warn_no_remap('~input', '~input/mask')
00165 
00166     def unsubscribe(self):
00167         self.sub.unregister()
00168         self.sub_mask.unregister()
00169 
00170     def _apply(self, imgmsg, maskmsg):
00171         bridge = cv_bridge.CvBridge()
00172         img = bridge.imgmsg_to_cv2(imgmsg)
00173         if img.ndim == 2:
00174             img = gray2rgb(img)
00175         mask = bridge.imgmsg_to_cv2(maskmsg, desired_encoding='mono8')
00176         mask = mask.reshape(mask.shape[:2])
00177         mask = gray2rgb(mask)
00178         # compute label
00179         roi = closed_mask_roi(mask)
00180         roi_labels = masked_slic(img=img[roi], mask=mask[roi],
00181                                  n_segments=20, compactness=30)
00182         if roi_labels is None:
00183             return
00184         labels = np.zeros(mask.shape, dtype=np.int32)
00185         # labels.fill(-1)  # set bg_label
00186         labels[roi] = roi_labels
00187         if self.is_debugging:
00188             # publish debug slic label
00189             slic_labelmsg = bridge.cv2_to_imgmsg(labels)
00190             slic_labelmsg.header = imgmsg.header
00191             self.pub_slic.publish(slic_labelmsg)
00192         # compute rag
00193         g = rag_solidity(labels, connectivity=2)
00194         if self.is_debugging:
00195             # publish debug rag drawn image
00196             rag_img = draw_rag(labels, g, img)
00197             rag_img = img_as_uint(rag_img)
00198             rag_imgmsg = bridge.cv2_to_imgmsg(
00199                 rag_img.astype(np.uint8), encoding='rgb8')
00200             rag_imgmsg.header = imgmsg.header
00201             self.pub_rag.publish(rag_imgmsg)
00202         # merge rag with solidity
00203         merged_labels = merge_hierarchical(
00204             labels, g, thresh=1, rag_copy=False,
00205             in_place_merge=True,
00206             merge_func=_solidity_merge_func,
00207             weight_func=_solidity_weight_func)
00208         merged_labels += 1
00209         merged_labels[mask == 0] = 0
00210         merged_labelmsg = bridge.cv2_to_imgmsg(merged_labels.astype(np.int32))
00211         merged_labelmsg.header = imgmsg.header
00212         self.pub.publish(merged_labelmsg)
00213         if self.is_debugging:
00214             out = label2rgb(merged_labels, img)
00215             out = (out * 255).astype(np.uint8)
00216             out_msg = bridge.cv2_to_imgmsg(out, encoding='rgb8')
00217             out_msg.header = imgmsg.header
00218             self.pub_label.publish(out_msg)
00219 
00220 
00221 if __name__ == '__main__':
00222     rospy.init_node('solidity_rag_merge')
00223     SolidityRagMerge()
00224     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23