ssd_object_detector.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 # Author: Furushchev <furushchev@jsk.imi.i.u-tokyo.ac.jp>
00004 
00005 import matplotlib
00006 matplotlib.use("Agg")
00007 import matplotlib.pyplot as plt
00008 
00009 from cv_bridge import CvBridge
00010 from jsk_topic_tools import ConnectionBasedTransport
00011 import numpy as np
00012 import rospy
00013 import time
00014 import yaml
00015 
00016 from dynamic_reconfigure.server import Server
00017 from jsk_perception.cfg import SSDObjectDetectorConfig as Config
00018 
00019 from sensor_msgs.msg import Image
00020 from jsk_recognition_msgs.msg import Rect, RectArray
00021 from jsk_recognition_msgs.msg import ClassificationResult
00022 
00023 import chainer
00024 from chainercv.links import SSD300
00025 from chainercv.links import SSD512
00026 from chainercv.visualizations import vis_bbox
00027 
00028 
00029 class SSDObjectDetector(ConnectionBasedTransport):
00030 
00031     def __init__(self):
00032         super(SSDObjectDetector, self).__init__()
00033         self.gpu = rospy.get_param("~gpu", -1)
00034         self.classifier_name = rospy.get_param("~classifier_name", rospy.get_name())
00035 
00036         self.cv_bridge = CvBridge()
00037 
00038         # load model
00039         self.label_names = self.load_label_names()
00040         rospy.loginfo("Loaded %d labels" % len(self.label_names))
00041 
00042         # model_path: name of pretrained model or path to model file
00043         model_path = rospy.get_param("~model_path", None)
00044         model_name = rospy.get_param('~model', 'ssd300')
00045         if model_name == 'ssd300':
00046             model_class = SSD300
00047         elif model_name == 'ssd512':
00048             model_class = SSD512
00049         else:
00050             rospy.logerr('Unsupported ~model: {}'.format(model_name))
00051         self.model = model_class(
00052             n_fg_class=len(self.label_names),
00053             pretrained_model=model_path)
00054         if self.gpu >= 0:
00055             chainer.cuda.get_device_from_id(self.gpu).use()
00056             self.model.to_gpu()
00057         rospy.loginfo("Loaded model: %s" % model_path)
00058 
00059         # dynamic reconfigure
00060         self.srv = Server(Config, self.config_callback)
00061 
00062         # advertise
00063         self.pub_rects = self.advertise("~output/rect", RectArray,
00064                                         queue_size=1)
00065         self.pub_class = self.advertise("~output/class", ClassificationResult,
00066                                         queue_size=1)
00067         self.pub_image = self.advertise("~output/image", Image,
00068                                         queue_size=1)
00069 
00070     def subscribe(self):
00071         self.sub_image = rospy.Subscriber("~input", Image, self.image_cb,
00072                                           queue_size=1, buff_size=2**26)
00073 
00074     def unsubscribe(self):
00075         self.sub_image.unregister()
00076 
00077     @property
00078     def visualize(self):
00079         return self.pub_image.get_num_connections() > 0
00080 
00081     def load_label_names(self):
00082         label_names = rospy.get_param("~label_names", tuple())
00083         if not label_names:
00084             try:
00085                 from chainercv.datasets import voc_detection_label_names
00086                 label_names = voc_detection_label_names
00087             except:
00088                 from chainercv.datasets import voc_bbox_label_names
00089                 label_names = voc_bbox_label_names
00090         elif isinstance(label_names, str):
00091             with open(label_names, "r") as f:
00092                 label_names = tuple(yaml.load(f))
00093         return label_names
00094 
00095     def config_callback(self, config, level):
00096         self.model.nms_thresh = config.nms_thresh
00097         self.model.score_thresh = config.score_thresh
00098         self.profiling = config.profiling
00099         return config
00100 
00101     def image_cb(self, msg):
00102         if self.profiling:
00103             rospy.loginfo("callback start: incomming msg is %s msec behind" % ((rospy.Time.now() - msg.header.stamp).to_sec() * 1000.0))
00104         tprev = time.time()
00105         try:
00106             # transform image to RGB, float, CHW
00107             img = self.cv_bridge.imgmsg_to_cv2(msg, desired_encoding="rgb8")
00108             img = np.asarray(img, dtype=np.float32)
00109             img = img.transpose(2, 0, 1)  # (H, W, C) -> (C, H, W)
00110         except Exception as e:
00111             rospy.logerr("Failed to convert image: %s" % str(e))
00112             return
00113         if self.profiling:
00114             tcur = time.time()
00115             rospy.loginfo("%s: elapsed %f msec" % ("convert", (tcur-tprev)*1000))
00116             tprev = tcur
00117 
00118         if self.gpu >= 0:
00119             chainer.cuda.get_device_from_id(self.gpu).use()
00120         bboxes, labels, scores = self.model.predict([img])
00121         bboxes, labels, scores = bboxes[0], labels[0], scores[0]
00122 
00123         if self.profiling:
00124             tcur = time.time()
00125             rospy.loginfo("%s: elapsed %f msec" % ("predict", (tcur-tprev)*1000))
00126             tprev = tcur
00127 
00128         rect_msg = RectArray(header=msg.header)
00129         for bbox in bboxes:
00130             rect = Rect(x=bbox[1], y=bbox[0],
00131                         width= bbox[3] - bbox[1],
00132                         height=bbox[2] - bbox[0])
00133             rect_msg.rects.append(rect)
00134 
00135         if self.profiling:
00136             tcur = time.time()
00137             rospy.loginfo("%s: elapsed %f msec" % ("make rect msg", (tcur-tprev)*1000))
00138             tprev = tcur
00139 
00140         cls_msg = ClassificationResult(
00141             header=msg.header,
00142             classifier=self.classifier_name,
00143             target_names=self.label_names,
00144             labels=labels,
00145             label_names=[self.label_names[l] for l in labels],
00146             label_proba=scores,
00147         )
00148 
00149         if self.profiling:
00150             tcur = time.time()
00151             rospy.loginfo("%s: elapsed %f msec" % ("make cls msg", (tcur-tprev)*1000))
00152             tprev = tcur
00153 
00154         self.pub_rects.publish(rect_msg)
00155         self.pub_class.publish(cls_msg)
00156 
00157         if self.profiling:
00158             tcur = time.time()
00159             rospy.loginfo("%s: elapsed %f msec" % ("publish msg", (tcur-tprev)*1000))
00160             tprev = tcur
00161 
00162         if self.visualize:
00163             self.publish_bbox_image(img, bboxes, labels, scores, msg.header)
00164 
00165         if self.profiling:
00166             tcur = time.time()
00167             rospy.loginfo("%s: elapsed %f msec" % ("callback end", (tcur-tprev)*1000))
00168             tprev = tcur
00169 
00170     def publish_bbox_image(self, img, bbox, label, score, header):
00171         vis_bbox(img, bbox, label, score,
00172                  label_names=self.label_names)
00173         fig = plt.gcf()
00174         fig.canvas.draw()
00175         w, h = fig.canvas.get_width_height()
00176         img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
00177         fig.clf()
00178         img.shape = (h, w, 3)
00179         plt.close()
00180         try:
00181             msg = self.cv_bridge.cv2_to_imgmsg(img, "rgb8")
00182         except Exception as e:
00183             rospy.logerr("Failed to convert bbox image: %s" % str(e))
00184             return
00185         msg.header = header
00186         self.pub_image.publish(msg)
00187 
00188 
00189 if __name__ == '__main__':
00190     rospy.init_node("ssd_object_detector")
00191     ssd = SSDObjectDetector()
00192     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07