00001
00002 from PIL import Image, ImageFont, ImageDraw, ImageOps
00003 import cv2
00004 import urllib
00005 import numpy as np
00006 import rospy
00007 from sensor_msgs.msg import Image as ROSIMAGE
00008 import os
00009 from matplotlib import pyplot as plt
00010 from dji_ronin.msg import gimbalangle
00011 import tf
00012 from cv_bridge import CvBridge
00013 import subprocess
00014
00015 __author__ = 'Itamar Eliakim'
00016
00017 class DJIRoninClassifier():
00018 def __init__(self):
00019 rospy.init_node("DJI_Ronin_App")
00020 self.pub = rospy.Publisher('/GimbalAngle', gimbalangle, queue_size=10)
00021 self.app_pub = rospy.Publisher("/DJI_Ronin_App", ROSIMAGE, queue_size=10)
00022 self.digit_pub = rospy.Publisher("/DJI_Ronin_App_Digits", ROSIMAGE, queue_size=10)
00023 self.url = rospy.get_param("~Phone_URL","...")
00024 self.digitheight = 14
00025 self.bool_publish_Digit = rospy.get_param("~Publish_Digits",0)
00026 self.bool_publish_App = rospy.get_param("~Publish_App", 0)
00027 self.getROSpath = subprocess.Popen("echo $ROS_PACKAGE_PATH", shell=True, stdout=subprocess.PIPE).stdout.read().split(':')[0]
00028 self.model = self.createDigitsModel(self.getROSpath + "/dji_ronin/scripts/open-sans.light.ttf", self.digitheight)
00029 rospy.loginfo("######################################")
00030 rospy.loginfo("## DJI Ronin Server is ON ##")
00031 rospy.loginfo("#Update URL in launch set to 360x640 #")
00032 rospy.loginfo("######################################")
00033
00034
00035 self.lastyaw = 0
00036 self.panangle = 0
00037
00038 try:
00039 rospy.sleep(5)
00040 self.runClassifier(self.url)
00041 rospy.spin()
00042 except:
00043 rospy.loginfo("[DJI Ronin] - Phone is not connected, Restart Server.")
00044
00045
00046 def updateangle(self):
00047 br = tf.TransformBroadcaster()
00048 rate = rospy.Rate(10.0)
00049 while not rospy.is_shutdown():
00050 br.sendTransform((0, 0, 0), tf.transformations.quaternion_from_euler(0, 0, np.deg2rad(self.panangle)),rospy.Time.now(), "base_link_ronin", "base_link")
00051 rate.sleep()
00052
00053 def createDigitsModel(self,fontfile, digitheight):
00054 ttfont = ImageFont.truetype(fontfile, digitheight)
00055 samples = np.empty((0, digitheight * (digitheight / 2)))
00056 responses = []
00057 for n in range(10):
00058 pil_im = Image.new("RGB", (digitheight, digitheight * 2))
00059 ImageDraw.Draw(pil_im).text((0, 0), str(n), font=ttfont)
00060 pil_im = pil_im.crop(pil_im.getbbox())
00061 pil_im = ImageOps.invert(pil_im)
00062
00063
00064
00065 cv_image = cv2.cvtColor(np.array(pil_im), cv2.COLOR_RGBA2BGRA)
00066 gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
00067 blur = cv2.GaussianBlur(gray, (5, 5), 0)
00068 thresh = cv2.adaptiveThreshold(blur, 255, 1, 1, 11, 2)
00069
00070 roi = cv2.resize(thresh, (digitheight, digitheight / 2))
00071 responses.append(n)
00072 sample = roi.reshape((1, digitheight * (digitheight / 2)))
00073 samples = np.append(samples, sample, 0)
00074
00075
00076 samples = np.array(samples, np.float32)
00077 responses = np.array(responses, np.float32)
00078
00079 model = cv2.KNearest()
00080 model.train(samples, responses)
00081 return model
00082
00083
00084 def findDigits(self,im, digitheight, lower, upper):
00085 im = im[450+lower:450+upper, 200:270]
00086 out = np.zeros(im.shape, np.uint8)
00087 gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
00088 thresh = cv2.adaptiveThreshold(gray, 255, 1, 1, 11, 15)
00089 contours, hierarchy = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
00090
00091 stringlist = []
00092 sign=1
00093
00094
00095
00096 centers = []
00097 radii = []
00098 for contour in contours:
00099 area = cv2.contourArea(contour)
00100 if area > 2:
00101 continue
00102 br = cv2.boundingRect(contour)
00103 radii.append(br[2])
00104
00105 m = cv2.moments(contour)
00106 try:
00107 center = (int(m['m10'] / m['m00']), int(m['m01'] / m['m00']))
00108 centers.append(center)
00109 except:
00110 pass
00111
00112 if len(centers)>1:
00113 center = max(centers)
00114 try:
00115 cv2.circle(out, center, 3, (255, 0, 0), -1)
00116 except:
00117 center = [50,1]
00118
00119 for cnt in contours:
00120 x, y, w, h = cv2.boundingRect(cnt)
00121 if h > w and h > (digitheight * 4) / 5 and h < (digitheight * 6) / 5 and x<center[0]:
00122 cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 0), 1)
00123 roi = thresh[y:y + h, x:x + w]
00124 roi = cv2.resize(roi, (digitheight, digitheight / 2))
00125 roi = roi.reshape((1, digitheight * (digitheight / 2)))
00126 roi = np.float32(roi)
00127 retval, results, neigh_resp, dists = self.model.find_nearest(roi, k=1)
00128 string = str(int((results[0][0])))
00129
00130 cv2.putText(out, string, (x-5, y + h), 0, 0.7, (0, 255, 0))
00131 stringlist.append(string)
00132
00133 elif x < 40:
00134 approx = cv2.approxPolyDP(cnt, 0.01 * cv2.arcLength(cnt, True), True)
00135 if len(approx) <= 4:
00136 cv2.drawContours(out, [cnt], 0, (0, 0, 255), -1)
00137 sign = -1
00138
00139 return stringlist, sign, im, thresh, out
00140
00141
00142 def runClassifier(self,url):
00143 stream=urllib.urlopen(url)
00144 anglist = ["Pan", "Tilt", "Roll"]
00145 bytes=''
00146 while True:
00147 bytes+=stream.read(1024)
00148 a = bytes.find('\xff\xd8')
00149 b = bytes.find('\xff\xd9')
00150 if a!=-1 and b!=-1:
00151 jpg = bytes[a:b+2]
00152 bytes= bytes[b+2:]
00153 i = cv2.imdecode(np.fromstring(jpg, dtype=np.uint8),cv2.CV_LOAD_IMAGE_COLOR)
00154
00155
00156 angles = []
00157 implotlist = []
00158 for j in range(0, len(anglist)):
00159 string, sign, im,thresh,out = self.findDigits(i, self.digitheight, 30*(j), 30*(j+1))
00160 tempstring = ""
00161 implotlist.extend([im,thresh,out])
00162 for k in range(0, len(string)):
00163 tempstring += string[len(string) - 1 - k]
00164
00165 if abs(int(tempstring))>180:
00166 angles.append(sign * int(tempstring[0:2]))
00167 else:
00168 angles.append(sign * int(tempstring))
00169
00170 if self.bool_publish_App:
00171 bridge = CvBridge()
00172 msg = bridge.cv2_to_imgmsg(i,encoding="bgr8")
00173 self.app_pub.publish(msg)
00174
00175 if self.bool_publish_Digit:
00176 name = os.getcwd() + "/DigitsRecognition.jpg"
00177 plt.figure(1)
00178 plt.imshow(i)
00179 plt.figure(2)
00180 for j in range(0,9):
00181 plt.subplot(3, 3, j+1);
00182 plt.imshow(implotlist[j], 'gray')
00183 plt.savefig(name)
00184 plt.show()
00185
00186 bridge = CvBridge()
00187 pic = cv2.imread(name)
00188 msg = bridge.cv2_to_imgmsg(pic, encoding="bgr8")
00189 self.digit_pub.publish(msg)
00190
00191
00192 if abs(angles[0]-self.lastyaw)<50:
00193 msg = gimbalangle()
00194 msg.header.stamp = rospy.get_rostime()
00195 msg.pan = self.panangle = angles[0]
00196 msg.tilt = angles[1]
00197 msg.roll = angles[2]
00198
00199 self.pub.publish(msg)
00200
00201 self.lastyaw = self.panangle
00202
00203
00204 if __name__ == "__main__":
00205 DJIRoninClassifier()