00001
00002
00003
00004
00005 import matplotlib
00006 matplotlib.use("Agg")
00007 import matplotlib.pyplot as plt
00008
00009 from cv_bridge import CvBridge
00010 from jsk_topic_tools import ConnectionBasedTransport
00011 import numpy as np
00012 import rospy
00013 import time
00014 import yaml
00015
00016 from dynamic_reconfigure.server import Server
00017 from jsk_perception.cfg import SSDObjectDetectorConfig as Config
00018
00019 from sensor_msgs.msg import Image
00020 from jsk_recognition_msgs.msg import Rect, RectArray
00021 from jsk_recognition_msgs.msg import ClassificationResult
00022
00023 import chainer
00024 from chainercv.links import SSD300
00025 from chainercv.links import SSD512
00026 from chainercv.visualizations import vis_bbox
00027
00028
00029 class SSDObjectDetector(ConnectionBasedTransport):
00030
00031 def __init__(self):
00032 super(SSDObjectDetector, self).__init__()
00033 self.gpu = rospy.get_param("~gpu", -1)
00034 self.classifier_name = rospy.get_param("~classifier_name", rospy.get_name())
00035
00036 self.cv_bridge = CvBridge()
00037
00038
00039 self.label_names = self.load_label_names()
00040 rospy.loginfo("Loaded %d labels" % len(self.label_names))
00041
00042
00043 model_path = rospy.get_param("~model_path", None)
00044 model_name = rospy.get_param('~model', 'ssd300')
00045 if model_name == 'ssd300':
00046 model_class = SSD300
00047 elif model_name == 'ssd512':
00048 model_class = SSD512
00049 else:
00050 rospy.logerr('Unsupported ~model: {}'.format(model_name))
00051 self.model = model_class(
00052 n_fg_class=len(self.label_names),
00053 pretrained_model=model_path)
00054 if self.gpu >= 0:
00055 chainer.cuda.get_device_from_id(self.gpu).use()
00056 self.model.to_gpu()
00057 rospy.loginfo("Loaded model: %s" % model_path)
00058
00059
00060 self.srv = Server(Config, self.config_callback)
00061
00062
00063 self.pub_rects = self.advertise("~output/rect", RectArray,
00064 queue_size=1)
00065 self.pub_class = self.advertise("~output/class", ClassificationResult,
00066 queue_size=1)
00067 self.pub_image = self.advertise("~output/image", Image,
00068 queue_size=1)
00069
00070 def subscribe(self):
00071 self.sub_image = rospy.Subscriber("~input", Image, self.image_cb,
00072 queue_size=1, buff_size=2**26)
00073
00074 def unsubscribe(self):
00075 self.sub_image.unregister()
00076
00077 @property
00078 def visualize(self):
00079 return self.pub_image.get_num_connections() > 0
00080
00081 def load_label_names(self):
00082 label_names = rospy.get_param("~label_names", tuple())
00083 if not label_names:
00084 try:
00085 from chainercv.datasets import voc_detection_label_names
00086 label_names = voc_detection_label_names
00087 except:
00088 from chainercv.datasets import voc_bbox_label_names
00089 label_names = voc_bbox_label_names
00090 elif isinstance(label_names, str):
00091 with open(label_names, "r") as f:
00092 label_names = tuple(yaml.load(f))
00093 return label_names
00094
00095 def config_callback(self, config, level):
00096 self.model.nms_thresh = config.nms_thresh
00097 self.model.score_thresh = config.score_thresh
00098 self.profiling = config.profiling
00099 return config
00100
00101 def image_cb(self, msg):
00102 if self.profiling:
00103 rospy.loginfo("callback start: incomming msg is %s msec behind" % ((rospy.Time.now() - msg.header.stamp).to_sec() * 1000.0))
00104 tprev = time.time()
00105 try:
00106
00107 img = self.cv_bridge.imgmsg_to_cv2(msg, desired_encoding="rgb8")
00108 img = np.asarray(img, dtype=np.float32)
00109 img = img.transpose(2, 0, 1)
00110 except Exception as e:
00111 rospy.logerr("Failed to convert image: %s" % str(e))
00112 return
00113 if self.profiling:
00114 tcur = time.time()
00115 rospy.loginfo("%s: elapsed %f msec" % ("convert", (tcur-tprev)*1000))
00116 tprev = tcur
00117
00118 if self.gpu >= 0:
00119 chainer.cuda.get_device_from_id(self.gpu).use()
00120 bboxes, labels, scores = self.model.predict([img])
00121 bboxes, labels, scores = bboxes[0], labels[0], scores[0]
00122
00123 if self.profiling:
00124 tcur = time.time()
00125 rospy.loginfo("%s: elapsed %f msec" % ("predict", (tcur-tprev)*1000))
00126 tprev = tcur
00127
00128 rect_msg = RectArray(header=msg.header)
00129 for bbox in bboxes:
00130 rect = Rect(x=bbox[1], y=bbox[0],
00131 width= bbox[3] - bbox[1],
00132 height=bbox[2] - bbox[0])
00133 rect_msg.rects.append(rect)
00134
00135 if self.profiling:
00136 tcur = time.time()
00137 rospy.loginfo("%s: elapsed %f msec" % ("make rect msg", (tcur-tprev)*1000))
00138 tprev = tcur
00139
00140 cls_msg = ClassificationResult(
00141 header=msg.header,
00142 classifier=self.classifier_name,
00143 target_names=self.label_names,
00144 labels=labels,
00145 label_names=[self.label_names[l] for l in labels],
00146 label_proba=scores,
00147 )
00148
00149 if self.profiling:
00150 tcur = time.time()
00151 rospy.loginfo("%s: elapsed %f msec" % ("make cls msg", (tcur-tprev)*1000))
00152 tprev = tcur
00153
00154 self.pub_rects.publish(rect_msg)
00155 self.pub_class.publish(cls_msg)
00156
00157 if self.profiling:
00158 tcur = time.time()
00159 rospy.loginfo("%s: elapsed %f msec" % ("publish msg", (tcur-tprev)*1000))
00160 tprev = tcur
00161
00162 if self.visualize:
00163 self.publish_bbox_image(img, bboxes, labels, scores, msg.header)
00164
00165 if self.profiling:
00166 tcur = time.time()
00167 rospy.loginfo("%s: elapsed %f msec" % ("callback end", (tcur-tprev)*1000))
00168 tprev = tcur
00169
00170 def publish_bbox_image(self, img, bbox, label, score, header):
00171 vis_bbox(img, bbox, label, score,
00172 label_names=self.label_names)
00173 fig = plt.gcf()
00174 fig.canvas.draw()
00175 w, h = fig.canvas.get_width_height()
00176 img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
00177 fig.clf()
00178 img.shape = (h, w, 3)
00179 plt.close()
00180 try:
00181 msg = self.cv_bridge.cv2_to_imgmsg(img, "rgb8")
00182 except Exception as e:
00183 rospy.logerr("Failed to convert bbox image: %s" % str(e))
00184 return
00185 msg.header = header
00186 self.pub_image.publish(msg)
00187
00188
00189 if __name__ == '__main__':
00190 rospy.init_node("ssd_object_detector")
00191 ssd = SSDObjectDetector()
00192 rospy.spin()