00001
00002
00003 from distutils.version import LooseVersion
00004
00005 import chainer
00006 from chainer import cuda
00007 import chainer.serializers as S
00008 import fcn
00009
00010 import cv_bridge
00011 from jsk_topic_tools import ConnectionBasedTransport
00012 import message_filters
00013 import numpy as np
00014 import rospy
00015 from sensor_msgs.msg import Image
00016
00017
00018 is_torch_available = True
00019 try:
00020 import torch
00021 except ImportError:
00022 is_torch_available = False
00023
00024
00025 def assert_torch_available():
00026 if not is_torch_available:
00027 url = 'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl'
00028 raise RuntimeError('Please install pytorch: pip install %s' % url)
00029
00030
00031 class FCNObjectSegmentation(ConnectionBasedTransport):
00032
00033 def __init__(self):
00034 super(self.__class__, self).__init__()
00035 self.backend = rospy.get_param('~backend', 'chainer')
00036 self.gpu = rospy.get_param('~gpu', -1)
00037 self.target_names = rospy.get_param('~target_names')
00038 self.bg_label = rospy.get_param('~bg_label', 0)
00039 self.proba_threshold = rospy.get_param('~proba_threshold', 0.0)
00040 self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00041 self._load_model()
00042 self.pub = self.advertise('~output', Image, queue_size=1)
00043 self.pub_proba = self.advertise(
00044 '~output/proba_image', Image, queue_size=1)
00045
00046 def _load_model(self):
00047 if self.backend == 'chainer':
00048 self._load_chainer_model()
00049 elif self.backend == 'torch':
00050 assert_torch_available()
00051
00052 torch.backends.cudnn.benchmark = True
00053 self._load_torch_model()
00054 else:
00055 raise RuntimeError('Unsupported backend: %s', self.backend)
00056
00057 def _load_chainer_model(self):
00058 model_name = rospy.get_param('~model_name')
00059 if rospy.has_param('~model_h5'):
00060 rospy.logwarn('Rosparam ~model_h5 is deprecated,'
00061 ' and please use ~model_file instead.')
00062 model_file = rospy.get_param('~model_h5')
00063 else:
00064 model_file = rospy.get_param('~model_file')
00065 n_class = len(self.target_names)
00066 if model_name == 'fcn32s':
00067 self.model = fcn.models.FCN32s(n_class=n_class)
00068 elif model_name == 'fcn16s':
00069 self.model = fcn.models.FCN16s(n_class=n_class)
00070 elif model_name == 'fcn8s':
00071 self.model = fcn.models.FCN8s(n_class=n_class)
00072 else:
00073 raise ValueError('Unsupported ~model_name: {}'.format(model_name))
00074 rospy.loginfo('Loading trained model: {0}'.format(model_file))
00075 if model_file.endswith('.npz'):
00076 S.load_npz(model_file, self.model)
00077 else:
00078 S.load_hdf5(model_file, self.model)
00079 rospy.loginfo('Finished loading trained model: {0}'.format(model_file))
00080 if self.gpu != -1:
00081 self.model.to_gpu(self.gpu)
00082 if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00083 self.model.train = False
00084
00085 def _load_torch_model(self):
00086 try:
00087 import torchfcn
00088 except ImportError:
00089 raise ImportError('Please install torchfcn: pip install torchfcn')
00090 n_class = len(self.target_names)
00091 model_file = rospy.get_param('~model_file')
00092 model_name = rospy.get_param('~model_name')
00093 if model_name == 'fcn32s':
00094 self.model = torchfcn.models.FCN32s(n_class=n_class)
00095 elif model_name == 'fcn32s_bilinear':
00096 self.model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=True)
00097 else:
00098 raise ValueError('Unsupported ~model_name: {0}'.format(model_name))
00099 rospy.loginfo('Loading trained model: %s' % model_file)
00100 self.model.load_state_dict(torch.load(model_file))
00101 rospy.loginfo('Finished loading trained model: %s' % model_file)
00102 if self.gpu >= 0:
00103 self.model = self.model.cuda(self.gpu)
00104 self.model.eval()
00105
00106 def subscribe(self):
00107 use_mask = rospy.get_param('~use_mask', False)
00108 if use_mask:
00109 queue_size = rospy.get_param('~queue_size', 10)
00110 sub_img = message_filters.Subscriber(
00111 '~input', Image, queue_size=1, buff_size=2**24)
00112 sub_mask = message_filters.Subscriber(
00113 '~input/mask', Image, queue_size=1, buff_size=2**24)
00114 self.subs = [sub_img, sub_mask]
00115 if rospy.get_param('~approximate_sync', False):
00116 slop = rospy.get_param('~slop', 0.1)
00117 sync = message_filters.ApproximateTimeSynchronizer(
00118 fs=self.subs, queue_size=queue_size, slop=slop)
00119 else:
00120 sync = message_filters.TimeSynchronizer(
00121 fs=self.subs, queue_size=queue_size)
00122 sync.registerCallback(self._cb_with_mask)
00123 else:
00124
00125
00126 sub_img = rospy.Subscriber(
00127 '~input', Image, self._cb, queue_size=1, buff_size=2**24)
00128 self.subs = [sub_img]
00129
00130 def unsubscribe(self):
00131 for sub in self.subs:
00132 sub.unregister()
00133
00134 def _cb_with_mask(self, img_msg, mask_msg):
00135 br = cv_bridge.CvBridge()
00136 img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00137 mask = br.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
00138 if mask.ndim > 2:
00139 mask = np.squeeze(mask, axis=2)
00140 label, proba_img = self.segment(img)
00141 label[mask == 0] = 0
00142 proba_img[:, :, 0][mask == 0] = 1
00143 proba_img[:, :, 1:][mask == 0] = 0
00144 label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00145 label_msg.header = img_msg.header
00146 self.pub.publish(label_msg)
00147 proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00148 proba_msg.header = img_msg.header
00149 self.pub_proba.publish(proba_msg)
00150
00151 def _cb(self, img_msg):
00152 br = cv_bridge.CvBridge()
00153 img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00154 label, proba_img = self.segment(img)
00155 label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00156 label_msg.header = img_msg.header
00157 self.pub.publish(label_msg)
00158 proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00159 proba_msg.header = img_msg.header
00160 self.pub_proba.publish(proba_msg)
00161
00162 def segment(self, bgr):
00163 if self.backend == 'chainer':
00164 return self._segment_chainer_backend(bgr)
00165 elif self.backend == 'torch':
00166 return self._segment_torch_backend(bgr)
00167 raise ValueError('Unsupported backend: {0}'.format(self.backend))
00168
00169 def _segment_chainer_backend(self, bgr):
00170 blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00171 x_data = np.array([blob], dtype=np.float32)
00172 if self.gpu != -1:
00173 x_data = cuda.to_gpu(x_data, device=self.gpu)
00174 if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00175 x = chainer.Variable(x_data, volatile=True)
00176 self.model(x)
00177 else:
00178 with chainer.using_config('train', False), \
00179 chainer.no_backprop_mode():
00180 x = chainer.Variable(x_data)
00181 self.model(x)
00182 proba_img = chainer.functions.softmax(self.model.score)
00183 proba_img = chainer.functions.transpose(proba_img, (0, 2, 3, 1))
00184 max_proba_img = chainer.functions.max(proba_img, axis=-1)
00185 label = chainer.functions.argmax(self.model.score, axis=1)
00186
00187 proba_img = cuda.to_cpu(proba_img.data)[0]
00188 max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
00189 label = cuda.to_cpu(label.data)[0]
00190
00191 label[max_proba_img < self.proba_threshold] = self.bg_label
00192 return label, proba_img
00193
00194 def _segment_torch_backend(self, bgr):
00195 blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00196 x_data = np.array([blob], dtype=np.float32)
00197 x_data = torch.from_numpy(x_data)
00198 if self.gpu >= 0:
00199 x_data = x_data.cuda(self.gpu)
00200 x = torch.autograd.Variable(x_data, volatile=True)
00201 score = self.model(x)
00202 proba = torch.nn.functional.softmax(score)
00203 max_proba, label = torch.max(proba, 1)
00204
00205 label[max_proba < self.proba_threshold] = self.bg_label
00206
00207 proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
00208 label = label.data.cpu().numpy().squeeze((0, 1))
00209 return label, proba
00210
00211
00212 if __name__ == '__main__':
00213 rospy.init_node('fcn_object_segmentation')
00214 FCNObjectSegmentation()
00215 rospy.spin()