fcn_object_segmentation.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from distutils.version import LooseVersion
00004 
00005 import chainer
00006 from chainer import cuda
00007 import chainer.serializers as S
00008 import fcn
00009 
00010 import cv_bridge
00011 from jsk_topic_tools import ConnectionBasedTransport
00012 import message_filters
00013 import numpy as np
00014 import rospy
00015 from sensor_msgs.msg import Image
00016 
00017 
00018 is_torch_available = True
00019 try:
00020     import torch
00021 except ImportError:
00022     is_torch_available = False
00023 
00024 
00025 def assert_torch_available():
00026     if not is_torch_available:
00027         url = 'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl'  # NOQA
00028         raise RuntimeError('Please install pytorch: pip install %s' % url)
00029 
00030 
00031 class FCNObjectSegmentation(ConnectionBasedTransport):
00032 
00033     def __init__(self):
00034         super(self.__class__, self).__init__()
00035         self.backend = rospy.get_param('~backend', 'chainer')
00036         self.gpu = rospy.get_param('~gpu', -1)  # -1 is cpu mode
00037         self.target_names = rospy.get_param('~target_names')
00038         self.bg_label = rospy.get_param('~bg_label', 0)
00039         self.proba_threshold = rospy.get_param('~proba_threshold', 0.0)
00040         self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00041         self._load_model()
00042         self.pub = self.advertise('~output', Image, queue_size=1)
00043         self.pub_proba = self.advertise(
00044             '~output/proba_image', Image, queue_size=1)
00045 
00046     def _load_model(self):
00047         if self.backend == 'chainer':
00048             self._load_chainer_model()
00049         elif self.backend == 'torch':
00050             assert_torch_available()
00051             # we assume input data size won't change in dynamic
00052             torch.backends.cudnn.benchmark = True
00053             self._load_torch_model()
00054         else:
00055             raise RuntimeError('Unsupported backend: %s', self.backend)
00056 
00057     def _load_chainer_model(self):
00058         model_name = rospy.get_param('~model_name')
00059         if rospy.has_param('~model_h5'):
00060             rospy.logwarn('Rosparam ~model_h5 is deprecated,'
00061                           ' and please use ~model_file instead.')
00062             model_file = rospy.get_param('~model_h5')
00063         else:
00064             model_file = rospy.get_param('~model_file')
00065         n_class = len(self.target_names)
00066         if model_name == 'fcn32s':
00067             self.model = fcn.models.FCN32s(n_class=n_class)
00068         elif model_name == 'fcn16s':
00069             self.model = fcn.models.FCN16s(n_class=n_class)
00070         elif model_name == 'fcn8s':
00071             self.model = fcn.models.FCN8s(n_class=n_class)
00072         else:
00073             raise ValueError('Unsupported ~model_name: {}'.format(model_name))
00074         rospy.loginfo('Loading trained model: {0}'.format(model_file))
00075         if model_file.endswith('.npz'):
00076             S.load_npz(model_file, self.model)
00077         else:
00078             S.load_hdf5(model_file, self.model)
00079         rospy.loginfo('Finished loading trained model: {0}'.format(model_file))
00080         if self.gpu != -1:
00081             self.model.to_gpu(self.gpu)
00082         if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00083             self.model.train = False
00084 
00085     def _load_torch_model(self):
00086         try:
00087             import torchfcn
00088         except ImportError:
00089             raise ImportError('Please install torchfcn: pip install torchfcn')
00090         n_class = len(self.target_names)
00091         model_file = rospy.get_param('~model_file')
00092         model_name = rospy.get_param('~model_name')
00093         if model_name == 'fcn32s':
00094             self.model = torchfcn.models.FCN32s(n_class=n_class)
00095         elif model_name == 'fcn32s_bilinear':
00096             self.model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=True)
00097         else:
00098             raise ValueError('Unsupported ~model_name: {0}'.format(model_name))
00099         rospy.loginfo('Loading trained model: %s' % model_file)
00100         self.model.load_state_dict(torch.load(model_file))
00101         rospy.loginfo('Finished loading trained model: %s' % model_file)
00102         if self.gpu >= 0:
00103             self.model = self.model.cuda(self.gpu)
00104         self.model.eval()
00105 
00106     def subscribe(self):
00107         use_mask = rospy.get_param('~use_mask', False)
00108         if use_mask:
00109             queue_size = rospy.get_param('~queue_size', 10)
00110             sub_img = message_filters.Subscriber(
00111                 '~input', Image, queue_size=1, buff_size=2**24)
00112             sub_mask = message_filters.Subscriber(
00113                 '~input/mask', Image, queue_size=1, buff_size=2**24)
00114             self.subs = [sub_img, sub_mask]
00115             if rospy.get_param('~approximate_sync', False):
00116                 slop = rospy.get_param('~slop', 0.1)
00117                 sync = message_filters.ApproximateTimeSynchronizer(
00118                     fs=self.subs, queue_size=queue_size, slop=slop)
00119             else:
00120                 sync = message_filters.TimeSynchronizer(
00121                     fs=self.subs, queue_size=queue_size)
00122             sync.registerCallback(self._cb_with_mask)
00123         else:
00124             # larger buff_size is necessary for taking time callback
00125             # http://stackoverflow.com/questions/26415699/ros-subscriber-not-up-to-date/29160379#29160379  # NOQA
00126             sub_img = rospy.Subscriber(
00127                 '~input', Image, self._cb, queue_size=1, buff_size=2**24)
00128             self.subs = [sub_img]
00129 
00130     def unsubscribe(self):
00131         for sub in self.subs:
00132             sub.unregister()
00133 
00134     def _cb_with_mask(self, img_msg, mask_msg):
00135         br = cv_bridge.CvBridge()
00136         img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00137         mask = br.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
00138         if mask.ndim > 2:
00139             mask = np.squeeze(mask, axis=2)
00140         label, proba_img = self.segment(img)
00141         label[mask == 0] = 0
00142         proba_img[:, :, 0][mask == 0] = 1
00143         proba_img[:, :, 1:][mask == 0] = 0
00144         label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00145         label_msg.header = img_msg.header
00146         self.pub.publish(label_msg)
00147         proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00148         proba_msg.header = img_msg.header
00149         self.pub_proba.publish(proba_msg)
00150 
00151     def _cb(self, img_msg):
00152         br = cv_bridge.CvBridge()
00153         img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00154         label, proba_img = self.segment(img)
00155         label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00156         label_msg.header = img_msg.header
00157         self.pub.publish(label_msg)
00158         proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00159         proba_msg.header = img_msg.header
00160         self.pub_proba.publish(proba_msg)
00161 
00162     def segment(self, bgr):
00163         if self.backend == 'chainer':
00164             return self._segment_chainer_backend(bgr)
00165         elif self.backend == 'torch':
00166             return self._segment_torch_backend(bgr)
00167         raise ValueError('Unsupported backend: {0}'.format(self.backend))
00168 
00169     def _segment_chainer_backend(self, bgr):
00170         blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00171         x_data = np.array([blob], dtype=np.float32)
00172         if self.gpu != -1:
00173             x_data = cuda.to_gpu(x_data, device=self.gpu)
00174         if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00175             x = chainer.Variable(x_data, volatile=True)
00176             self.model(x)
00177         else:
00178             with chainer.using_config('train', False), \
00179                  chainer.no_backprop_mode():
00180                 x = chainer.Variable(x_data)
00181                 self.model(x)
00182         proba_img = chainer.functions.softmax(self.model.score)
00183         proba_img = chainer.functions.transpose(proba_img, (0, 2, 3, 1))
00184         max_proba_img = chainer.functions.max(proba_img, axis=-1)
00185         label = chainer.functions.argmax(self.model.score, axis=1)
00186         # squeeze batch axis, gpu -> cpu
00187         proba_img = cuda.to_cpu(proba_img.data)[0]
00188         max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
00189         label = cuda.to_cpu(label.data)[0]
00190         # uncertain because the probability is low
00191         label[max_proba_img < self.proba_threshold] = self.bg_label
00192         return label, proba_img
00193 
00194     def _segment_torch_backend(self, bgr):
00195         blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00196         x_data = np.array([blob], dtype=np.float32)
00197         x_data = torch.from_numpy(x_data)
00198         if self.gpu >= 0:
00199             x_data = x_data.cuda(self.gpu)
00200         x = torch.autograd.Variable(x_data, volatile=True)
00201         score = self.model(x)
00202         proba = torch.nn.functional.softmax(score)
00203         max_proba, label = torch.max(proba, 1)
00204         # uncertain because the probability is low
00205         label[max_proba < self.proba_threshold] = self.bg_label
00206         # gpu -> cpu
00207         proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
00208         label = label.data.cpu().numpy().squeeze((0, 1))
00209         return label, proba
00210 
00211 
00212 if __name__ == '__main__':
00213     rospy.init_node('fcn_object_segmentation')
00214     FCNObjectSegmentation()
00215     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23