00001
00002 import socket
00003 from select import select
00004 from threading import Lock, Thread
00005 from struct import pack, unpack
00006 import re
00007 import hashlib
00008 from base64 import b64encode
00009
00010 class Sender(Thread):
00011 def __init__(self,proxy):
00012 Thread.__init__(self)
00013 self.proxy = proxy
00014 self.go = True
00015 self.latestException = None
00016 self.lock = Lock()
00017 self.lock.acquire()
00018
00019 def run(self):
00020 while self.go:
00021 self.lock.acquire()
00022 if (not self.go):
00023 self.proxy.lock.release()
00024 break
00025 msg = self.proxy.msg
00026 args = self.proxy.args
00027 sent = 0
00028 toSend = len(msg)
00029 while sent < toSend:
00030 justSend = 0
00031 try:
00032 justSent = self.proxy.sock.send(msg[sent:],*args)
00033 except Exception, e:
00034 justSent = 0
00035 self.latestException = e
00036 sent = sent + justSent
00037 self.proxy.lock.release()
00038
00039 class SockProxy(object):
00040 def __init__(self,sock):
00041 self.raw = True
00042 self.sock = sock
00043 self.lock = Lock()
00044 thisSelf = self
00045 self.msg = ''
00046 self.args = []
00047 self.sender = Sender(thisSelf)
00048 self.sender.start()
00049
00050 def send(self,msg,*args):
00051 if not self.raw:
00052
00053 length = len(msg)-2
00054 if length < 126:
00055 msg = '\x81' + pack('!B',length) + msg[1:-1]
00056 elif length <= 65535:
00057 msg = '\x81' + pack('!B',126) + pack('!H',length) + msg[1:-1]
00058 else:
00059 msg = '\x81' + pack('!B',127) + pack('!Q',length) + msg[1:-1]
00060 if self.lock.acquire(True):
00061 self.msg = msg
00062 self.args = args
00063 self.sender.lock.release()
00064
00065 def close(self):
00066 self.lock.acquire()
00067 self.sender.go = False
00068 self.sender.lock.release()
00069
00070 class Session(object):
00071 handshake = 0
00072 determineResponse = 1
00073 ready = 2
00074 receiveKey = 3
00075 sentinel = 4
00076 closed = 5
00077 finAndOpCode = 6
00078 sixteenBitLength = 7
00079 sixtyFourBitLength = 8
00080 maskingKey = 9
00081 payload = 10
00082 def __init__(self,sock):
00083 self.sock = SockProxy(sock)
00084 self.buffer = []
00085 self.state = Session.handshake
00086 self.count = 0
00087 self.data = {}
00088 self.msg = []
00089 self.opcode = 0
00090 self.fin = 0
00091 self.mask = 0
00092 self.length = 0
00093 self.publishers = {}
00094 self.authStatus = ''
00095 self.latest = {}
00096 self.latestQueue = {}
00097 self.latestSubs = {}
00098 self.latestLock = Lock()
00099
00100 def transition(self,state,flush=False):
00101 self.count = 0
00102 self.state = state
00103 if flush:
00104 self.buffer = []
00105
00106 def hybiHandshake(buffer):
00107 resp = ["HTTP/1.1 101 Switching Protocols\r\n"]
00108 resp.append("Upgrade: websocket\r\n")
00109 resp.append("Connection: Upgrade\r\n")
00110
00111 key = re.compile(r'.*Sec-WebSocket-Key:\s*(.*?)[\r\n]',re.DOTALL).match(buffer).group(1)
00112 magicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
00113
00114 resp.append("Sec-WebSocket-Accept: ")
00115 resp.append(b64encode(hashlib.sha1(key + magicString).digest()))
00116 resp.append("\r\n")
00117
00118 resp.append("\r\n")
00119
00120 return ''.join(resp)
00121
00122 def googleHandshake(buffer):
00123 resp = ["HTTP/1.1 101 Web Socket Protocol Handshake\r\n"]
00124 resp.append("Upgrade: WebSocket\r\n")
00125 resp.append("Connection: Upgrade\r\n")
00126
00127 resp.append("WebSocket-Location: ws://")
00128 resp.append(re.compile(r'.*Host:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00129 resp.append("/\r\n")
00130
00131 resp.append("WebSocket-Origin: ")
00132 resp.append(re.compile(r'.*Origin:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00133 resp.append("\r\n")
00134
00135 resp.append("\r\n")
00136
00137 return ''.join(resp)
00138
00139 def actualHandshake(buffer):
00140 try:
00141 keyNumber1 = re.compile(r'.*Sec-WebSocket-Key1: ([^\r\n]*)\s*.*',re.DOTALL).match(buffer).group(1)
00142 spaces1 = len(re.findall(' ',keyNumber1))
00143 keyNumber1 = re.sub(r'[^0123456789]','',keyNumber1)
00144 keyNumber1 = int(keyNumber1)
00145 keyNumber2 = re.compile(r'.*Sec-WebSocket-Key2: ([^\r\n]*)\s*.*',re.DOTALL).match(buffer).group(1)
00146 spaces2 = len(re.findall(' ',keyNumber2))
00147 keyNumber2 = re.sub(r'[^0123456789]','',keyNumber2)
00148 keyNumber2 = int(keyNumber2)
00149
00150
00151 if (spaces1 == 0 or spaces2 == 0):
00152 return None
00153
00154 if (keyNumber1 % spaces1 != 0 or keyNumber2 % spaces2 != 0):
00155 return None
00156
00157 part1 = keyNumber1 / spaces1
00158 part2 = keyNumber2 / spaces2
00159
00160 challenge = pack('!ii',part1,part2) + buffer[-8:]
00161
00162 checker = hashlib.md5()
00163 checker.update(challenge)
00164
00165 response = checker.digest()
00166
00167 resp = ["HTTP/1.1 101 Web Socket Protocol Handshake\r\n"]
00168 resp.append("Upgrade: WebSocket\r\n")
00169 resp.append("Connection: Upgrade\r\n")
00170
00171 resp.append("Sec-WebSocket-Location: ws://")
00172 resp.append(re.compile(r'.*Host:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00173 resp.append("/\r\n")
00174
00175 resp.append("Sec-WebSocket-Origin: ")
00176 resp.append(re.compile(r'.*Origin:\s*([^\r\n]*).*',re.DOTALL).match(buffer).group(1))
00177 resp.append("\r\n")
00178
00179 resp.append("\r\n")
00180
00181 resp.append(response)
00182
00183 return ''.join(resp)
00184
00185 except:
00186 return None
00187
00188 def defaultHandleFrame(frame, session):
00189 session.transition(Session.ready)
00190
00191 def defaultHandleOutput(session):
00192 pass
00193
00194 def defaultLoop():
00195 pass
00196
00197 def serveForever(handleFrame = defaultHandleFrame, handleOutput=defaultHandleOutput, loop=defaultLoop, host='', port=9090, hz=1.0/100000):
00198 incoming = []
00199 outgoing = []
00200 try:
00201 serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
00202 serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
00203 serverSocket.bind((host,port))
00204 serverSocket.listen(5)
00205
00206 incoming = [serverSocket]
00207 sessions = {}
00208
00209 def closeSocket(sock):
00210 print "closed %s" % (sock.fileno(),)
00211 if sock.fileno() in sessions.keys():
00212 sessions[sock.fileno()].sock.close()
00213 del sessions[sock.fileno()]
00214 incoming.remove(sock)
00215 outgoing.remove(sock)
00216 try:
00217 sock.shutdown(socket.SHUT_RDWR)
00218 sock.close()
00219 except:
00220 print "Socket did not close smoothly."
00221
00222 def closeSockets():
00223 for i in sessions.keys():
00224 session = sessions[i]
00225 if session.state == Session.closed:
00226 closeSocket(session.sock.sock)
00227
00228 while loop():
00229 inputReady,outputReady,errors = select(incoming,[],[],hz)
00230
00231 for input in inputReady:
00232 if (input == serverSocket):
00233
00234 connection, address = serverSocket.accept()
00235 print "Connection from %s:%s" % address
00236 sessions[connection.fileno()] = Session(connection)
00237 incoming.append(connection)
00238 outgoing.append(connection)
00239
00240 print "%s concurrent connections.\n" % (len(incoming),)
00241
00242 else:
00243 session = sessions[input.fileno()]
00244 try:
00245 buff = input.recv(8192)
00246 except socket.error:
00247
00248 closeSocket(input)
00249 continue
00250
00251 buffLength = len(buff)
00252 if buffLength == 0:
00253 closeSocket(input)
00254 continue
00255
00256 idx = -1
00257 while True:
00258 idx = idx + 1
00259 if idx >= buffLength:
00260 break
00261
00262 data = buff[idx]
00263 session.buffer.append(data)
00264
00265 if (session.state == Session.handshake):
00266 if "\r\n\r\n" == ''.join(session.buffer[-4:]):
00267 session.transition(Session.determineResponse)
00268
00269 if (session.state == Session.determineResponse):
00270 call = ''.join(session.buffer[:-1])
00271
00272 print "------"
00273 print call
00274 print "------"
00275
00276 if re.compile(r'.*Sec-WebSocket-Key2',re.DOTALL).match(call) != None:
00277 print "\"original\" websocket handshake (will be deprecated)"
00278 session.transition(Session.receiveKey)
00279
00280 elif re.compile(r'raw').match(call) != None:
00281 print "raw socket"
00282 session.transition(Session.ready)
00283
00284 elif re.compile(r'.*Sec-WebSocket-Key[^0-9]',re.DOTALL).match(call) != None:
00285 print "draft-ietf-hybi-thewebsocketprotocol-06 (preliminary, _may_ be deprecated)"
00286 input.send(hybiHandshake(call))
00287 session.sock.raw = False
00288 session.transition(Session.finAndOpCode)
00289 continue
00290
00291 else:
00292
00293 print "Google handshake (will be deprecated)"
00294 input.send(googleHandshake(call))
00295 session.transition(Session.ready)
00296
00297 print "------\n"
00298
00299 if (session.state == Session.receiveKey):
00300 if (session.count >= 8):
00301 resp = actualHandshake(''.join(session.buffer))
00302 if resp != None:
00303 input.send(resp)
00304 session.transition(Session.ready)
00305 continue
00306 else:
00307 closeSocket(input)
00308 else:
00309 session.count = session.count + 1
00310
00311 if (session.state == Session.ready):
00312 data = unpack('!b',data)[0]
00313 if (data >> 7) == 0:
00314 session.transition(Session.sentinel, flush=True)
00315 continue
00316 else:
00317 print "Binary data frames are unsupported"
00318 closeSocket(input)
00319
00320 if (session.state == Session.sentinel):
00321 found = buff.find('\xff',idx)
00322 if found != -1:
00323 session.buffer.extend(buff[idx+1:found])
00324 frame = ''.join(session.buffer)
00325 if frame[-1] == '\xff':
00326 frame = frame[:-1]
00327 handleFrame(frame, session)
00328 idx = found
00329 else:
00330 session.buffer.extend(buff[idx+1:])
00331 idx = buffLength-1
00332
00333 if (session.state == Session.finAndOpCode):
00334 session.count = session.count + 1
00335 if session.count == 1:
00336 data = unpack('!B',data)[0]
00337 fin = data >> 7
00338 opcode = data & 15
00339 session.fin = fin
00340 session.opcode = opcode
00341 if session.count == 2:
00342 data = unpack('!B',data)[0]
00343 mask = data >> 7
00344 if mask != 1:
00345 print "Only masked frames are supported."
00346 closeSocket(input)
00347 length = data & 127
00348 session.length = length
00349 if session.length == 126:
00350 session.transition(Session.sixteenBitLength, flush=True)
00351 continue
00352 if session.length == 127:
00353 session.transition(Session.sixtyFourBitLength, flush=True)
00354 continue
00355 session.transition(Session.maskingKey, flush=True)
00356 continue
00357
00358 if (session.state == Session.sixteenBitLength):
00359 session.count = session.count + 1
00360 if session.count == 2:
00361 length = ''.join(session.buffer)
00362 session.length = unpack('!H',length)[0]
00363 session.transition(Session.maskingKey, flush=True)
00364 continue
00365
00366 if (session.state == Session.sixtyFourBitLength):
00367 session.count = session.count + 1
00368 if session.count == 8:
00369 length = ''.join(session.buffer)
00370 session.length = unpack('!Q',length)[0]
00371 session.transition(Session.maskingKey, flush=True)
00372 continue
00373
00374 if (session.state == Session.maskingKey):
00375 session.count = session.count + 1
00376 if session.count == 4:
00377 session.mask = session.buffer + []
00378 session.transition(Session.payload, flush=True)
00379 continue
00380
00381 if (session.state == Session.payload):
00382 more = -len(session.buffer)
00383 span = idx+(session.length-session.count)
00384 session.buffer.extend(buff[idx+1:span])
00385 more = more + len(session.buffer)
00386
00387 session.count = session.count + more + 1
00388 idx = idx + more
00389
00390 if session.count == session.length:
00391 msg = session.buffer + []
00392 mask = session.mask
00393 for i in xrange(session.length):
00394 msg[i] = chr(ord(msg[i]) ^ ord(mask[i % 4]))
00395 if session.opcode == 1 or session.opcode == 0:
00396 session.msg = session.msg + msg
00397 if session.fin == 1 and session.opcode == 1 or session.opcode == 0:
00398 handleFrame(''.join(session.msg),session)
00399 session.msg = []
00400 if session.opcode == 8:
00401 closeSocket(input)
00402 if session.opcode == 9:
00403 input.send('\x89' + '\x00')
00404 session.buffer = list(buff[span:])
00405 session.transition(Session.finAndOpCode)
00406
00407
00408
00409 inputReady,outputReady,errors = select([],outgoing,[],hz)
00410
00411 for output in outputReady:
00412 try:
00413 handleOutput(sessions[output.fileno()])
00414 except socket.error, e:
00415 pass
00416
00417 closeSockets()
00418
00419 except:
00420 raise
00421
00422 finally:
00423 socks = incoming + filter(lambda x:x not in incoming,outgoing)
00424 for sock in socks:
00425 try:
00426 if sock.fileno() in sessions.keys():
00427 sessions[sock.fileno()].sock.close()
00428 print "closing sock: %s" % (sock.fileno(),)
00429 sock.shutdown(1)
00430 sock.close()
00431 except:
00432 print "failed to close sock: %s" % (sock.fileno(),)