00001
00002
00003 import os
00004 import sys
00005 import cv2
00006 import time
00007
00008 import numpy as np
00009 import matplotlib.pyplot as plt
00010
00011 import rospy
00012 import rospkg
00013
00014 from cv_bridge import CvBridge, CvBridgeError
00015 from rail_object_detector import drfcn_detector
00016
00017 from sensor_msgs.msg import Image, CompressedImage
00018 from rail_object_detection_msgs.msg import Object, Detections
00019
00020
00021 FAIL_COLOR = '\033[91m'
00022 ENDC_COLOR = '\033[0m'
00023
00024
00025 def eprint(error):
00026 sys.stderr.write(
00027 FAIL_COLOR
00028 + type(error).__name__
00029 + ": "
00030 + error.message
00031 + ENDC_COLOR
00032 )
00033
00034
00035
00036 class DRFCNDetector():
00037 """
00038 This class interfaces to the deformable R-FCN for object detection
00039 """
00040 def __init__(self):
00041 self.objects = []
00042 self.image_datastream = None
00043 self.input_image = None
00044
00045 self.bridge = CvBridge()
00046
00047 self.debug = rospy.get_param('~debug', default=False)
00048 self.image_sub_topic_name = rospy.get_param('~image_sub_topic_name', default='/kinect/qhd/image_color_rect')
00049 self.use_compressed_image = rospy.get_param('~use_compressed_image', default=False)
00050
00051 rospack = rospkg.RosPack()
00052 self.model_filename = rospy.get_param(
00053 '~model_filename',
00054 os.path.join(rospack.get_path('rail_object_detector'), 'libs', 'drfcn' , 'model', 'rfcn_dcn_coco')
00055 )
00056 self.detector = drfcn_detector.Detector(self.model_filename)
00057
00058 def _convert_msg_to_image(self, image_msg):
00059 """
00060 Convert an incoming image message (compressed or otherwise) into a cv2
00061 image
00062 """
00063 if not self.use_compressed_image:
00064 try:
00065 image_cv = self.bridge.imgmsg_to_cv2(image_msg, "bgr8")
00066 except CvBridgeError as e:
00067 print e
00068 return None
00069 else:
00070 image_np = np.fromstring(image_msg.data, np.uint8)
00071 image_cv = cv2.imdecode(image_np, cv2.CV_LOAD_IMAGE_COLOR)
00072
00073 return image_cv
00074
00075 def _parse_image(self, image_msg):
00076 """
00077 Take in an image and draw a bounding box within it
00078 :param image_msg: Image data
00079 :return: None
00080 """
00081
00082 header = image_msg.header
00083 image_cv = self._convert_msg_to_image(image_msg)
00084 if image_cv is None:
00085 return
00086 self.objects = self.detector.detect_objects(image_cv)
00087 if self.debug:
00088 debug_im = self._draw_boxes(image_cv, self.objects, self.detector.classes)
00089 try:
00090 image_msg = self.bridge.cv2_to_imgmsg(debug_im, "bgr8")
00091 except CvBridgeError as e:
00092 print e
00093 image_msg.header = header
00094 self.image_pub.publish(image_msg)
00095
00096
00097 object_arr = Detections()
00098 object_arr.header = header
00099 for cls_idx, cls_name in enumerate(self.detector.classes):
00100 cls_dets = self.objects[cls_idx]
00101 for obj in cls_dets:
00102 msg = Object()
00103 msg.left_bot_x = int(obj[0])
00104 msg.right_top_y = int(obj[1])
00105 msg.right_top_x = int(obj[2])
00106 msg.left_bot_y = int(obj[3])
00107 msg.centroid_x = int((obj[0] + obj[2])/2)
00108 msg.centroid_y = int((obj[1] + obj[3])/2)
00109 msg.probability = obj[-1]
00110 msg.label = cls_name
00111 object_arr.objects.append(msg)
00112 self.object_pub.publish(object_arr)
00113
00114 def _draw_boxes(self, im, dets, classes, scale=1.0):
00115 for cls_idx, cls_name in enumerate(classes):
00116 cls_dets = dets[cls_idx]
00117 color = np.random.randint(0, 256, 3)
00118 for det in cls_dets:
00119 bbox = det[:4] * scale
00120 cv2.rectangle(im, (bbox[0], bbox[1]),
00121 (bbox[2], bbox[3]), color, thickness=3)
00122
00123 if cls_dets.shape[1] == 5:
00124 score = det[-1]
00125 font = cv2.FONT_HERSHEY_SIMPLEX
00126 cv2.putText(im, '{:s} {:.3f}'.format(cls_name, score), (bbox[0], bbox[1]),
00127 font, fontScale=0.75, color=(0, 0, 0), thickness=2)
00128 return im
00129
00130 def start(self,
00131 pub_image_topic='~debug/object_image',
00132 pub_object_topic='~detections'):
00133 if not self.use_compressed_image:
00134 rospy.Subscriber(self.image_sub_topic_name, Image,
00135 self._parse_image)
00136 else:
00137 rospy.Subscriber(self.image_sub_topic_name + '/compressed', CompressedImage, self._parse_image)
00138 if self.debug:
00139 self.image_pub = rospy.Publisher(pub_image_topic, Image, queue_size=2)
00140 self.object_pub = rospy.Publisher(pub_object_topic, Detections, queue_size=2)
00141
00142 if __name__ == '__main__':
00143 rospy.init_node('drfcn_node')
00144 detector = DRFCNDetector()
00145 detector.start()
00146 rospy.spin()