00001
00002
00003 from distutils.version import LooseVersion
00004
00005 import chainer
00006 from chainer import cuda
00007 import chainer.serializers as S
00008 import fcn
00009
00010 import cv_bridge
00011 from jsk_topic_tools import ConnectionBasedTransport
00012 import message_filters
00013 import numpy as np
00014 import rospy
00015 from sensor_msgs.msg import Image
00016
00017
00018 is_torch_available = True
00019 try:
00020 import torch
00021 except ImportError:
00022 is_torch_available = False
00023
00024
00025 def assert_torch_available():
00026 if not is_torch_available:
00027 url = 'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl'
00028 raise RuntimeError('Please install pytorch: pip install %s' % url)
00029
00030
00031 class FCNObjectSegmentation(ConnectionBasedTransport):
00032
00033 def __init__(self):
00034 super(self.__class__, self).__init__()
00035 self.backend = rospy.get_param('~backend', 'chainer')
00036 self.gpu = rospy.get_param('~gpu', -1)
00037 self.target_names = rospy.get_param('~target_names')
00038 self.bg_label = rospy.get_param('~bg_label', 0)
00039 self.proba_threshold = rospy.get_param('~proba_threshold', 0.0)
00040 self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00041 self._load_model()
00042 self.pub = self.advertise('~output', Image, queue_size=1)
00043 self.pub_proba = self.advertise(
00044 '~output/proba_image', Image, queue_size=1)
00045
00046 def _load_model(self):
00047 if self.backend == 'chainer':
00048 self._load_chainer_model()
00049 elif self.backend == 'torch':
00050 assert_torch_available()
00051
00052 torch.backends.cudnn.benchmark = True
00053 self._load_torch_model()
00054 else:
00055 raise RuntimeError('Unsupported backend: %s', self.backend)
00056
00057 def _load_chainer_model(self):
00058 model_name = rospy.get_param('~model_name')
00059 if rospy.has_param('~model_h5'):
00060 rospy.logwarn('Rosparam ~model_h5 is deprecated,'
00061 ' and please use ~model_file instead.')
00062 model_file = rospy.get_param('~model_h5')
00063 else:
00064 model_file = rospy.get_param('~model_file')
00065 n_class = len(self.target_names)
00066 if model_name == 'fcn32s':
00067 self.model = fcn.models.FCN32s(n_class=n_class)
00068 elif model_name == 'fcn16s':
00069 self.model = fcn.models.FCN16s(n_class=n_class)
00070 elif model_name == 'fcn8s':
00071 self.model = fcn.models.FCN8s(n_class=n_class)
00072 elif model_name == 'fcn8s_at_once':
00073 self.model = fcn.models.FCN8sAtOnce(n_class=n_class)
00074 else:
00075 raise ValueError('Unsupported ~model_name: {}'.format(model_name))
00076 rospy.loginfo('Loading trained model: {0}'.format(model_file))
00077 if model_file.endswith('.npz'):
00078 S.load_npz(model_file, self.model)
00079 else:
00080 S.load_hdf5(model_file, self.model)
00081 rospy.loginfo('Finished loading trained model: {0}'.format(model_file))
00082 if self.gpu != -1:
00083 self.model.to_gpu(self.gpu)
00084 if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00085 self.model.train = False
00086
00087 def _load_torch_model(self):
00088 try:
00089 import torchfcn
00090 except ImportError:
00091 raise ImportError('Please install torchfcn: pip install torchfcn')
00092 n_class = len(self.target_names)
00093 model_file = rospy.get_param('~model_file')
00094 model_name = rospy.get_param('~model_name')
00095 if model_name == 'fcn32s':
00096 self.model = torchfcn.models.FCN32s(n_class=n_class)
00097 elif model_name == 'fcn32s_bilinear':
00098 self.model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=True)
00099 else:
00100 raise ValueError('Unsupported ~model_name: {0}'.format(model_name))
00101 rospy.loginfo('Loading trained model: %s' % model_file)
00102 self.model.load_state_dict(torch.load(model_file))
00103 rospy.loginfo('Finished loading trained model: %s' % model_file)
00104 if self.gpu >= 0:
00105 self.model = self.model.cuda(self.gpu)
00106 self.model.eval()
00107
00108 def subscribe(self):
00109 use_mask = rospy.get_param('~use_mask', False)
00110 if use_mask:
00111 queue_size = rospy.get_param('~queue_size', 10)
00112 sub_img = message_filters.Subscriber(
00113 '~input', Image, queue_size=1, buff_size=2**24)
00114 sub_mask = message_filters.Subscriber(
00115 '~input/mask', Image, queue_size=1, buff_size=2**24)
00116 self.subs = [sub_img, sub_mask]
00117 if rospy.get_param('~approximate_sync', False):
00118 slop = rospy.get_param('~slop', 0.1)
00119 sync = message_filters.ApproximateTimeSynchronizer(
00120 fs=self.subs, queue_size=queue_size, slop=slop)
00121 else:
00122 sync = message_filters.TimeSynchronizer(
00123 fs=self.subs, queue_size=queue_size)
00124 sync.registerCallback(self._cb_with_mask)
00125 else:
00126
00127
00128 sub_img = rospy.Subscriber(
00129 '~input', Image, self._cb, queue_size=1, buff_size=2**24)
00130 self.subs = [sub_img]
00131
00132 def unsubscribe(self):
00133 for sub in self.subs:
00134 sub.unregister()
00135
00136 def _cb_with_mask(self, img_msg, mask_msg):
00137 br = cv_bridge.CvBridge()
00138 img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00139 mask = br.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
00140 if mask.ndim > 2:
00141 mask = np.squeeze(mask, axis=2)
00142 label, proba_img = self.segment(img)
00143 label[mask == 0] = 0
00144 proba_img[:, :, 0][mask == 0] = 1
00145 proba_img[:, :, 1:][mask == 0] = 0
00146 label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00147 label_msg.header = img_msg.header
00148 self.pub.publish(label_msg)
00149 proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00150 proba_msg.header = img_msg.header
00151 self.pub_proba.publish(proba_msg)
00152
00153 def _cb(self, img_msg):
00154 br = cv_bridge.CvBridge()
00155 img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00156 label, proba_img = self.segment(img)
00157 label_msg = br.cv2_to_imgmsg(label.astype(np.int32), '32SC1')
00158 label_msg.header = img_msg.header
00159 self.pub.publish(label_msg)
00160 proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
00161 proba_msg.header = img_msg.header
00162 self.pub_proba.publish(proba_msg)
00163
00164 def segment(self, bgr):
00165 if self.backend == 'chainer':
00166 return self._segment_chainer_backend(bgr)
00167 elif self.backend == 'torch':
00168 return self._segment_torch_backend(bgr)
00169 raise ValueError('Unsupported backend: {0}'.format(self.backend))
00170
00171 def _segment_chainer_backend(self, bgr):
00172 blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00173 x_data = np.array([blob], dtype=np.float32)
00174 if self.gpu != -1:
00175 x_data = cuda.to_gpu(x_data, device=self.gpu)
00176 if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00177 x = chainer.Variable(x_data, volatile=True)
00178 self.model(x)
00179 else:
00180 with chainer.using_config('train', False), \
00181 chainer.no_backprop_mode():
00182 x = chainer.Variable(x_data)
00183 self.model(x)
00184 proba_img = chainer.functions.softmax(self.model.score)
00185 proba_img = chainer.functions.transpose(proba_img, (0, 2, 3, 1))
00186 max_proba_img = chainer.functions.max(proba_img, axis=-1)
00187 label = chainer.functions.argmax(self.model.score, axis=1)
00188
00189 proba_img = cuda.to_cpu(proba_img.data)[0]
00190 max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
00191 label = cuda.to_cpu(label.data)[0]
00192
00193 label[max_proba_img < self.proba_threshold] = self.bg_label
00194 return label, proba_img
00195
00196 def _segment_torch_backend(self, bgr):
00197 blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00198 x_data = np.array([blob], dtype=np.float32)
00199 x_data = torch.from_numpy(x_data)
00200 if self.gpu >= 0:
00201 x_data = x_data.cuda(self.gpu)
00202 x = torch.autograd.Variable(x_data, volatile=True)
00203 score = self.model(x)
00204 proba = torch.nn.functional.softmax(score)
00205 max_proba, label = torch.max(proba, 1)
00206
00207 label[max_proba < self.proba_threshold] = self.bg_label
00208
00209 proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
00210 label = label.data.cpu().numpy().squeeze((0, 1))
00211 return label, proba
00212
00213
00214 if __name__ == '__main__':
00215 rospy.init_node('fcn_object_segmentation')
00216 FCNObjectSegmentation()
00217 rospy.spin()