00001
00002 import socket
00003 from select import select
00004 from threading import Lock, Thread
00005 from struct import pack, unpack
00006 import re
00007 import hashlib
00008 from base64 import b64encode
00009
00010 class Sender(Thread):
00011 def __init__(self,proxy):
00012 Thread.__init__(self)
00013 self.proxy = proxy
00014 self.go = True
00015 self.latestException = None
00016 self.lock = Lock()
00017 self.lock.acquire()
00018
00019 def run(self):
00020 while self.go:
00021 self.lock.acquire()
00022 if (not self.go):
00023 self.proxy.lock.release()
00024 break
00025 msg = self.proxy.msg
00026 args = self.proxy.args
00027 sent = 0
00028 toSend = len(msg)
00029 while sent < toSend:
00030 justSend = 0
00031 try:
00032 justSent = self.proxy.sock.send(msg[sent:],*args)
00033 except Exception, e:
00034 justSent = 0
00035 self.latestException = e
00036 sent = sent + justSent
00037 self.proxy.lock.release()
00038
00039 class SockProxy(object):
00040 def __init__(self,sock):
00041 self.raw = True
00042 self.sock = sock
00043 self.lock = Lock()
00044 thisSelf = self
00045 self.msg = ''
00046 self.args = []
00047 self.sender = Sender(thisSelf)
00048 self.sender.start()
00049
00050 def send(self,msg,*args):
00051 if not self.raw:
00052
00053 length = len(msg)-2
00054 if length < 126:
00055 msg = '\x81' + pack('!B',length) + msg[1:-1]
00056 elif length <= 65535:
00057 msg = '\x81' + pack('!B',126) + pack('!H',length) + msg[1:-1]
00058 else:
00059 msg = '\x81' + pack('!B',127) + pack('!Q',length) + msg[1:-1]
00060 if self.lock.acquire(True):
00061 self.msg = msg
00062 self.args = args
00063 self.sender.lock.release()
00064
00065 def close(self):
00066 self.lock.acquire()
00067 self.sender.go = False
00068 self.sender.lock.release()
00069
00070 class Session(object):
00071 handshake = 0
00072 determineResponse = 1
00073 ready = 2
00074 receiveKey = 3
00075 sentinel = 4
00076 closed = 5
00077 finAndOpCode = 6
00078 sixteenBitLength = 7
00079 sixtyFourBitLength = 8
00080 maskingKey = 9
00081 payload = 10
00082 def __init__(self,sock):
00083 self.sock = SockProxy(sock)
00084 self.buffer = []
00085 self.state = Session.handshake
00086 self.count = 0
00087 self.data = {}
00088 self.msg = []
00089 self.opcode = 0
00090 self.fin = 0
00091 self.mask = 0
00092 self.length = 0
00093 self.publishers = {}
00094 self.authStatus = ''
00095 self.latest = {}
00096 self.latestQueue = {}
00097 self.latestSubs = {}
00098 self.latestLock = Lock()
00099
00100 def transition(self,state,flush=False):
00101 self.count = 0
00102 self.state = state
00103 if flush:
00104 self.buffer = []
00105
00106 def hybiHandshake(buffer):
00107 resp = ["HTTP/1.1 101 Switching Protocols\r\n"]
00108 resp.append("Upgrade: websocket\r\n")
00109 resp.append("Connection: Upgrade\r\n")
00110
00111 key = re.compile(r'.*Sec-WebSocket-Key:\s*(.*?)[\r\n]',re.DOTALL).match(buffer).group(1)
00112 magicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
00113
00114 resp.append("Sec-WebSocket-Accept: ")
00115 resp.append(b64encode(hashlib.sha1(key + magicString).digest()))
00116 resp.append("\r\n")
00117
00118 resp.append("\r\n")
00119
00120 return ''.join(resp)
00121
00122 def googleHandshake(buffer):
00123 resp = ["HTTP/1.1 101 Web Socket Protocol Handshake\r\n"]
00124 resp.append("Upgrade: WebSocket\r\n")
00125 resp.append("Connection: Upgrade\r\n")
00126
00127 resp.append("WebSocket-Location: ws://")
00128 resp.append(re.compile(r'.*Host:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00129 resp.append("/\r\n")
00130
00131 resp.append("WebSocket-Origin: ")
00132 resp.append(re.compile(r'.*Origin:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00133 resp.append("\r\n")
00134
00135 resp.append("\r\n")
00136
00137 return ''.join(resp)
00138
00139 def actualHandshake(buffer):
00140 try:
00141 keyNumber1 = re.compile(r'.*Sec-WebSocket-Key1: ([^\r\n]*)\s*.*',re.DOTALL).match(buffer).group(1)
00142 spaces1 = len(re.findall(' ',keyNumber1))
00143 keyNumber1 = re.sub(r'[^0123456789]','',keyNumber1)
00144 keyNumber1 = int(keyNumber1)
00145 keyNumber2 = re.compile(r'.*Sec-WebSocket-Key2: ([^\r\n]*)\s*.*',re.DOTALL).match(buffer).group(1)
00146 spaces2 = len(re.findall(' ',keyNumber2))
00147 keyNumber2 = re.sub(r'[^0123456789]','',keyNumber2)
00148 keyNumber2 = int(keyNumber2)
00149
00150
00151 if (spaces1 == 0 or spaces2 == 0):
00152 return None
00153
00154 if (keyNumber1 % spaces1 != 0 or keyNumber2 % spaces2 != 0):
00155 return None
00156
00157 part1 = keyNumber1 / spaces1
00158 part2 = keyNumber2 / spaces2
00159
00160 challenge = pack('!ii',part1,part2) + buffer[-8:]
00161
00162 checker = hashlib.md5()
00163 checker.update(challenge)
00164
00165 response = checker.digest()
00166
00167 resp = ["HTTP/1.1 101 Web Socket Protocol Handshake\r\n"]
00168 resp.append("Upgrade: WebSocket\r\n")
00169 resp.append("Connection: Upgrade\r\n")
00170
00171 resp.append("Sec-WebSocket-Location: ws://")
00172 resp.append(re.compile(r'.*Host:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00173 resp.append("/\r\n")
00174
00175 resp.append("Sec-WebSocket-Origin: ")
00176 resp.append(re.compile(r'.*Origin:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00177 resp.append("\r\n")
00178
00179 resp.append("\r\n")
00180
00181 resp.append(response)
00182
00183 return ''.join(resp)
00184
00185 except:
00186 return None
00187
00188 def defaultHandleFrame(frame, session):
00189 session.transition(Session.ready)
00190
00191 def defaultHandleOutput(session):
00192 pass
00193
00194 def defaultLoop():
00195 pass
00196
00197 def serveForever(handleFrame = defaultHandleFrame, handleOutput=defaultHandleOutput, loop=defaultLoop, host='', port=9090, hz=1.0/200):
00198 incoming = []
00199 outgoing = []
00200 try:
00201 serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
00202 serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
00203 serverSocket.bind((host,port))
00204 serverSocket.listen(5)
00205
00206 incoming = [serverSocket]
00207 sessions = {}
00208
00209 def closeSocket(sock):
00210 print "closed %s" % (sock.fileno(),)
00211 if sock.fileno() in sessions.keys():
00212 sessions[sock.fileno()].sock.close()
00213 del sessions[sock.fileno()]
00214 incoming.remove(sock)
00215 outgoing.remove(sock)
00216 try:
00217 sock.shutdown(socket.SHUT_RDWR)
00218 sock.close()
00219 except:
00220 print "Socket did not close smoothly."
00221
00222 def closeSockets():
00223 for i in sessions.keys():
00224 session = sessions[i]
00225 if session.state == Session.closed:
00226 closeSocket(session.sock.sock)
00227
00228 while loop():
00229 inputReady,outputReady,errors = select(incoming,outgoing,[],hz)
00230
00231 for input in inputReady:
00232 if (input == serverSocket):
00233
00234 connection, address = serverSocket.accept()
00235 print "Connection from %s:%s" % address
00236 sessions[connection.fileno()] = Session(connection)
00237 incoming.append(connection)
00238 outgoing.append(connection)
00239
00240 print "%s concurrent connections.\n" % (len(incoming),)
00241
00242 else:
00243 session = sessions[input.fileno()]
00244 try:
00245 buff = input.recv(8192)
00246 except socket.error:
00247
00248 closeSocket(input)
00249 continue
00250
00251 buffLength = len(buff)
00252 if buffLength == 0:
00253 closeSocket(input)
00254 continue
00255
00256 idx = -1
00257 while True:
00258 idx = idx + 1
00259 if idx >= buffLength:
00260 break
00261
00262 data = buff[idx]
00263 session.buffer.append(data)
00264
00265 if (session.state == Session.handshake):
00266 if "\r\n\r\n" == ''.join(session.buffer[-4:]):
00267 session.transition(Session.determineResponse)
00268
00269 if (session.state == Session.determineResponse):
00270 call = ''.join(session.buffer[:-1])
00271
00272 print "------"
00273 print call
00274 print "------"
00275
00276 if re.compile(r'.*Sec-WebSocket-Key2',re.DOTALL).match(call) != None:
00277 print "\"original\" websocket handshake (will be deprecated)"
00278 session.transition(Session.receiveKey)
00279
00280 elif re.compile(r'raw').match(call) != None:
00281 print "raw socket"
00282 session.transition(Session.ready)
00283
00284 elif re.compile(r'.*Sec-WebSocket-Key[^0-9]',re.DOTALL).match(call) != None:
00285 print "draft-ietf-hybi-thewebsocketprotocol-06 (preliminary, _may_ be deprecated)"
00286 input.send(hybiHandshake(call))
00287 session.sock.raw = False
00288 session.transition(Session.finAndOpCode)
00289 continue
00290
00291 else:
00292
00293 print "Google handshake (will be deprecated)"
00294 input.send(googleHandshake(call))
00295 session.transition(Session.ready)
00296
00297 print "------\n"
00298
00299 if (session.state == Session.receiveKey):
00300 if (session.count >= 8):
00301 resp = actualHandshake(''.join(session.buffer))
00302 if resp != None:
00303 input.send(resp)
00304 session.transition(Session.ready)
00305 continue
00306 else:
00307 closeSocket(input)
00308 else:
00309 session.count = session.count + 1
00310
00311 if (session.state == Session.ready):
00312 data = unpack('!b',data)[0]
00313 if (data >> 7) == 0:
00314 session.transition(Session.sentinel, flush=True)
00315 continue
00316 else:
00317 print "Binary data frames are unsupported"
00318 closeSocket(input)
00319
00320 if (session.state == Session.sentinel):
00321 found = buff.find('\xff',idx)
00322 if found != -1:
00323 session.buffer.extend(buff[idx+1:found])
00324 frame = ''.join(session.buffer)
00325 handleFrame(frame, session)
00326 idx = found
00327 else:
00328 session.buffer.extend(buff[idx+1:])
00329 idx = buffLength-1
00330
00331 if (session.state == Session.finAndOpCode):
00332 session.count = session.count + 1
00333 if session.count == 1:
00334 data = unpack('!B',data)[0]
00335 fin = data >> 7
00336 opcode = data & 15
00337 session.fin = fin
00338 session.opcode = opcode
00339 if session.count == 2:
00340 data = unpack('!B',data)[0]
00341 mask = data >> 7
00342 if mask != 1:
00343 print "Only masked frames are supported."
00344 closeSocket(input)
00345 length = data & 127
00346 session.length = length
00347 if session.length == 126:
00348 session.transition(Session.sixteenBitLength, flush=True)
00349 continue
00350 if session.length == 127:
00351 session.transition(Session.sixtyFourBitLength, flush=True)
00352 continue
00353 session.transition(Session.maskingKey, flush=True)
00354 continue
00355
00356 if (session.state == Session.sixteenBitLength):
00357 session.count = session.count + 1
00358 if session.count == 2:
00359 length = ''.join(session.buffer)
00360 session.length = unpack('!H',length)[0]
00361 session.transition(Session.maskingKey, flush=True)
00362 continue
00363
00364 if (session.state == Session.sixtyFourBitLength):
00365 session.count = session.count + 1
00366 if session.count == 8:
00367 length = ''.join(session.buffer)
00368 session.length = unpack('!Q',length)[0]
00369 session.transition(Session.maskingKey, flush=True)
00370 continue
00371
00372 if (session.state == Session.maskingKey):
00373 session.count = session.count + 1
00374 if session.count == 4:
00375 session.mask = session.buffer + []
00376 session.transition(Session.payload, flush=True)
00377 continue
00378
00379 if (session.state == Session.payload):
00380 more = -len(session.buffer)
00381 span = idx+(session.length-session.count)
00382 session.buffer.extend(buff[idx+1:span])
00383 more = more + len(session.buffer)
00384
00385 session.count = session.count + more + 1
00386 idx = idx + more
00387
00388 if session.count == session.length:
00389 msg = session.buffer + []
00390 mask = session.mask
00391 for i in xrange(session.length):
00392 msg[i] = chr(ord(msg[i]) ^ ord(mask[i % 4]))
00393 if session.opcode == 1:
00394 session.msg = session.msg + msg
00395 if session.opcode == 1 and session.fin == 1:
00396 handleFrame(''.join(session.msg),session)
00397 session.msg = []
00398 session.buffer = list(buff[span:])
00399 session.transition(Session.finAndOpCode)
00400 continue
00401 if session.opcode == 8:
00402 closeSocket(input)
00403 if session.opcode == 9:
00404 input.send('\x89' + '\x00')
00405
00406
00407 inputReady,outputReady,errors = select(incoming,outgoing,[],hz)
00408
00409 for output in outputReady:
00410 try:
00411 handleOutput(sessions[output.fileno()])
00412 except socket.error, e:
00413 pass
00414
00415 closeSockets()
00416
00417 except:
00418 raise
00419
00420 finally:
00421 socks = incoming + filter(lambda x:x not in incoming,outgoing)
00422 for sock in socks:
00423 try:
00424 if sock.fileno() in sessions.keys():
00425 sessions[sock.fileno()].sock.close()
00426 print "closing sock: %s" % (sock.fileno(),)
00427 sock.shutdown(1)
00428 sock.close()
00429 except:
00430 print "failed to close sock: %s" % (sock.fileno(),)