3 from __future__
import print_function
5 from distutils.version
import LooseVersion
7 import itertools, pkg_resources, sys
8 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0')
and \
9 sys.version_info.major == 2:
10 print(
'''Please install chainer < 7.0.0: 12 sudo pip install chainer==6.7.0 14 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 17 if [p
for p
in list(itertools.chain(*[pkg_resources.find_distributions(_)
for _
in sys.path]))
if "cupy-" in p.project_name ] == []:
18 print(
'''Please install CuPy 20 sudo pip install cupy-cuda[your cuda version] 22 sudo pip install cupy-cuda91 27 from chainer
import cuda
28 import chainer.serializers
as S
32 from jsk_topic_tools
import ConnectionBasedTransport
33 import message_filters
36 from sensor_msgs.msg
import Image
39 is_torch_available =
True 43 is_torch_available =
False 47 if not is_torch_available:
48 url =
'http://download.pytorch.org/whl/cu80/torch-0.1.11.post4-cp27-none-linux_x86_64.whl' 49 raise RuntimeError(
'Please install pytorch: pip install %s' % url)
55 super(self.__class__, self).
__init__()
56 self.
backend = rospy.get_param(
'~backend',
'chainer')
57 self.
gpu = rospy.get_param(
'~gpu', -1)
59 self.
bg_label = rospy.get_param(
'~bg_label', 0)
61 self.
mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
63 self.
pub = self.advertise(
'~output', Image, queue_size=1)
65 '~output/proba_image', Image, queue_size=1)
73 torch.backends.cudnn.benchmark =
True 76 raise RuntimeError(
'Unsupported backend: %s', self.
backend)
79 model_name = rospy.get_param(
'~model_name')
80 if rospy.has_param(
'~model_h5'):
81 rospy.logwarn(
'Rosparam ~model_h5 is deprecated,' 82 ' and please use ~model_file instead.')
83 model_file = rospy.get_param(
'~model_h5')
85 model_file = rospy.get_param(
'~model_file')
87 if model_name ==
'fcn32s':
88 self.
model = fcn.models.FCN32s(n_class=n_class)
89 elif model_name ==
'fcn16s':
90 self.
model = fcn.models.FCN16s(n_class=n_class)
91 elif model_name ==
'fcn8s':
92 self.
model = fcn.models.FCN8s(n_class=n_class)
93 elif model_name ==
'fcn8s_at_once':
94 self.
model = fcn.models.FCN8sAtOnce(n_class=n_class)
96 raise ValueError(
'Unsupported ~model_name: {}'.format(model_name))
97 rospy.loginfo(
'Loading trained model: {0}'.format(model_file))
98 if model_file.endswith(
'.npz'):
99 S.load_npz(model_file, self.
model)
101 S.load_hdf5(model_file, self.
model)
102 rospy.loginfo(
'Finished loading trained model: {0}'.format(model_file))
104 self.model.to_gpu(self.
gpu)
105 if LooseVersion(chainer.__version__) < LooseVersion(
'2.0.0'):
106 self.model.train =
False 112 raise ImportError(
'Please install torchfcn: pip install torchfcn')
114 model_file = rospy.get_param(
'~model_file')
115 model_name = rospy.get_param(
'~model_name')
116 if model_name ==
'fcn32s':
117 self.
model = torchfcn.models.FCN32s(n_class=n_class)
118 elif model_name ==
'fcn32s_bilinear':
119 self.
model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=
True)
121 raise ValueError(
'Unsupported ~model_name: {0}'.format(model_name))
122 rospy.loginfo(
'Loading trained model: %s' % model_file)
123 self.model.load_state_dict(torch.load(model_file))
124 rospy.loginfo(
'Finished loading trained model: %s' % model_file)
126 self.
model = self.model.cuda(self.
gpu)
130 use_mask = rospy.get_param(
'~use_mask',
False)
132 queue_size = rospy.get_param(
'~queue_size', 10)
134 '~input', Image, queue_size=1, buff_size=2**24)
136 '~input/mask', Image, queue_size=1, buff_size=2**24)
137 self.
subs = [sub_img, sub_mask]
138 if rospy.get_param(
'~approximate_sync',
False):
139 slop = rospy.get_param(
'~slop', 0.1)
140 sync = message_filters.ApproximateTimeSynchronizer(
141 fs=self.
subs, queue_size=queue_size, slop=slop)
144 fs=self.
subs, queue_size=queue_size)
149 sub_img = rospy.Subscriber(
150 '~input', Image, self.
_cb, queue_size=1, buff_size=2**24)
151 self.
subs = [sub_img]
154 for sub
in self.
subs:
158 br = cv_bridge.CvBridge()
159 img = br.imgmsg_to_cv2(img_msg, desired_encoding=
'bgr8')
160 mask = br.imgmsg_to_cv2(mask_msg, desired_encoding=
'mono8')
162 mask = np.squeeze(mask, axis=2)
163 label, proba_img = self.
segment(img)
165 proba_img[:, :, 0][mask == 0] = 1
166 proba_img[:, :, 1:][mask == 0] = 0
167 label_msg = br.cv2_to_imgmsg(label.astype(np.int32),
'32SC1')
168 label_msg.header = img_msg.header
169 self.pub.publish(label_msg)
170 proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
171 proba_msg.header = img_msg.header
172 self.pub_proba.publish(proba_msg)
175 br = cv_bridge.CvBridge()
176 img = br.imgmsg_to_cv2(img_msg, desired_encoding=
'bgr8')
177 label, proba_img = self.
segment(img)
178 label_msg = br.cv2_to_imgmsg(label.astype(np.int32),
'32SC1')
179 label_msg.header = img_msg.header
180 self.pub.publish(label_msg)
181 proba_msg = br.cv2_to_imgmsg(proba_img.astype(np.float32))
182 proba_msg.header = img_msg.header
183 self.pub_proba.publish(proba_msg)
190 raise ValueError(
'Unsupported backend: {0}'.format(self.
backend))
193 blob = (bgr - self.
mean_bgr).transpose((2, 0, 1))
194 x_data = np.array([blob], dtype=np.float32)
196 x_data = cuda.to_gpu(x_data, device=self.
gpu)
197 if LooseVersion(chainer.__version__) < LooseVersion(
'2.0.0'):
198 x = chainer.Variable(x_data, volatile=
True)
201 with chainer.using_config(
'train',
False), \
202 chainer.no_backprop_mode():
203 x = chainer.Variable(x_data)
205 proba_img = chainer.functions.softmax(self.model.score)
206 proba_img = chainer.functions.transpose(proba_img, (0, 2, 3, 1))
207 max_proba_img = chainer.functions.max(proba_img, axis=-1)
208 label = chainer.functions.argmax(self.model.score, axis=1)
210 proba_img = cuda.to_cpu(proba_img.data)[0]
211 max_proba_img = cuda.to_cpu(max_proba_img.data)[0]
212 label = cuda.to_cpu(label.data)[0]
215 return label, proba_img
218 blob = (bgr - self.
mean_bgr).transpose((2, 0, 1))
219 x_data = np.array([blob], dtype=np.float32)
220 x_data = torch.from_numpy(x_data)
222 x_data = x_data.cuda(self.
gpu)
223 x = torch.autograd.Variable(x_data, volatile=
True)
224 score = self.
model(x)
225 proba = torch.nn.functional.softmax(score)
226 max_proba, label = torch.max(proba, 1)
230 proba = proba.permute(0, 2, 3, 1).data.cpu().numpy()[0]
231 label = label.data.cpu().numpy().squeeze((0, 1))
235 if __name__ ==
'__main__':
236 rospy.init_node(
'fcn_object_segmentation')
def _load_chainer_model(self)
def _cb_with_mask(self, img_msg, mask_msg)
def _load_torch_model(self)
def assert_torch_available()
def _segment_chainer_backend(self, bgr)
def _segment_torch_backend(self, bgr)