ws.py
Go to the documentation of this file.
00001 #!/usr/bin/python
00002 import socket
00003 from select import select
00004 from threading import Lock, Thread
00005 from struct import pack, unpack
00006 import re
00007 import hashlib
00008 from base64 import b64encode
00009 
00010 class Sender(Thread):
00011         def __init__(self,proxy):
00012                 Thread.__init__(self)
00013                 self.proxy = proxy
00014                 self.go = True
00015                 self.latestException = None
00016                 self.lock = Lock()
00017                 self.lock.acquire()
00018 
00019         def run(self):
00020                 while self.go:
00021                         self.lock.acquire()
00022                         if (not self.go):
00023                                 self.proxy.lock.release()
00024                                 break
00025                         msg = self.proxy.msg
00026                         args = self.proxy.args
00027                         sent = 0
00028                         toSend = len(msg)
00029                         while sent < toSend:
00030                                 justSend = 0
00031                                 try:
00032                                         justSent = self.proxy.sock.send(msg[sent:],*args)
00033                                 except Exception, e:
00034                                         justSent = 0
00035                                         self.latestException = e
00036                                 sent = sent + justSent
00037                         self.proxy.lock.release()
00038 
00039 class SockProxy(object):
00040         def __init__(self,sock):
00041                 self.raw = True
00042                 self.sock = sock
00043                 self.lock = Lock()
00044                 thisSelf = self
00045                 self.msg = ''
00046                 self.args = []
00047                 self.sender = Sender(thisSelf)
00048                 self.sender.start()
00049 
00050         def send(self,msg,*args):
00051                 if not self.raw:
00052                         #correct the msg
00053                         length = len(msg)-2
00054                         if length < 126:
00055                                 msg = '\x81' + pack('!B',length) + msg[1:-1]
00056                         elif length <= 65535:
00057                                 msg = '\x81' + pack('!B',126) + pack('!H',length) + msg[1:-1]
00058                         else:
00059                                 msg = '\x81' + pack('!B',127) + pack('!Q',length) + msg[1:-1]
00060                 if self.lock.acquire(True):
00061                         self.msg = msg
00062                         self.args = args
00063                         self.sender.lock.release()
00064 
00065         def close(self):
00066                 self.lock.acquire()
00067                 self.sender.go = False
00068                 self.sender.lock.release()
00069 
00070 class Session(object):
00071         handshake = 0
00072         determineResponse = 1
00073         ready = 2
00074         receiveKey = 3
00075         sentinel = 4
00076         closed = 5
00077         finAndOpCode = 6
00078         sixteenBitLength = 7
00079         sixtyFourBitLength = 8
00080         maskingKey = 9
00081         payload = 10
00082         def __init__(self,sock):
00083                 self.sock = SockProxy(sock)
00084                 self.buffer = []
00085                 self.state = Session.handshake
00086                 self.count = 0
00087                 self.data = {}
00088                 self.msg = []
00089                 self.opcode = 0
00090                 self.fin = 0
00091                 self.mask = 0
00092                 self.length = 0
00093                 self.publishers = {}
00094                 self.authStatus = ''
00095                 self.latest = {}
00096                 self.latestQueue = {}
00097                 self.latestSubs = {}
00098                 self.latestLock = Lock()
00099 
00100         def transition(self,state,flush=False):
00101                 self.count = 0
00102                 self.state = state
00103                 if flush:
00104                         self.buffer = []
00105 
00106 def hybiHandshake(buffer):
00107         resp = ["HTTP/1.1 101 Switching Protocols\r\n"]
00108         resp.append("Upgrade: websocket\r\n")
00109         resp.append("Connection: Upgrade\r\n")
00110 
00111         key = re.compile(r'.*Sec-WebSocket-Key:\s*(.*?)[\r\n]',re.DOTALL).match(buffer).group(1)
00112         magicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" #from the draft protocol
00113 
00114         resp.append("Sec-WebSocket-Accept: ")
00115         resp.append(b64encode(hashlib.sha1(key + magicString).digest()))
00116         resp.append("\r\n")
00117 
00118         resp.append("\r\n")
00119 
00120         return ''.join(resp)
00121 
00122 def googleHandshake(buffer):
00123         resp = ["HTTP/1.1 101 Web Socket Protocol Handshake\r\n"]
00124         resp.append("Upgrade: WebSocket\r\n")
00125         resp.append("Connection: Upgrade\r\n")
00126 
00127         resp.append("WebSocket-Location: ws://")
00128         resp.append(re.compile(r'.*Host:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00129         resp.append("/\r\n")
00130 
00131         resp.append("WebSocket-Origin: ")
00132         resp.append(re.compile(r'.*Origin:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00133         resp.append("\r\n")
00134 
00135         resp.append("\r\n")
00136 
00137         return ''.join(resp)
00138 
00139 def actualHandshake(buffer):
00140         try:
00141                 keyNumber1 = re.compile(r'.*Sec-WebSocket-Key1: ([^\r\n]*)\s*.*',re.DOTALL).match(buffer).group(1)
00142                 spaces1 = len(re.findall(' ',keyNumber1))
00143                 keyNumber1 = re.sub(r'[^0123456789]','',keyNumber1)
00144                 keyNumber1 = int(keyNumber1)
00145                 keyNumber2 = re.compile(r'.*Sec-WebSocket-Key2: ([^\r\n]*)\s*.*',re.DOTALL).match(buffer).group(1)
00146                 spaces2 = len(re.findall(' ',keyNumber2))
00147                 keyNumber2 = re.sub(r'[^0123456789]','',keyNumber2)
00148                 keyNumber2 = int(keyNumber2)
00149 
00150 
00151                 if (spaces1 == 0 or spaces2 == 0):
00152                         return None
00153 
00154                 if (keyNumber1 % spaces1 != 0 or keyNumber2 % spaces2 != 0):
00155                         return None
00156 
00157                 part1 = keyNumber1 / spaces1
00158                 part2 = keyNumber2 / spaces2
00159 
00160                 challenge = pack('!ii',part1,part2) + buffer[-8:]
00161 
00162                 checker = hashlib.md5()
00163                 checker.update(challenge)
00164 
00165                 response = checker.digest()
00166 
00167                 resp = ["HTTP/1.1 101 Web Socket Protocol Handshake\r\n"]
00168                 resp.append("Upgrade: WebSocket\r\n")
00169                 resp.append("Connection: Upgrade\r\n")
00170 
00171                 resp.append("Sec-WebSocket-Location: ws://")
00172                 resp.append(re.compile(r'.*Host:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00173                 resp.append("/\r\n")
00174 
00175                 resp.append("Sec-WebSocket-Origin: ")
00176                 resp.append(re.compile(r'.*Origin:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00177                 resp.append("\r\n")
00178 
00179                 resp.append("\r\n")
00180 
00181                 resp.append(response)
00182 
00183                 return ''.join(resp)
00184 
00185         except:
00186                 return None
00187 
00188 def defaultHandleFrame(frame, session):
00189         session.transition(Session.ready)
00190 
00191 def defaultHandleOutput(session):
00192         pass
00193 
00194 def defaultLoop():
00195         pass
00196 
00197 def serveForever(handleFrame = defaultHandleFrame, handleOutput=defaultHandleOutput, loop=defaultLoop, host='', port=9090, hz=1.0/100000):
00198         incoming = []
00199         outgoing = []
00200         try:
00201                 serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
00202                 serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
00203                 serverSocket.bind((host,port))
00204                 serverSocket.listen(5)
00205 
00206                 incoming = [serverSocket]
00207                 sessions = {}
00208 
00209                 def closeSocket(sock):
00210                         print "closed %s" % (sock.fileno(),)
00211                         if sock.fileno() in sessions.keys():
00212                                 sessions[sock.fileno()].sock.close()
00213                                 del sessions[sock.fileno()]
00214                         incoming.remove(sock)
00215                         outgoing.remove(sock)
00216                         try:
00217                                 sock.shutdown(socket.SHUT_RDWR)
00218                                 sock.close()
00219                         except:
00220                                 print "Socket did not close smoothly."
00221 
00222                 def closeSockets():
00223                         for i in sessions.keys():
00224                                 session = sessions[i]
00225                                 if session.state == Session.closed:
00226                                         closeSocket(session.sock.sock)
00227 
00228                 while loop():
00229                         inputReady,outputReady,errors = select(incoming,[],[],hz)
00230 
00231                         for input in inputReady:
00232                                 if (input == serverSocket):
00233                                         #new connection
00234                                         connection, address = serverSocket.accept()
00235                                         print "Connection from %s:%s" % address
00236                                         sessions[connection.fileno()] = Session(connection)
00237                                         incoming.append(connection)
00238                                         outgoing.append(connection)
00239 
00240                                         print "%s concurrent connections.\n" % (len(incoming),)
00241 
00242                                 else:
00243                                         session = sessions[input.fileno()]
00244                                         try:
00245                                                 buff = input.recv(8192)
00246                                         except socket.error:
00247                                                 # Reset from remote client.
00248                                                 closeSocket(input)
00249                                                 continue
00250 
00251                                         buffLength = len(buff)
00252                                         if buffLength == 0:
00253                                                 closeSocket(input)
00254                                                 continue
00255 
00256                                         idx = -1
00257                                         while True:
00258                                                 idx = idx + 1
00259                                                 if idx >= buffLength:
00260                                                         break
00261 
00262                                                 data = buff[idx]
00263                                                 session.buffer.append(data)
00264 
00265                                                 if (session.state == Session.handshake):
00266                                                         if "\r\n\r\n" ==  ''.join(session.buffer[-4:]):
00267                                                                 session.transition(Session.determineResponse)
00268 
00269                                                 if (session.state == Session.determineResponse):
00270                                                         call = ''.join(session.buffer[:-1])
00271 
00272                                                         print "------"
00273                                                         print call
00274                                                         print "------"
00275 
00276                                                         if re.compile(r'.*Sec-WebSocket-Key2',re.DOTALL).match(call) != None:
00277                                                                 print "\"original\" websocket handshake (will be deprecated)"
00278                                                                 session.transition(Session.receiveKey)
00279 
00280                                                         elif re.compile(r'raw').match(call) != None:
00281                                                                 print "raw socket"
00282                                                                 session.transition(Session.ready)
00283                                                                 
00284                                                         elif re.compile(r'.*Sec-WebSocket-Key[^0-9]',re.DOTALL).match(call) != None:
00285                                                                 print "draft-ietf-hybi-thewebsocketprotocol-06 (preliminary, _may_ be deprecated)"
00286                                                                 input.send(hybiHandshake(call))
00287                                                                 session.sock.raw = False
00288                                                                 session.transition(Session.finAndOpCode)
00289                                                                 continue
00290 
00291                                                         else:
00292                                                                 #google handshake
00293                                                                 print "Google handshake (will be deprecated)"
00294                                                                 input.send(googleHandshake(call))
00295                                                                 session.transition(Session.ready)
00296 
00297                                                         print "------\n"
00298 
00299                                                 if (session.state == Session.receiveKey):
00300                                                         if (session.count >= 8):
00301                                                                 resp = actualHandshake(''.join(session.buffer))
00302                                                                 if resp != None:
00303                                                                         input.send(resp)
00304                                                                         session.transition(Session.ready)
00305                                                                         continue
00306                                                                 else:
00307                                                                         closeSocket(input)
00308                                                         else:   
00309                                                                 session.count = session.count + 1
00310 
00311                                                 if (session.state == Session.ready):
00312                                                         data = unpack('!b',data)[0]
00313                                                         if (data >> 7) == 0:
00314                                                                 session.transition(Session.sentinel, flush=True)
00315                                                                 continue
00316                                                         else:
00317                                                                 print "Binary data frames are unsupported"
00318                                                                 closeSocket(input)
00319 
00320                                                 if (session.state == Session.sentinel):
00321                                                         found = buff.find('\xff',idx)
00322                                                         if found != -1:
00323                                                                 session.buffer.extend(buff[idx+1:found])
00324                                                                 frame = ''.join(session.buffer)
00325                                                                 if frame[-1] == '\xff':
00326                                                                         frame = frame[:-1]
00327                                                                 handleFrame(frame, session)
00328                                                                 idx = found
00329                                                         else:
00330                                                                 session.buffer.extend(buff[idx+1:])
00331                                                                 idx = buffLength-1
00332 
00333                                                 if (session.state == Session.finAndOpCode):
00334                                                         session.count = session.count + 1
00335                                                         if session.count == 1:
00336                                                                 data = unpack('!B',data)[0]
00337                                                                 fin = data >> 7
00338                                                                 opcode = data & 15
00339                                                                 session.fin = fin
00340                                                                 session.opcode = opcode
00341                                                         if session.count == 2:
00342                                                                 data = unpack('!B',data)[0]
00343                                                                 mask = data >> 7
00344                                                                 if mask != 1:
00345                                                                         print "Only masked frames are supported."
00346                                                                         closeSocket(input)
00347                                                                 length = data & 127
00348                                                                 session.length = length
00349                                                                 if session.length == 126:
00350                                                                         session.transition(Session.sixteenBitLength, flush=True)
00351                                                                         continue
00352                                                                 if session.length == 127:
00353                                                                         session.transition(Session.sixtyFourBitLength, flush=True)
00354                                                                         continue
00355                                                                 session.transition(Session.maskingKey, flush=True)
00356                                                                 continue
00357 
00358                                                 if (session.state == Session.sixteenBitLength):
00359                                                         session.count = session.count + 1
00360                                                         if session.count == 2:
00361                                                                 length = ''.join(session.buffer)
00362                                                                 session.length = unpack('!H',length)[0]
00363                                                                 session.transition(Session.maskingKey, flush=True)
00364                                                                 continue
00365 
00366                                                 if (session.state == Session.sixtyFourBitLength):
00367                                                         session.count = session.count + 1
00368                                                         if session.count == 8:
00369                                                                 length = ''.join(session.buffer)
00370                                                                 session.length = unpack('!Q',length)[0]
00371                                                                 session.transition(Session.maskingKey, flush=True)
00372                                                                 continue
00373 
00374                                                 if (session.state == Session.maskingKey):
00375                                                         session.count = session.count + 1
00376                                                         if session.count == 4:
00377                                                                 session.mask = session.buffer + []
00378                                                                 session.transition(Session.payload, flush=True)
00379                                                                 continue
00380 
00381                                                 if (session.state == Session.payload):
00382                                                         more = -len(session.buffer)
00383                                                         span = idx+(session.length-session.count)
00384                                                         session.buffer.extend(buff[idx+1:span])
00385                                                         more = more + len(session.buffer)
00386 
00387                                                         session.count = session.count + more + 1
00388                                                         idx = idx + more
00389 
00390                                                         if session.count == session.length:
00391                                                                 msg = session.buffer + []
00392                                                                 mask = session.mask
00393                                                                 for i in xrange(session.length):
00394                                                                         msg[i] = chr(ord(msg[i]) ^ ord(mask[i % 4]))
00395                                                                 if session.opcode == 1 or session.opcode == 0:
00396                                                                         session.msg = session.msg + msg
00397                                                                 if session.fin == 1 and session.opcode == 1 or session.opcode == 0:
00398                                                                         handleFrame(''.join(session.msg),session)
00399                                                                         session.msg = []
00400                                                                 if session.opcode == 8:
00401                                                                         closeSocket(input)
00402                                                                 if session.opcode == 9:
00403                                                                         input.send('\x89' + '\x00')
00404                                                                 session.buffer = list(buff[span:])
00405                                                                 session.transition(Session.finAndOpCode)
00406 
00407 
00408 
00409                         inputReady,outputReady,errors = select([],outgoing,[],hz)
00410 
00411                         for output in outputReady:
00412                                 try:    
00413                                         handleOutput(sessions[output.fileno()])
00414                                 except socket.error, e:
00415                                         pass
00416 
00417                         closeSockets()
00418 
00419         except:
00420                 raise
00421 
00422         finally:
00423                 socks = incoming + filter(lambda x:x not in incoming,outgoing)
00424                 for sock in socks:
00425                         try:
00426                                 if sock.fileno() in sessions.keys():
00427                                         sessions[sock.fileno()].sock.close()
00428                                 print "closing sock: %s" % (sock.fileno(),)
00429                                 sock.shutdown(1)
00430                                 sock.close()
00431                         except:
00432                                 print "failed to close sock: %s" % (sock.fileno(),)


rosbridge
Author(s): Graylin Trevor Jay
autogenerated on Sun Oct 5 2014 22:41:14