00001
00002
00003
00004 import sys
00005
00006 import numpy as np
00007 import scipy.ndimage as ndi
00008 from skimage.color import gray2rgb
00009 from skimage.color import label2rgb
00010 from skimage.future.graph import draw_rag
00011 from skimage.future.graph import merge_hierarchical
00012 from skimage.future.graph import RAG
00013 from skimage.future.graph.rag import _add_edge_filter
00014 from skimage.measure import regionprops
00015 from skimage.morphology.convex_hull import convex_hull_image
00016 from skimage.segmentation import slic
00017 from skimage.util import img_as_uint
00018
00019 import cv_bridge
00020 from jsk_topic_tools import ConnectionBasedTransport
00021 from jsk_topic_tools import warn_no_remap
00022 import message_filters
00023 import rospy
00024 from sensor_msgs.msg import Image
00025
00026
00027
00028
00029
00030
00031 def rag_solidity(labels, connectivity=2):
00032
00033 graph = RAG()
00034
00035
00036
00037
00038 fp = ndi.generate_binary_structure(labels.ndim, connectivity)
00039 for d in range(fp.ndim):
00040 fp = fp.swapaxes(0, d)
00041 fp[0, ...] = 0
00042 fp = fp.swapaxes(0, d)
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 ndi.generic_filter(
00056 labels,
00057 function=_add_edge_filter,
00058 footprint=fp,
00059 mode='nearest',
00060 output=np.zeros(labels.shape, dtype=np.uint8),
00061 extra_arguments=(graph,))
00062
00063
00064
00065 graph.remove_node(0)
00066
00067 for n in graph:
00068 mask = (labels == n)
00069 solidity = 1. * mask.sum() / convex_hull_image(mask).sum()
00070 graph.node[n].update({'labels': [n],
00071 'solidity': solidity,
00072 'mask': mask})
00073
00074 for x, y, d in graph.edges_iter(data=True):
00075 new_mask = np.logical_or(graph.node[x]['mask'], graph.node[y]['mask'])
00076 new_solidity = 1. * new_mask.sum() / convex_hull_image(new_mask).sum()
00077 org_solidity = np.mean([graph.node[x]['solidity'],
00078 graph.node[y]['solidity']])
00079 d['weight'] = org_solidity / new_solidity
00080
00081 return graph
00082
00083
00084
00085
00086
00087
00088 def _solidity_weight_func(graph, src, dst, n):
00089 """Callback to handle merging nodes by recomputing solidity."""
00090 org_solidity = np.mean([graph.node[src]['solidity'],
00091 graph.node[dst]['solidity']])
00092 new_mask1 = np.logical_or(graph.node[src]['mask'], graph.node[n]['mask'])
00093 new_mask2 = np.logical_or(graph.node[dst]['mask'], graph.node[n]['mask'])
00094 new_solidity1 = 1. * new_mask1.sum() / convex_hull_image(new_mask1).sum()
00095 new_solidity2 = 1. * new_mask2.sum() / convex_hull_image(new_mask2).sum()
00096 weight1 = org_solidity / new_solidity1
00097 weight2 = org_solidity / new_solidity2
00098 return min([weight1, weight2])
00099
00100
00101 def _solidity_merge_func(graph, src, dst):
00102 """Callback called before merging two nodes of a solidity graph."""
00103 new_mask = np.logical_or(graph.node[src]['mask'], graph.node[dst]['mask'])
00104 graph.node[dst]['mask'] = new_mask
00105 graph.node[dst]['solidity'] = \
00106 1. * np.sum(new_mask) / np.sum(convex_hull_image(new_mask))
00107
00108
00109
00110
00111
00112
00113 def masked_slic(img, mask, n_segments, compactness):
00114 labels = slic(img, n_segments=n_segments, compactness=compactness)
00115 labels += 1
00116 n_labels = len(np.unique(labels))
00117 try:
00118 mask = ndi.binary_closing(mask, structure=np.ones((3, 3)), iterations=1)
00119 except IndexError, e:
00120 rospy.logerr(e)
00121 return
00122 labels[mask == 0] = 0
00123 if len(np.unique(labels)) < n_labels - 2:
00124 sys.stderr.write('WARNING: number of label differs after masking.'
00125 ' Maybe this is not good for RAG construction.\n')
00126 return labels
00127
00128
00129 def closed_mask_roi(mask):
00130 closed_mask = ndi.binary_closing(
00131 mask, structure=np.ones((3, 3)), iterations=8)
00132 roi = ndi.find_objects(closed_mask, max_label=1)[0]
00133 return roi
00134
00135
00136
00137
00138
00139
00140 class SolidityRagMerge(ConnectionBasedTransport):
00141
00142 def __init__(self):
00143 super(SolidityRagMerge, self).__init__()
00144 self.pub = self.advertise('~output', Image, queue_size=5)
00145 self.is_debugging = rospy.get_param('~debug', True)
00146 if self.is_debugging:
00147 self.pub_rag = self.advertise('~debug/rag', Image, queue_size=5)
00148 self.pub_slic = self.advertise('~debug/slic', Image, queue_size=5)
00149 self.pub_label = self.advertise('~debug/label_viz', Image,
00150 queue_size=5)
00151
00152 def subscribe(self):
00153 self.sub = message_filters.Subscriber('~input', Image)
00154 self.sub_mask = message_filters.Subscriber('~input/mask', Image)
00155 self.use_async = rospy.get_param('~approximate_sync', False)
00156 rospy.loginfo('~approximate_sync: {}'.format(self.use_async))
00157 if self.use_async:
00158 sync = message_filters.ApproximateTimeSynchronizer(
00159 [self.sub, self.sub_mask], queue_size=1000, slop=0.1)
00160 else:
00161 sync = message_filters.TimeSynchronizer(
00162 [self.sub, self.sub_mask], queue_size=1000)
00163 sync.registerCallback(self._apply)
00164 warn_no_remap('~input', '~input/mask')
00165
00166 def unsubscribe(self):
00167 self.sub.unregister()
00168 self.sub_mask.unregister()
00169
00170 def _apply(self, imgmsg, maskmsg):
00171 bridge = cv_bridge.CvBridge()
00172 img = bridge.imgmsg_to_cv2(imgmsg)
00173 if img.ndim == 2:
00174 img = gray2rgb(img)
00175 mask = bridge.imgmsg_to_cv2(maskmsg, desired_encoding='mono8')
00176 mask = mask.reshape(mask.shape[:2])
00177 mask = gray2rgb(mask)
00178
00179 roi = closed_mask_roi(mask)
00180 roi_labels = masked_slic(img=img[roi], mask=mask[roi],
00181 n_segments=20, compactness=30)
00182 if roi_labels is None:
00183 return
00184 labels = np.zeros(mask.shape, dtype=np.int32)
00185
00186 labels[roi] = roi_labels
00187 if self.is_debugging:
00188
00189 slic_labelmsg = bridge.cv2_to_imgmsg(labels)
00190 slic_labelmsg.header = imgmsg.header
00191 self.pub_slic.publish(slic_labelmsg)
00192
00193 g = rag_solidity(labels, connectivity=2)
00194 if self.is_debugging:
00195
00196 rag_img = draw_rag(labels, g, img)
00197 rag_img = img_as_uint(rag_img)
00198 rag_imgmsg = bridge.cv2_to_imgmsg(
00199 rag_img.astype(np.uint8), encoding='rgb8')
00200 rag_imgmsg.header = imgmsg.header
00201 self.pub_rag.publish(rag_imgmsg)
00202
00203 merged_labels = merge_hierarchical(
00204 labels, g, thresh=1, rag_copy=False,
00205 in_place_merge=True,
00206 merge_func=_solidity_merge_func,
00207 weight_func=_solidity_weight_func)
00208 merged_labels += 1
00209 merged_labels[mask == 0] = 0
00210 merged_labelmsg = bridge.cv2_to_imgmsg(merged_labels.astype(np.int32))
00211 merged_labelmsg.header = imgmsg.header
00212 self.pub.publish(merged_labelmsg)
00213 if self.is_debugging:
00214 out = label2rgb(merged_labels, img)
00215 out = (out * 255).astype(np.uint8)
00216 out_msg = bridge.cv2_to_imgmsg(out, encoding='rgb8')
00217 out_msg.header = imgmsg.header
00218 self.pub_label.publish(out_msg)
00219
00220
00221 if __name__ == '__main__':
00222 rospy.init_node('solidity_rag_merge')
00223 SolidityRagMerge()
00224 rospy.spin()