vgg16_object_recognition.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006 
00007 
00008 import chainer
00009 from chainer import cuda
00010 import chainer.serializers as S
00011 from chainer import Variable
00012 from distutils.version import LooseVersion
00013 import numpy as np
00014 import skimage.transform
00015 
00016 import cv_bridge
00017 from jsk_recognition_msgs.msg import ClassificationResult
00018 from jsk_recognition_utils.chainermodels import VGG16
00019 from jsk_recognition_utils.chainermodels import VGG16BatchNormalization
00020 from jsk_topic_tools import ConnectionBasedTransport
00021 from jsk_topic_tools.log_utils import logerr_throttle
00022 import message_filters
00023 import rospy
00024 from sensor_msgs.msg import Image
00025 
00026 
00027 class VGG16ObjectRecognition(ConnectionBasedTransport):
00028 
00029     mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00030 
00031     def __init__(self):
00032         super(self.__class__, self).__init__()
00033         self.insize = 224
00034         self.gpu = rospy.get_param('~gpu', -1)
00035         self.target_names = rospy.get_param('~target_names')
00036         self.model_name = rospy.get_param('~model_name')
00037         if self.model_name == 'vgg16':
00038             self.model = VGG16(n_class=len(self.target_names))
00039         elif self.model_name == 'vgg16_batch_normalization':
00040             self.model = VGG16BatchNormalization(
00041                 n_class=len(self.target_names))
00042         else:
00043             rospy.logerr('Unsupported ~model_name: {0}'
00044                          .format(self.model_name))
00045         model_file = rospy.get_param('~model_file')
00046         S.load_hdf5(model_file, self.model)
00047         if self.gpu != -1:
00048             self.model.to_gpu(self.gpu)
00049         self.pub = self.advertise('~output', ClassificationResult,
00050                                   queue_size=1)
00051         self.pub_input = self.advertise(
00052             '~debug/net_input', Image, queue_size=1)
00053 
00054     def subscribe(self):
00055         if rospy.get_param('~use_mask', False):
00056             # larger buff_size is necessary for taking time callback
00057             # http://stackoverflow.com/questions/26415699/ros-subscriber-not-up-to-date/29160379#29160379  # NOQA
00058             sub = message_filters.Subscriber(
00059                 '~input', Image, queue_size=1, buff_size=2**24)
00060             sub_mask = message_filters.Subscriber(
00061                 '~input/mask', Image, queue_size=1, buff_size=2**24)
00062             self.subs = [sub, sub_mask]
00063             queue_size = rospy.get_param('~queue_size', 10)
00064             if rospy.get_param('~approximate_sync', False):
00065                 slop = rospy.get_param('~slop', 0.1)
00066                 sync = message_filters.ApproximateTimeSynchronizer(
00067                     self.subs, queue_size=queue_size, slop=slop)
00068             else:
00069                 sync = message_filters.TimeSynchronizer(
00070                     self.subs, queue_size=queue_size)
00071             sync.registerCallback(self._recognize)
00072         else:
00073             sub = rospy.Subscriber(
00074                 '~input', Image, self._recognize, callback_args=None,
00075                 queue_size=1, buff_size=2**24)
00076             self.subs = [sub]
00077 
00078     def unsubscribe(self):
00079         for sub in self.subs:
00080             sub.unregister()
00081 
00082     def _recognize(self, imgmsg, mask_msg=None):
00083         bridge = cv_bridge.CvBridge()
00084         bgr = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='bgr8')
00085         if mask_msg is not None:
00086             mask = bridge.imgmsg_to_cv2(mask_msg)
00087             if mask.shape != bgr.shape[:2]:
00088                 logerr_throttle(10,
00089                                 'Size of input image and mask is different')
00090                 return
00091             elif mask.size == 0:
00092                 logerr_throttle(10, 'Size of input mask is 0')
00093                 return
00094             bgr[mask == 0] = self.mean_bgr
00095         bgr = skimage.transform.resize(
00096             bgr, (self.insize, self.insize), preserve_range=True)
00097         input_msg = bridge.cv2_to_imgmsg(bgr.astype(np.uint8), encoding='bgr8')
00098         input_msg.header = imgmsg.header
00099         self.pub_input.publish(input_msg)
00100 
00101         blob = (bgr - self.mean_bgr).transpose((2, 0, 1))
00102         x_data = np.array([blob], dtype=np.float32)
00103         if self.gpu != -1:
00104             x_data = cuda.to_gpu(x_data, device=self.gpu)
00105         if LooseVersion(chainer.__version__) < LooseVersion('2.0.0'):
00106             x = Variable(x_data, volatile=True)
00107             self.model.train = False
00108             self.model(x)
00109         else:
00110             with chainer.using_config('train', False), \
00111                     chainer.no_backprop_mode():
00112                 x = Variable(x_data)
00113                 self.model(x)
00114 
00115         proba = cuda.to_cpu(self.model.pred.data)[0]
00116         label = np.argmax(proba)
00117         label_name = self.target_names[label]
00118         label_proba = proba[label]
00119         cls_msg = ClassificationResult(
00120             header=imgmsg.header,
00121             labels=[label],
00122             label_names=[label_name],
00123             label_proba=[label_proba],
00124             probabilities=proba,
00125             classifier=self.model_name,
00126             target_names=self.target_names,
00127         )
00128         self.pub.publish(cls_msg)
00129 
00130 
00131 if __name__ == '__main__':
00132     rospy.init_node('vgg16_object_recognition')
00133     app = VGG16ObjectRecognition()
00134     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23