fcn_object_segmentation.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from distutils.version import LooseVersion
00004 
00005 import chainer
00006 from chainer import cuda
00007 import chainer.serializers as S
00008 import fcn
00009 
00010 import cv_bridge
00011 from jsk_topic_tools import ConnectionBasedTransport
00012 import message_filters
00013 import numpy as np
00014 import rospy
00015 from sensor_msgs.msg import Image
00016 
00017 
00018 is_torch_available = True
00019 try:
00020     import torch
00021 except ImportError:
00022     is_torch_available = False
00023 
00024 
00025 def assert_torch_available():
00026     if not is_torch_available:
00027         url = 'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl'  # NOQA
00028         raise RuntimeError('Please install pytorch: pip install %s' % url)
00029 
00030 
00031 class FCNObjectSegmentation(ConnectionBasedTransport):
00032 
00033     def __init__(self):
00034         super(self.__class__, self).__init__()
00035         self.backend = rospy.get_param('~backend', 'chainer')
00036         self.gpu = rospy.get_param('~gpu', -1)  # -1 is cpu mode
00037         self.target_names = rospy.get_param('~target_names')
00038         self.bg_label = rospy.get_param('~bg_label', 0)
00039         self.proba_threshold = rospy.get_param('~proba_threshold', 0.0)
00040         self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00041         self._load_model()
00042         self.pub = self.advertise('~output', Image, queue_size=1)
00043         self.pub_proba = self.advertise(
00044             '~output/proba_image', Image, queue_size=1)
00045 
00046     def _load_model(self):
00047         if self.backend == 'chainer':
00048             self._load_chainer_model()
00049         elif self.backend == 'torch':
00050             assert_torch_available()
00051             # we assume input data size won't change in dynamic
00052             torch.backends.cudnn.benchmark = True
00053             self._load_torch_model()
00054         else:
00055             raise RuntimeError('Unsupported backend: %s', self.backend)
00056 
00057     def _load_chainer_model(self):
00058         model_name = rospy.get_param('~model_name')
00059         if rospy.has_param('~model_h5'):
00060             rospy.logwarn('Rosparam ~model_h5 is deprecated,'
00061                           ' and please use ~model_file instead.')
00062             model_file = rospy.get_param('~model_h5')
00063         else:
00064             model_file = rospy.get_param('~model_file')
00065         n_class = len(self.target_names)
00066         if model_name == 'fcn32s':
00067             self.model = fcn.models.FCN32s(n_class=n_class)
00068         elif model_name == 'fcn16s':
00069             self.model = fcn.models.FCN16s(n_class=n_class)
00070         elif model_name == 'fcn8s':
00071             self.model = fcn.models.FCN8s(n_class=n_class)
00072         elif model_name == 'fcn8s_at_once':
00073             self.model = fcn.models.FCN8sAtOnce(n_class=n_class)
00074         else:
00075             raise ValueError('Unsupported ~model_name: {}'.format(model_name))
00076         rospy.loginfo('Loading trained model: {0}'.format(model_file))
00077         if model_file.endswith('.npz'):
00078             S.load_npz(model_file, self.model)
00079         else:
00080             S.load_hdf5(model_file, self.model)
00081         rospy.loginfo('Finished loading trained model: {0}'.format(model_file))
00082         if self.gpu != -1:
00083             self.model.to_gpu(self.gpu)
00084         if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00085             self.model.train = False
00086 
00087     def _load_torch_model(self):
00088         try:
00089             import torchfcn
00090         except ImportError:
00091             raise ImportError('Please install torchfcn: pip install torchfcn')
00092         n_class = len(self.target_names)
00093         model_file = rospy.get_param('~model_file')
00094         model_name = rospy.get_param('~model_name')
00095         if model_name == 'fcn32s':
00096             self.model = torchfcn.models.FCN32s(n_class=n_class)
00097         elif model_name == 'fcn32s_bilinear':
00098             self.model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=True)
00099         else:
00100             raise ValueError('Unsupported ~model_name: {0}'.format(model_name))
00101         rospy.loginfo('Loading trained model: %s' % model_file)
00102         self.model.load_state_dict(torch.load(model_file))
00103         rospy.loginfo('Finished loading trained model: %s' % model_file)
00104         if self.gpu >= 0:
00105             self.model = self.model.cuda(self.gpu)
00106         self.model.eval()
00107 
00108     def subscribe(self):
00109         use_mask = rospy.get_param('~use_mask', False)
00110         if use_mask:
00111             queue_size = rospy.get_param('~queue_size', 10)
00112             sub_img = message_filters.Subscriber(
00113                 '~input', Image, queue_size=1, buff_size=2**24)
00114             sub_mask = message_filters.Subscriber(
00115                 '~input/mask', Image, queue_size=1, buff_size=2**24)
00116             self.subs = [sub_img, sub_mask]
00117             if rospy.get_param('~approximate_sync', False):
00118                 slop = rospy.get_param('~slop', 0.1)
00119                 sync = message_filters.ApproximateTimeSynchronizer(
00120                     fs=self.subs, queue_size=queue_size, slop=slop)
00121             else:
00122                 sync = message_filters.TimeSynchronizer(
00123                     fs=self.subs, queue_size=queue_size)
00124             sync.registerCallback(self._cb_with_mask)
00125         else:
00126             # larger buff_size is necessary for taking time callback
00127             # http://stackoverflow.com/questions/26415699/ros-subscriber-not-up-to-date/29160379#29160379  # NOQA
00128             sub_img = rospy.Subscriber(
00129                 '~input', Image, self._cb, queue_size=1, buff_size=2**24)
00130             self.subs = [sub_img]
00131 
00132     def unsubscribe(self):
00133         for sub in self.subs:
00134             sub.unregister()
00135 
00136     def _cb_with_mask(self, img_msg, mask_msg):
00137         br = cv_bridge.CvBridge()
00138         img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00139         mask = br.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
00140         if mask.ndim > 2:
00141             mask = np.squeeze(mask, axis=2)
00142         label, proba_img = self.segment(img)
00143         label[mask == 0] = 0
00144         proba_img[:, :, 0][mask == 0] = 1
00145         proba_img[:, :, 1:][mask == 0] = 0
00146         label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00147         label_msg.header = img_msg.header
00148         self.pub.publish(label_msg)
00149         proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00150         proba_msg.header = img_msg.header
00151         self.pub_proba.publish(proba_msg)
00152 
00153     def _cb(self, img_msg):
00154         br = cv_bridge.CvBridge()
00155         img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00156         label, proba_img = self.segment(img)
00157         label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00158         label_msg.header = img_msg.header
00159         self.pub.publish(label_msg)
00160         proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00161         proba_msg.header = img_msg.header
00162         self.pub_proba.publish(proba_msg)
00163 
00164     def segment(self, bgr):
00165         if self.backend == 'chainer':
00166             return self._segment_chainer_backend(bgr)
00167         elif self.backend == 'torch':
00168             return self._segment_torch_backend(bgr)
00169         raise ValueError('Unsupported backend: {0}'.format(self.backend))
00170 
00171     def _segment_chainer_backend(self, bgr):
00172         blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00173         x_data = np.array([blob], dtype=np.float32)
00174         if self.gpu != -1:
00175             x_data = cuda.to_gpu(x_data, device=self.gpu)
00176         if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00177             x = chainer.Variable(x_data, volatile=True)
00178             self.model(x)
00179         else:
00180             with chainer.using_config('train', False), \
00181                  chainer.no_backprop_mode():
00182                 x = chainer.Variable(x_data)
00183                 self.model(x)
00184         proba_img = chainer.functions.softmax(self.model.score)
00185         proba_img = chainer.functions.transpose(proba_img, (0, 2, 3, 1))
00186         max_proba_img = chainer.functions.max(proba_img, axis=-1)
00187         label = chainer.functions.argmax(self.model.score, axis=1)
00188         # squeeze batch axis, gpu -> cpu
00189         proba_img = cuda.to_cpu(proba_img.data)[0]
00190         max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
00191         label = cuda.to_cpu(label.data)[0]
00192         # uncertain because the probability is low
00193         label[max_proba_img < self.proba_threshold] = self.bg_label
00194         return label, proba_img
00195 
00196     def _segment_torch_backend(self, bgr):
00197         blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00198         x_data = np.array([blob], dtype=np.float32)
00199         x_data = torch.from_numpy(x_data)
00200         if self.gpu >= 0:
00201             x_data = x_data.cuda(self.gpu)
00202         x = torch.autograd.Variable(x_data, volatile=True)
00203         score = self.model(x)
00204         proba = torch.nn.functional.softmax(score)
00205         max_proba, label = torch.max(proba, 1)
00206         # uncertain because the probability is low
00207         label[max_proba < self.proba_threshold] = self.bg_label
00208         # gpu -> cpu
00209         proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
00210         label = label.data.cpu().numpy().squeeze((0, 1))
00211         return label, proba
00212 
00213 
00214 if __name__ == '__main__':
00215     rospy.init_node('fcn_object_segmentation')
00216     FCNObjectSegmentation()
00217     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07