vgg16_object_recognition.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 
7 
8 import itertools, pkg_resources, sys
9 from distutils.version import LooseVersion
10 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
11  sys.version_info.major == 2:
12  print('''Please install chainer < 7.0.0:
13 
14  sudo pip install chainer==6.7.0
15 
16 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
17 ''', file=sys.stderr)
18  sys.exit(1)
19 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
20  print('''Please install CuPy
21 
22  sudo pip install cupy-cuda[your cuda version]
23 i.e.
24  sudo pip install cupy-cuda91
25 
26 ''', file=sys.stderr)
27  # sys.exit(1)
28 import chainer
29 from chainer import cuda
30 import chainer.serializers as S
31 from chainer import Variable
32 from distutils.version import LooseVersion
33 import numpy as np
34 import skimage.transform
35 
36 import cv_bridge
37 from jsk_recognition_msgs.msg import ClassificationResult
39 from jsk_recognition_utils.chainermodels import VGG16BatchNormalization
40 from jsk_topic_tools import ConnectionBasedTransport
41 from jsk_topic_tools.log_utils import logerr_throttle
42 import message_filters
43 import rospy
44 from sensor_msgs.msg import Image
45 
46 
47 class VGG16ObjectRecognition(ConnectionBasedTransport):
48 
49  mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
50 
51  def __init__(self):
52  super(self.__class__, self).__init__()
53  self.insize = 224
54  self.gpu = rospy.get_param('~gpu', -1)
55  self.target_names = rospy.get_param('~target_names')
56  self.model_name = rospy.get_param('~model_name')
57  if self.model_name == 'vgg16':
58  self.model = VGG16(n_class=len(self.target_names))
59  elif self.model_name == 'vgg16_batch_normalization':
60  self.model = VGG16BatchNormalization(
61  n_class=len(self.target_names))
62  else:
63  rospy.logerr('Unsupported ~model_name: {0}'
64  .format(self.model_name))
65  model_file = rospy.get_param('~model_file')
66  S.load_hdf5(model_file, self.model)
67  if self.gpu != -1:
68  self.model.to_gpu(self.gpu)
69  self.pub = self.advertise('~output', ClassificationResult,
70  queue_size=1)
71  self.pub_input = self.advertise(
72  '~debug/net_input', Image, queue_size=1)
73 
74  def subscribe(self):
75  if rospy.get_param('~use_mask', False):
76  # larger buff_size is necessary for taking time callback
77  # http://stackoverflow.com/questions/26415699/ros-subscriber-not-up-to-date/29160379#29160379 # NOQA
79  '~input', Image, queue_size=1, buff_size=2**24)
80  sub_mask = message_filters.Subscriber(
81  '~input/mask', Image, queue_size=1, buff_size=2**24)
82  self.subs = [sub, sub_mask]
83  queue_size = rospy.get_param('~queue_size', 10)
84  if rospy.get_param('~approximate_sync', False):
85  slop = rospy.get_param('~slop', 0.1)
86  sync = message_filters.ApproximateTimeSynchronizer(
87  self.subs, queue_size=queue_size, slop=slop)
88  else:
90  self.subs, queue_size=queue_size)
91  sync.registerCallback(self._recognize)
92  else:
93  sub = rospy.Subscriber(
94  '~input', Image, self._recognize, callback_args=None,
95  queue_size=1, buff_size=2**24)
96  self.subs = [sub]
97 
98  def unsubscribe(self):
99  for sub in self.subs:
100  sub.unregister()
101 
102  def _recognize(self, imgmsg, mask_msg=None):
103  bridge = cv_bridge.CvBridge()
104  bgr = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='bgr8')
105  if mask_msg is not None:
106  mask = bridge.imgmsg_to_cv2(mask_msg)
107  if mask.shape != bgr.shape[:2]:
108  logerr_throttle(10,
109  'Size of input image and mask is different')
110  return
111  elif mask.size == 0:
112  logerr_throttle(10, 'Size of input mask is 0')
113  return
114  bgr[mask == 0] = self.mean_bgr
115  bgr = skimage.transform.resize(
116  bgr, (self.insize, self.insize), preserve_range=True)
117  input_msg = bridge.cv2_to_imgmsg(bgr.astype(np.uint8), encoding='bgr8')
118  input_msg.header = imgmsg.header
119  self.pub_input.publish(input_msg)
120 
121  blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
122  x_data = np.array([blob], dtype=np.float32)
123  if self.gpu != -1:
124  x_data = cuda.to_gpu(x_data, device=self.gpu)
125  if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
126  x = Variable(x_data, volatile=True)
127  self.model.train = False
128  self.model(x)
129  else:
130  with chainer.using_config('train', False), \
131  chainer.no_backprop_mode():
132  x = Variable(x_data)
133  self.model(x)
134 
135  proba = cuda.to_cpu(self.model.pred.data)[0]
136  label = np.argmax(proba)
137  label_name = self.target_names[label]
138  label_proba = proba[label]
139  cls_msg = ClassificationResult(
140  header=imgmsg.header,
141  labels=[label],
142  label_names=[label_name],
143  label_proba=[label_proba],
144  probabilities=proba,
145  classifier=self.model_name,
146  target_names=self.target_names,
147  )
148  self.pub.publish(cls_msg)
149 
150 
151 if __name__ == '__main__':
152  rospy.init_node('vgg16_object_recognition')
154  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27