regional_feature_based_object_recognition.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import os.path as osp
00004 
00005 import chainer
00006 from chainer import cuda
00007 import numpy as np
00008 from sklearn.neighbors import KNeighborsClassifier
00009 
00010 import cv_bridge
00011 from jsk_recognition_msgs.msg import ClassificationResult
00012 from jsk_recognition_utils.chainermodels import ResNet152
00013 from jsk_recognition_utils.chainermodels import ResNet152Feature
00014 from jsk_topic_tools import ConnectionBasedTransport
00015 import message_filters
00016 import rospy
00017 from sensor_msgs.msg import Image
00018 
00019 
00020 import rospkg
00021 PKG_PATH = rospkg.RosPack().get_path('jsk_perception')
00022 
00023 
00024 class RegionalFeatureBasedObjectRecognition(ConnectionBasedTransport):
00025 
00026     def __init__(self):
00027         super(RegionalFeatureBasedObjectRecognition, self).__init__()
00028         # parameters
00029         db_file = rospy.get_param('~db_file')
00030         self.gpu = rospy.get_param('~gpu', 0)
00031         # setup chainer
00032         chainer.global_config.train = False
00033         chainer.global_config.enable_backprop = False
00034         # model
00035         pretrained_model = osp.join(
00036             PKG_PATH, 'trained_data/resnet152_from_caffe.npz')
00037         rospy.loginfo('Loading pretrained model: %s' % pretrained_model)
00038         # TODO(wkentaro): Support Resnet50/101
00039         self.model = ResNet152Feature()
00040         chainer.serializers.load_npz(pretrained_model, self.model)
00041         if self.gpu >= 0:
00042             chainer.cuda.get_device_from_id(self.gpu).use()
00043             self.model.to_gpu()
00044         rospy.loginfo('Finished loading pretrained model')
00045         # mean
00046         mean_file = osp.join(
00047             PKG_PATH, 'trained_data/resnet_lsvrc2012_mean.npy')
00048         self.mean = np.load(mean_file)
00049         assert self.mean.shape == (224, 224, 3)  # BGR order
00050         # knn
00051         rospy.loginfo('Fitting KNN from db')
00052         db = np.load(db_file)
00053         X, y, self.target_names = db['X'], db['y'], db['target_names']
00054         self.knn = KNeighborsClassifier(n_neighbors=10)
00055         self.knn.fit(X, y)
00056         rospy.loginfo('Finished fitting KNN from db')
00057         # setup publishers
00058         self.pub = self.advertise(
00059             '~output', ClassificationResult, queue_size=1)
00060 
00061     def subscribe(self):
00062         self.subs = []
00063         self.subs.append(message_filters.Subscriber('~input', Image))
00064         self.subs.append(message_filters.Subscriber('~input/mask', Image))
00065         queue_size = 100
00066         slop = 0.1
00067         self.sync = message_filters.ApproximateTimeSynchronizer(
00068             self.subs, queue_size=queue_size, slop=slop)
00069         self.sync.registerCallback(self.callback)
00070 
00071     def unsubscribe(self):
00072         for sub in self.subs:
00073             sub.unregister()
00074 
00075     def callback(self, imgmsg, mask_msg):
00076         bridge = cv_bridge.CvBridge()
00077         img = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='bgr8')
00078         mask = bridge.imgmsg_to_cv2(mask_msg, desired_encoding='mono8')
00079         if mask.ndim == 3:
00080             mask = np.squeeze(mask, axis=2)
00081         mask = mask >= 127  # uint8 -> bool
00082 
00083         img = img.astype(np.float64)
00084         img[mask] -= self.mean[mask]
00085         img[~mask] = 0
00086 
00087         img = img.transpose(2, 0, 1)
00088         img = img.astype(np.float32)
00089         x_data = np.asarray([img])
00090         if self.gpu >= 0:
00091             x_data = cuda.to_gpu(x_data)
00092         x = chainer.Variable(x_data)
00093         y = self.model(x)
00094 
00095         feat = cuda.to_cpu(y.data)
00096         feat = feat.squeeze(axis=(2, 3))
00097         X_query = feat
00098 
00099         y_pred_proba = self.knn.predict_proba(X_query)
00100         y_pred = np.argmax(y_pred_proba, axis=1)
00101 
00102         classes = self.knn.classes_
00103         target_names = self.target_names[classes]
00104 
00105         msg = ClassificationResult()
00106         msg.header = imgmsg.header
00107         msg.labels = y_pred.tolist()
00108         msg.label_names = target_names[y_pred].tolist()
00109         msg.label_proba = y_pred_proba[:, y_pred].flatten().tolist()
00110         msg.probabilities = y_pred_proba.flatten().tolist()
00111         msg.target_names = target_names.tolist()
00112         self.pub.publish(msg)
00113 
00114 
00115 if __name__ == '__main__':
00116     rospy.init_node('regional_feature_based_object_recognition')
00117     app = RegionalFeatureBasedObjectRecognition()
00118     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23