00001 
00002 from PIL import Image, ImageFont, ImageDraw, ImageOps
00003 import cv2
00004 import urllib
00005 import numpy as np
00006 import rospy
00007 from sensor_msgs.msg import Image as ROSIMAGE
00008 import os
00009 from matplotlib import pyplot as plt
00010 from dji_ronin.msg import gimbalangle
00011 import tf
00012 from cv_bridge import CvBridge
00013 import subprocess
00014 
00015 __author__ = 'Itamar Eliakim'
00016 
00017 class DJIRoninClassifier():
00018     def __init__(self):
00019         rospy.init_node("DJI_Ronin_App")
00020         self.pub = rospy.Publisher('/GimbalAngle', gimbalangle, queue_size=10)
00021         self.app_pub = rospy.Publisher("/DJI_Ronin_App", ROSIMAGE, queue_size=10)
00022         self.digit_pub = rospy.Publisher("/DJI_Ronin_App_Digits", ROSIMAGE, queue_size=10)
00023         self.url = rospy.get_param("~Phone_URL","...")
00024         self.digitheight = 14
00025         self.bool_publish_Digit = rospy.get_param("~Publish_Digits",0)
00026         self.bool_publish_App = rospy.get_param("~Publish_App", 0)
00027         self.getROSpath = subprocess.Popen("echo $ROS_PACKAGE_PATH", shell=True, stdout=subprocess.PIPE).stdout.read().split(':')[0]
00028         self.model = self.createDigitsModel(self.getROSpath + "/dji_ronin/scripts/open-sans.light.ttf", self.digitheight)
00029         rospy.loginfo("######################################")
00030         rospy.loginfo("##      DJI Ronin Server is ON      ##")
00031         rospy.loginfo("#Update URL in launch set to 360x640 #")
00032         rospy.loginfo("######################################")
00033 
00034         
00035         self.lastyaw = 0
00036         self.panangle = 0
00037         
00038         try:
00039             rospy.sleep(5)                                                              
00040             self.runClassifier(self.url)
00041             rospy.spin()
00042         except:
00043             rospy.loginfo("[DJI Ronin] - Phone is not connected, Restart Server.")
00044 
00045 
00046     def updateangle(self):
00047         br = tf.TransformBroadcaster()
00048         rate = rospy.Rate(10.0)
00049         while not rospy.is_shutdown():
00050             br.sendTransform((0, 0, 0), tf.transformations.quaternion_from_euler(0, 0, np.deg2rad(self.panangle)),rospy.Time.now(), "base_link_ronin", "base_link")
00051             rate.sleep()
00052 
00053     def createDigitsModel(self,fontfile, digitheight):
00054         ttfont = ImageFont.truetype(fontfile, digitheight)
00055         samples = np.empty((0, digitheight * (digitheight / 2)))
00056         responses = []
00057         for n in range(10):
00058             pil_im = Image.new("RGB", (digitheight, digitheight * 2))
00059             ImageDraw.Draw(pil_im).text((0, 0), str(n), font=ttfont)
00060             pil_im = pil_im.crop(pil_im.getbbox())
00061             pil_im = ImageOps.invert(pil_im)
00062             
00063 
00064             
00065             cv_image = cv2.cvtColor(np.array(pil_im), cv2.COLOR_RGBA2BGRA)
00066             gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
00067             blur = cv2.GaussianBlur(gray, (5, 5), 0)
00068             thresh = cv2.adaptiveThreshold(blur, 255, 1, 1, 11, 2)
00069 
00070             roi = cv2.resize(thresh, (digitheight, digitheight / 2))
00071             responses.append(n)
00072             sample = roi.reshape((1, digitheight * (digitheight / 2)))
00073             samples = np.append(samples, sample, 0)
00074 
00075 
00076         samples = np.array(samples, np.float32)
00077         responses = np.array(responses, np.float32)
00078 
00079         model = cv2.KNearest()
00080         model.train(samples, responses)
00081         return model
00082 
00083 
00084     def findDigits(self,im, digitheight, lower, upper):
00085         im = im[450+lower:450+upper, 200:270]
00086         out = np.zeros(im.shape, np.uint8)
00087         gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
00088         thresh = cv2.adaptiveThreshold(gray, 255, 1, 1, 11, 15)
00089         contours, hierarchy = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
00090 
00091         stringlist = []
00092         sign=1
00093 
00094         
00095 
00096         centers = []
00097         radii = []
00098         for contour in contours:
00099             area = cv2.contourArea(contour)
00100             if area > 2:
00101                 continue
00102             br = cv2.boundingRect(contour)
00103             radii.append(br[2])
00104 
00105             m = cv2.moments(contour)
00106             try:
00107                 center = (int(m['m10'] / m['m00']), int(m['m01'] / m['m00']))
00108                 centers.append(center)
00109             except:
00110                 pass
00111 
00112         if len(centers)>1:
00113             center = max(centers)
00114         try:
00115             cv2.circle(out, center, 3, (255, 0, 0), -1)
00116         except:
00117             center = [50,1]
00118 
00119         for cnt in contours:
00120             x, y, w, h = cv2.boundingRect(cnt)
00121             if h > w and h > (digitheight * 4) / 5 and h < (digitheight * 6) / 5 and x<center[0]:
00122                 cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 0), 1)
00123                 roi = thresh[y:y + h, x:x + w]  
00124                 roi = cv2.resize(roi, (digitheight, digitheight / 2))
00125                 roi = roi.reshape((1, digitheight * (digitheight / 2)))
00126                 roi = np.float32(roi)
00127                 retval, results, neigh_resp, dists = self.model.find_nearest(roi, k=1)
00128                 string = str(int((results[0][0])))
00129                 
00130                 cv2.putText(out, string, (x-5, y + h), 0, 0.7, (0, 255, 0))
00131                 stringlist.append(string)
00132 
00133             elif x < 40:
00134                 approx = cv2.approxPolyDP(cnt, 0.01 * cv2.arcLength(cnt, True), True)
00135                 if len(approx) <= 4:
00136                     cv2.drawContours(out, [cnt], 0, (0, 0, 255), -1)
00137                     sign = -1
00138 
00139         return stringlist, sign, im, thresh, out
00140 
00141 
00142     def runClassifier(self,url):
00143         stream=urllib.urlopen(url)
00144         anglist = ["Pan", "Tilt", "Roll"]
00145         bytes=''
00146         while True:
00147             bytes+=stream.read(1024)
00148             a = bytes.find('\xff\xd8')
00149             b = bytes.find('\xff\xd9')
00150             if a!=-1 and b!=-1:
00151                 jpg = bytes[a:b+2]
00152                 bytes= bytes[b+2:]
00153                 i = cv2.imdecode(np.fromstring(jpg, dtype=np.uint8),cv2.CV_LOAD_IMAGE_COLOR)
00154                 
00155 
00156                 angles = []
00157                 implotlist = []
00158                 for j in range(0, len(anglist)):
00159                                 string, sign, im,thresh,out = self.findDigits(i, self.digitheight, 30*(j), 30*(j+1))
00160                                 tempstring = ""
00161                                 implotlist.extend([im,thresh,out])
00162                                 for k in range(0, len(string)):
00163                                     tempstring += string[len(string) - 1 - k]
00164                                 
00165                                 if abs(int(tempstring))>180:
00166                                     angles.append(sign * int(tempstring[0:2]))
00167                                 else:
00168                                     angles.append(sign * int(tempstring))
00169 
00170                 if self.bool_publish_App:
00171                                 bridge = CvBridge()
00172                                 msg = bridge.cv2_to_imgmsg(i,encoding="bgr8")
00173                                 self.app_pub.publish(msg)
00174 
00175                 if self.bool_publish_Digit:
00176                                 name = os.getcwd() + "/DigitsRecognition.jpg"
00177                                 plt.figure(1)
00178                                 plt.imshow(i)
00179                                 plt.figure(2)
00180                                 for j in range(0,9):
00181                                     plt.subplot(3, 3, j+1);
00182                                     plt.imshow(implotlist[j], 'gray')
00183                                 plt.savefig(name)
00184                                 plt.show()
00185 
00186                                 bridge = CvBridge()
00187                                 pic = cv2.imread(name)
00188                                 msg = bridge.cv2_to_imgmsg(pic, encoding="bgr8")
00189                                 self.digit_pub.publish(msg)
00190 
00191 
00192                 if abs(angles[0]-self.lastyaw)<50:                                              
00193                             msg = gimbalangle()
00194                             msg.header.stamp = rospy.get_rostime()    
00195                             msg.pan = self.panangle = angles[0]
00196                             msg.tilt = angles[1]
00197                             msg.roll = angles[2]
00198 
00199                             self.pub.publish(msg)
00200 
00201                 self.lastyaw = self.panangle
00202 
00203 
00204 if __name__ == "__main__":
00205     DJIRoninClassifier()