regional_feature_based_object_recognition.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 import os.path as osp
6 
7 import itertools, pkg_resources, sys
8 from distutils.version import LooseVersion
9 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
10  sys.version_info.major == 2:
11  print('''Please install chainer < 7.0.0:
12 
13  sudo pip install chainer==6.7.0
14 
15 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
16 ''', file=sys.stderr)
17  sys.exit(1)
18 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name or "cupy" == p.project_name ] == []:
19  print('''Please install CuPy
20 
21  sudo pip install cupy-cuda[your cuda version]
22 i.e.
23  sudo pip install cupy-cuda91
24 
25 ''', file=sys.stderr)
26  # sys.exit(1)
27 import chainer
28 from chainer import cuda
29 import numpy as np
30 from sklearn.neighbors import KNeighborsClassifier
31 
32 import cv_bridge
33 from jsk_recognition_msgs.msg import ClassificationResult
34 from jsk_recognition_utils.chainermodels import ResNet152
35 from jsk_recognition_utils.chainermodels import ResNet152Feature
36 from jsk_topic_tools import ConnectionBasedTransport
37 import message_filters
38 import rospy
39 from sensor_msgs.msg import Image
40 
41 
42 import rospkg
43 PKG_PATH = rospkg.RosPack().get_path('jsk_perception')
44 
45 
46 class RegionalFeatureBasedObjectRecognition(ConnectionBasedTransport):
47 
48  def __init__(self):
49  super(RegionalFeatureBasedObjectRecognition, self).__init__()
50  # parameters
51  db_file = rospy.get_param('~db_file')
52  self.gpu = rospy.get_param('~gpu', 0)
53  # setup chainer
54  chainer.global_config.train = False
55  chainer.global_config.enable_backprop = False
56  # model
57  pretrained_model = osp.join(
58  PKG_PATH, 'trained_data/resnet152_from_caffe.npz')
59  rospy.loginfo('Loading pretrained model: %s' % pretrained_model)
60  # TODO(wkentaro): Support Resnet50/101
61  self.model = ResNet152Feature()
62  chainer.serializers.load_npz(pretrained_model, self.model)
63  if self.gpu >= 0:
64  chainer.cuda.get_device_from_id(self.gpu).use()
65  self.model.to_gpu()
66  rospy.loginfo('Finished loading pretrained model')
67  # mean
68  mean_file = osp.join(
69  PKG_PATH, 'trained_data/resnet_lsvrc2012_mean.npy')
70  self.mean = np.load(mean_file)
71  assert self.mean.shape == (224, 224, 3) # BGR order
72  # knn
73  rospy.loginfo('Fitting KNN from db')
74  db = np.load(db_file)
75  X, y, self.target_names = db['X'], db['y'], db['target_names']
76  self.target_names = np.array(list(map(lambda name: name.decode('utf-8') if hasattr(name, 'decode') else name, self.target_names)))
77  self.knn = KNeighborsClassifier(n_neighbors=10)
78  self.knn.fit(X, y)
79  rospy.loginfo('Finished fitting KNN from db')
80  # setup publishers
81  self.pub = self.advertise(
82  '~output', ClassificationResult, queue_size=1)
83 
84  def subscribe(self):
85  self.subs = []
86  self.subs.append(message_filters.Subscriber('~input', Image))
87  self.subs.append(message_filters.Subscriber('~input/mask', Image))
88  queue_size = 100
89  slop = 0.1
90  self.sync = message_filters.ApproximateTimeSynchronizer(
91  self.subs, queue_size=queue_size, slop=slop)
92  self.sync.registerCallback(self.callback)
93 
94  def unsubscribe(self):
95  for sub in self.subs:
96  sub.unregister()
97 
98  def callback(self, imgmsg, mask_msg):
99  bridge = cv_bridge.CvBridge()
100  img = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='bgr8')
101  mask = bridge.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
102  if mask.ndim == 3:
103  mask = np.squeeze(mask, axis=2)
104  mask = mask >= 127 # uint8 -> bool
105 
106  img = img.astype(np.float64)
107  img[mask] -= self.mean[mask]
108  img[~mask] = 0
109 
110  img = img.transpose(2, 0, 1)
111  img = img.astype(np.float32)
112  x_data = np.asarray([img])
113  if self.gpu >= 0:
114  x_data = cuda.to_gpu(x_data)
115  x = chainer.Variable(x_data)
116  y = self.model(x)
117 
118  feat = cuda.to_cpu(y.data)
119  feat = feat.squeeze(axis=(2, 3))
120  X_query = feat
121 
122  y_pred_proba = self.knn.predict_proba(X_query)
123  y_pred = np.argmax(y_pred_proba, axis=1)
124 
125  classes = self.knn.classes_
126  target_names = self.target_names[classes]
127 
128  msg = ClassificationResult()
129  msg.header = imgmsg.header
130  msg.labels = y_pred.tolist()
131  msg.label_names = target_names[y_pred].tolist()
132  msg.label_proba = y_pred_proba[:, y_pred].flatten().tolist()
133  msg.probabilities = y_pred_proba.flatten().tolist()
134  msg.target_names = target_names.tolist()
135  self.pub.publish(msg)
136 
137 
138 if __name__ == '__main__':
139  rospy.init_node('regional_feature_based_object_recognition')
141  rospy.spin()
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.__init__
def __init__(self)
Definition: regional_feature_based_object_recognition.py:48
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.sync
sync
Definition: regional_feature_based_object_recognition.py:90
jsk_recognition_utils::chainermodels
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.knn
knn
Definition: regional_feature_based_object_recognition.py:77
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.callback
def callback(self, imgmsg, mask_msg)
Definition: regional_feature_based_object_recognition.py:98
message_filters::Subscriber
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.subs
subs
Definition: regional_feature_based_object_recognition.py:85
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.gpu
gpu
Definition: regional_feature_based_object_recognition.py:52
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.subscribe
def subscribe(self)
Definition: regional_feature_based_object_recognition.py:84
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition
Definition: regional_feature_based_object_recognition.py:46
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.unsubscribe
def unsubscribe(self)
Definition: regional_feature_based_object_recognition.py:94
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.model
model
Definition: regional_feature_based_object_recognition.py:61
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.pub
pub
Definition: regional_feature_based_object_recognition.py:81
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.target_names
target_names
Definition: regional_feature_based_object_recognition.py:75
node_scripts.regional_feature_based_object_recognition.RegionalFeatureBasedObjectRecognition.mean
mean
Definition: regional_feature_based_object_recognition.py:70


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:17