regional_feature_based_object_recognition.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 import os.path as osp
6 
7 import itertools, pkg_resources, sys
8 from distutils.version import LooseVersion
9 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
10  sys.version_info.major == 2:
11  print('''Please install chainer < 7.0.0:
12 
13  sudo pip install chainer==6.7.0
14 
15 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
16 ''', file=sys.stderr)
17  sys.exit(1)
18 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
19  print('''Please install CuPy
20 
21  sudo pip install cupy-cuda[your cuda version]
22 i.e.
23  sudo pip install cupy-cuda91
24 
25 ''', file=sys.stderr)
26  # sys.exit(1)
27 import chainer
28 from chainer import cuda
29 import numpy as np
30 from sklearn.neighbors import KNeighborsClassifier
31 
32 import cv_bridge
33 from jsk_recognition_msgs.msg import ClassificationResult
34 from jsk_recognition_utils.chainermodels import ResNet152
35 from jsk_recognition_utils.chainermodels import ResNet152Feature
36 from jsk_topic_tools import ConnectionBasedTransport
37 import message_filters
38 import rospy
39 from sensor_msgs.msg import Image
40 
41 
42 import rospkg
43 PKG_PATH = rospkg.RosPack().get_path('jsk_perception')
44 
45 
46 class RegionalFeatureBasedObjectRecognition(ConnectionBasedTransport):
47 
48  def __init__(self):
49  super(RegionalFeatureBasedObjectRecognition, self).__init__()
50  # parameters
51  db_file = rospy.get_param('~db_file')
52  self.gpu = rospy.get_param('~gpu', 0)
53  # setup chainer
54  chainer.global_config.train = False
55  chainer.global_config.enable_backprop = False
56  # model
57  pretrained_model = osp.join(
58  PKG_PATH, 'trained_data/resnet152_from_caffe.npz')
59  rospy.loginfo('Loading pretrained model: %s' % pretrained_model)
60  # TODO(wkentaro): Support Resnet50/101
61  self.model = ResNet152Feature()
62  chainer.serializers.load_npz(pretrained_model, self.model)
63  if self.gpu >= 0:
64  chainer.cuda.get_device_from_id(self.gpu).use()
65  self.model.to_gpu()
66  rospy.loginfo('Finished loading pretrained model')
67  # mean
68  mean_file = osp.join(
69  PKG_PATH, 'trained_data/resnet_lsvrc2012_mean.npy')
70  self.mean = np.load(mean_file)
71  assert self.mean.shape == (224, 224, 3) # BGR order
72  # knn
73  rospy.loginfo('Fitting KNN from db')
74  db = np.load(db_file)
75  X, y, self.target_names = db['X'], db['y'], db['target_names']
76  self.knn = KNeighborsClassifier(n_neighbors=10)
77  self.knn.fit(X, y)
78  rospy.loginfo('Finished fitting KNN from db')
79  # setup publishers
80  self.pub = self.advertise(
81  '~output', ClassificationResult, queue_size=1)
82 
83  def subscribe(self):
84  self.subs = []
85  self.subs.append(message_filters.Subscriber('~input', Image))
86  self.subs.append(message_filters.Subscriber('~input/mask', Image))
87  queue_size = 100
88  slop = 0.1
89  self.sync = message_filters.ApproximateTimeSynchronizer(
90  self.subs, queue_size=queue_size, slop=slop)
91  self.sync.registerCallback(self.callback)
92 
93  def unsubscribe(self):
94  for sub in self.subs:
95  sub.unregister()
96 
97  def callback(self, imgmsg, mask_msg):
98  bridge = cv_bridge.CvBridge()
99  img = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='bgr8')
100  mask = bridge.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
101  if mask.ndim == 3:
102  mask = np.squeeze(mask, axis=2)
103  mask = mask >= 127 # uint8 -> bool
104 
105  img = img.astype(np.float64)
106  img[mask] -= self.mean[mask]
107  img[~mask] = 0
108 
109  img = img.transpose(2, 0, 1)
110  img = img.astype(np.float32)
111  x_data = np.asarray([img])
112  if self.gpu >= 0:
113  x_data = cuda.to_gpu(x_data)
114  x = chainer.Variable(x_data)
115  y = self.model(x)
116 
117  feat = cuda.to_cpu(y.data)
118  feat = feat.squeeze(axis=(2, 3))
119  X_query = feat
120 
121  y_pred_proba = self.knn.predict_proba(X_query)
122  y_pred = np.argmax(y_pred_proba, axis=1)
123 
124  classes = self.knn.classes_
125  target_names = self.target_names[classes]
126 
127  msg = ClassificationResult()
128  msg.header = imgmsg.header
129  msg.labels = y_pred.tolist()
130  msg.label_names = target_names[y_pred].tolist()
131  msg.label_proba = y_pred_proba[:, y_pred].flatten().tolist()
132  msg.probabilities = y_pred_proba.flatten().tolist()
133  msg.target_names = target_names.tolist()
134  self.pub.publish(msg)
135 
136 
137 if __name__ == '__main__':
138  rospy.init_node('regional_feature_based_object_recognition')
140  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27