fcn_object_segmentation.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 from distutils.version import LooseVersion
6 
7 import itertools, pkg_resources, sys
8 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
9  sys.version_info.major == 2:
10  print('''Please install chainer < 7.0.0:
11 
12  sudo pip install chainer==6.7.0
13 
14 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
15 ''', file=sys.stderr)
16  sys.exit(1)
17 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
18  print('''Please install CuPy
19 
20  sudo pip install cupy-cuda[your cuda version]
21 i.e.
22  sudo pip install cupy-cuda91
23 
24 ''', file=sys.stderr)
25  # sys.exit(1)
26 import chainer
27 from chainer import cuda
28 import chainer.serializers as S
29 import fcn
30 
31 import cv_bridge
32 from jsk_topic_tools import ConnectionBasedTransport
33 import message_filters
34 import numpy as np
35 import rospy
36 from sensor_msgs.msg import Image
37 
38 
39 is_torch_available = True
40 try:
41  import torch
42 except ImportError:
43  is_torch_available = False
44 
45 
47  if not is_torch_available:
48  url = 'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl' # NOQA
49  raise RuntimeError('Please install pytorch: pip install %s' % url)
50 
51 
52 class FCNObjectSegmentation(ConnectionBasedTransport):
53 
54  def __init__(self):
55  super(self.__class__, self).__init__()
56  self.backend = rospy.get_param('~backend', 'chainer')
57  self.gpu = rospy.get_param('~gpu', -1) # -1 is cpu mode
58  self.target_names = rospy.get_param('~target_names')
59  self.bg_label = rospy.get_param('~bg_label', 0)
60  self.proba_threshold = rospy.get_param('~proba_threshold', 0.0)
61  self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
62  self._load_model()
63  self.pub = self.advertise('~output', Image, queue_size=1)
64  self.pub_proba = self.advertise(
65  '~output/proba_image', Image, queue_size=1)
66 
67  def _load_model(self):
68  if self.backend == 'chainer':
69  self._load_chainer_model()
70  elif self.backend == 'torch':
72  # we assume input data size won't change in dynamic
73  torch.backends.cudnn.benchmark = True
74  self._load_torch_model()
75  else:
76  raise RuntimeError('Unsupported backend: %s', self.backend)
77 
79  model_name = rospy.get_param('~model_name')
80  if rospy.has_param('~model_h5'):
81  rospy.logwarn('Rosparam ~model_h5 is deprecated,'
82  ' and please use ~model_file instead.')
83  model_file = rospy.get_param('~model_h5')
84  else:
85  model_file = rospy.get_param('~model_file')
86  n_class = len(self.target_names)
87  if model_name == 'fcn32s':
88  self.model = fcn.models.FCN32s(n_class=n_class)
89  elif model_name == 'fcn16s':
90  self.model = fcn.models.FCN16s(n_class=n_class)
91  elif model_name == 'fcn8s':
92  self.model = fcn.models.FCN8s(n_class=n_class)
93  elif model_name == 'fcn8s_at_once':
94  self.model = fcn.models.FCN8sAtOnce(n_class=n_class)
95  else:
96  raise ValueError('Unsupported ~model_name: {}'.format(model_name))
97  rospy.loginfo('Loading trained model: {0}'.format(model_file))
98  if model_file.endswith('.npz'):
99  S.load_npz(model_file, self.model)
100  else:
101  S.load_hdf5(model_file, self.model)
102  rospy.loginfo('Finished loading trained model: {0}'.format(model_file))
103  if self.gpu != -1:
104  self.model.to_gpu(self.gpu)
105  if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
106  self.model.train = False
107 
108  def _load_torch_model(self):
109  try:
110  import torchfcn
111  except ImportError:
112  raise ImportError('Please install torchfcn: pip install torchfcn')
113  n_class = len(self.target_names)
114  model_file = rospy.get_param('~model_file')
115  model_name = rospy.get_param('~model_name')
116  if model_name == 'fcn32s':
117  self.model = torchfcn.models.FCN32s(n_class=n_class)
118  elif model_name == 'fcn32s_bilinear':
119  self.model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=True)
120  else:
121  raise ValueError('Unsupported ~model_name: {0}'.format(model_name))
122  rospy.loginfo('Loading trained model: %s' % model_file)
123  self.model.load_state_dict(torch.load(model_file))
124  rospy.loginfo('Finished loading trained model: %s' % model_file)
125  if self.gpu >= 0:
126  self.model = self.model.cuda(self.gpu)
127  self.model.eval()
128 
129  def subscribe(self):
130  use_mask = rospy.get_param('~use_mask', False)
131  if use_mask:
132  queue_size = rospy.get_param('~queue_size', 10)
133  sub_img = message_filters.Subscriber(
134  '~input', Image, queue_size=1, buff_size=2**24)
135  sub_mask = message_filters.Subscriber(
136  '~input/mask', Image, queue_size=1, buff_size=2**24)
137  self.subs = [sub_img, sub_mask]
138  if rospy.get_param('~approximate_sync', False):
139  slop = rospy.get_param('~slop', 0.1)
140  sync = message_filters.ApproximateTimeSynchronizer(
141  fs=self.subs, queue_size=queue_size, slop=slop)
142  else:
144  fs=self.subs, queue_size=queue_size)
145  sync.registerCallback(self._cb_with_mask)
146  else:
147  # larger buff_size is necessary for taking time callback
148  # http://stackoverflow.com/questions/26415699/ros-subscriber-not-up-to-date/29160379#29160379 # NOQA
149  sub_img = rospy.Subscriber(
150  '~input', Image, self._cb, queue_size=1, buff_size=2**24)
151  self.subs = [sub_img]
152 
153  def unsubscribe(self):
154  for sub in self.subs:
155  sub.unregister()
156 
157  def _cb_with_mask(self, img_msg, mask_msg):
158  br = cv_bridge.CvBridge()
159  img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
160  mask = br.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
161  if mask.ndim > 2:
162  mask = np.squeeze(mask, axis=2)
163  label, proba_img = self.segment(img)
164  label[mask == 0] = 0
165  proba_img[:, :, 0][mask == 0] = 1
166  proba_img[:, :, 1:][mask == 0] = 0
167  label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
168  label_msg.header = img_msg.header
169  self.pub.publish(label_msg)
170  proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
171  proba_msg.header = img_msg.header
172  self.pub_proba.publish(proba_msg)
173 
174  def _cb(self, img_msg):
175  br = cv_bridge.CvBridge()
176  img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
177  label, proba_img = self.segment(img)
178  label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
179  label_msg.header = img_msg.header
180  self.pub.publish(label_msg)
181  proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
182  proba_msg.header = img_msg.header
183  self.pub_proba.publish(proba_msg)
184 
185  def segment(self, bgr):
186  if self.backend == 'chainer':
187  return self._segment_chainer_backend(bgr)
188  elif self.backend == 'torch':
189  return self._segment_torch_backend(bgr)
190  raise ValueError('Unsupported backend: {0}'.format(self.backend))
191 
192  def _segment_chainer_backend(self, bgr):
193  blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
194  x_data = np.array([blob], dtype=np.float32)
195  if self.gpu != -1:
196  x_data = cuda.to_gpu(x_data, device=self.gpu)
197  if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
198  x = chainer.Variable(x_data, volatile=True)
199  self.model(x)
200  else:
201  with chainer.using_config('train', False), \
202  chainer.no_backprop_mode():
203  x = chainer.Variable(x_data)
204  self.model(x)
205  proba_img = chainer.functions.softmax(self.model.score)
206  proba_img = chainer.functions.transpose(proba_img, (0, 2, 3, 1))
207  max_proba_img = chainer.functions.max(proba_img, axis=-1)
208  label = chainer.functions.argmax(self.model.score, axis=1)
209  # squeeze batch axis, gpu -> cpu
210  proba_img = cuda.to_cpu(proba_img.data)[0]
211  max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
212  label = cuda.to_cpu(label.data)[0]
213  # uncertain because the probability is low
214  label[max_proba_img < self.proba_threshold] = self.bg_label
215  return label, proba_img
216 
217  def _segment_torch_backend(self, bgr):
218  blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
219  x_data = np.array([blob], dtype=np.float32)
220  x_data = torch.from_numpy(x_data)
221  if self.gpu >= 0:
222  x_data = x_data.cuda(self.gpu)
223  x = torch.autograd.Variable(x_data, volatile=True)
224  score = self.model(x)
225  proba = torch.nn.functional.softmax(score)
226  max_proba, label = torch.max(proba, 1)
227  # uncertain because the probability is low
228  label[max_proba < self.proba_threshold] = self.bg_label
229  # gpu -> cpu
230  proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
231  label = label.data.cpu().numpy().squeeze((0, 1))
232  return label, proba
233 
234 
235 if __name__ == '__main__':
236  rospy.init_node('fcn_object_segmentation')
238  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27