object_detectors.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # This file is responsible for bridging ROS to the ObjectDetector class (built with PyCaffe)
00003 import os
00004 import sys
00005 import cv2
00006 import time
00007 
00008 import numpy as np
00009 import matplotlib.pyplot as plt
00010 
00011 import rospy
00012 import rospkg
00013 
00014 from cv_bridge import CvBridge, CvBridgeError
00015 from rail_object_detector import drfcn_detector
00016 
00017 from sensor_msgs.msg import Image, CompressedImage
00018 from rail_object_detection_msgs.msg import Object, Detections
00019 
00020 # Debug Helpers
00021 FAIL_COLOR = '\033[91m'
00022 ENDC_COLOR = '\033[0m'
00023 
00024 
00025 def eprint(error):
00026     sys.stderr.write(
00027         FAIL_COLOR
00028         + type(error).__name__
00029         + ": "
00030         + error.message
00031         + ENDC_COLOR
00032     )
00033 # End Debug Helpers
00034 
00035 
00036 class DRFCNDetector():
00037     """
00038     This class interfaces to the deformable R-FCN for object detection
00039     """
00040     def __init__(self):
00041         self.objects = []
00042         self.image_datastream = None
00043         self.input_image = None
00044 
00045         self.bridge = CvBridge()
00046 
00047         self.debug = rospy.get_param('~debug', default=False)
00048         self.image_sub_topic_name = rospy.get_param('~image_sub_topic_name', default='/kinect/qhd/image_color_rect')
00049         self.use_compressed_image = rospy.get_param('~use_compressed_image', default=False)
00050 
00051         rospack = rospkg.RosPack()
00052         self.model_filename = rospy.get_param(
00053             '~model_filename',
00054             os.path.join(rospack.get_path('rail_object_detector'), 'libs', 'drfcn' , 'model', 'rfcn_dcn_coco')
00055         )
00056         self.detector = drfcn_detector.Detector(self.model_filename)
00057 
00058     def _convert_msg_to_image(self, image_msg):
00059         """
00060         Convert an incoming image message (compressed or otherwise) into a cv2
00061         image
00062         """
00063         if not self.use_compressed_image:
00064             try:
00065                 image_cv = self.bridge.imgmsg_to_cv2(image_msg, "bgr8")
00066             except CvBridgeError as e:
00067                 print e
00068                 return None
00069         else:
00070             image_np = np.fromstring(image_msg.data, np.uint8)
00071             image_cv = cv2.imdecode(image_np, cv2.CV_LOAD_IMAGE_COLOR)
00072 
00073         return image_cv
00074 
00075     def _parse_image(self, image_msg):
00076         """
00077                 Take in an image and draw a bounding box within it
00078                 :param image_msg: Image data
00079                 :return: None
00080                 """
00081 
00082         header = image_msg.header
00083         image_cv = self._convert_msg_to_image(image_msg)
00084         if image_cv is None:
00085             return
00086         self.objects = self.detector.detect_objects(image_cv)
00087         if self.debug:
00088             debug_im = self._draw_boxes(image_cv, self.objects, self.detector.classes)
00089             try:
00090                 image_msg = self.bridge.cv2_to_imgmsg(debug_im, "bgr8")
00091             except CvBridgeError as e:
00092                 print e
00093             image_msg.header = header
00094             self.image_pub.publish(image_msg)
00095 
00096         # Instantiate detections object
00097         object_arr = Detections()
00098         object_arr.header = header
00099         for cls_idx, cls_name in enumerate(self.detector.classes):
00100             cls_dets = self.objects[cls_idx]
00101             for obj in cls_dets:
00102                 msg = Object()
00103                 msg.left_bot_x = int(obj[0])
00104                 msg.right_top_y = int(obj[1])
00105                 msg.right_top_x = int(obj[2])
00106                 msg.left_bot_y = int(obj[3])
00107                 msg.centroid_x = int((obj[0] + obj[2])/2)
00108                 msg.centroid_y = int((obj[1] + obj[3])/2)
00109                 msg.probability = obj[-1]
00110                 msg.label = cls_name
00111                 object_arr.objects.append(msg)
00112         self.object_pub.publish(object_arr)
00113 
00114     def _draw_boxes(self, im, dets, classes, scale=1.0):
00115         for cls_idx, cls_name in enumerate(classes):
00116             cls_dets = dets[cls_idx]
00117             color = np.random.randint(0, 256, 3)
00118             for det in cls_dets:
00119                 bbox = det[:4] * scale
00120                 cv2.rectangle(im, (bbox[0], bbox[1]),
00121                               (bbox[2], bbox[3]), color, thickness=3)
00122 
00123                 if cls_dets.shape[1] == 5:
00124                     score = det[-1]
00125                     font = cv2.FONT_HERSHEY_SIMPLEX
00126                     cv2.putText(im, '{:s} {:.3f}'.format(cls_name, score), (bbox[0], bbox[1]),
00127                                 font, fontScale=0.75, color=(0, 0, 0), thickness=2)
00128         return im
00129 
00130     def start(self,
00131             pub_image_topic='~debug/object_image',
00132             pub_object_topic='~detections'):
00133         if not self.use_compressed_image:
00134             rospy.Subscriber(self.image_sub_topic_name, Image,
00135                              self._parse_image)  # subscribe to sub_image_topic and callback parse
00136         else:
00137             rospy.Subscriber(self.image_sub_topic_name + '/compressed', CompressedImage, self._parse_image)
00138         if self.debug:
00139             self.image_pub = rospy.Publisher(pub_image_topic, Image, queue_size=2)  # image publisher
00140         self.object_pub = rospy.Publisher(pub_object_topic, Detections, queue_size=2)  # detections publisher
00141 
00142 if __name__ == '__main__':
00143     rospy.init_node('drfcn_node')
00144     detector = DRFCNDetector()
00145     detector.start()
00146     rospy.spin()


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30