00001 from __future__ import absolute_import, division, print_function, with_statement
00002
00003 import traceback
00004
00005 from tornado.concurrent import Future
00006 from tornado.httpclient import HTTPError, HTTPRequest
00007 from tornado.log import gen_log, app_log
00008 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
00009 from tornado.test.util import unittest
00010 from tornado.web import Application, RequestHandler
00011 from tornado.util import u
00012
00013 try:
00014 import tornado.websocket
00015 from tornado.util import _websocket_mask_python
00016 except ImportError:
00017
00018
00019
00020
00021 traceback.print_exc()
00022 raise
00023
00024 from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
00025
00026 try:
00027 from tornado import speedups
00028 except ImportError:
00029 speedups = None
00030
00031
00032 class TestWebSocketHandler(WebSocketHandler):
00033 """Base class for testing handlers that exposes the on_close event.
00034
00035 This allows for deterministic cleanup of the associated socket.
00036 """
00037 def initialize(self, close_future):
00038 self.close_future = close_future
00039
00040 def on_close(self):
00041 self.close_future.set_result((self.close_code, self.close_reason))
00042
00043
00044 class EchoHandler(TestWebSocketHandler):
00045 def on_message(self, message):
00046 self.write_message(message, isinstance(message, bytes))
00047
00048
00049 class ErrorInOnMessageHandler(TestWebSocketHandler):
00050 def on_message(self, message):
00051 1/0
00052
00053
00054 class HeaderHandler(TestWebSocketHandler):
00055 def open(self):
00056 try:
00057
00058
00059 self.set_status(503)
00060 raise Exception("did not get expected exception")
00061 except RuntimeError:
00062 pass
00063 self.write_message(self.request.headers.get('X-Test', ''))
00064
00065
00066 class NonWebSocketHandler(RequestHandler):
00067 def get(self):
00068 self.write('ok')
00069
00070
00071 class CloseReasonHandler(TestWebSocketHandler):
00072 def open(self):
00073 self.close(1001, "goodbye")
00074
00075
00076 class WebSocketTest(AsyncHTTPTestCase):
00077 def get_app(self):
00078 self.close_future = Future()
00079 return Application([
00080 ('/echo', EchoHandler, dict(close_future=self.close_future)),
00081 ('/non_ws', NonWebSocketHandler),
00082 ('/header', HeaderHandler, dict(close_future=self.close_future)),
00083 ('/close_reason', CloseReasonHandler,
00084 dict(close_future=self.close_future)),
00085 ('/error_in_on_message', ErrorInOnMessageHandler,
00086 dict(close_future=self.close_future)),
00087 ])
00088
00089 def test_http_request(self):
00090
00091 response = self.fetch('/echo')
00092 self.assertEqual(response.code, 400)
00093
00094 @gen_test
00095 def test_websocket_gen(self):
00096 ws = yield websocket_connect(
00097 'ws://localhost:%d/echo' % self.get_http_port(),
00098 io_loop=self.io_loop)
00099 ws.write_message('hello')
00100 response = yield ws.read_message()
00101 self.assertEqual(response, 'hello')
00102 ws.close()
00103 yield self.close_future
00104
00105 def test_websocket_callbacks(self):
00106 websocket_connect(
00107 'ws://localhost:%d/echo' % self.get_http_port(),
00108 io_loop=self.io_loop, callback=self.stop)
00109 ws = self.wait().result()
00110 ws.write_message('hello')
00111 ws.read_message(self.stop)
00112 response = self.wait().result()
00113 self.assertEqual(response, 'hello')
00114 self.close_future.add_done_callback(lambda f: self.stop())
00115 ws.close()
00116 self.wait()
00117
00118 @gen_test
00119 def test_binary_message(self):
00120 ws = yield websocket_connect(
00121 'ws://localhost:%d/echo' % self.get_http_port())
00122 ws.write_message(b'hello \xe9', binary=True)
00123 response = yield ws.read_message()
00124 self.assertEqual(response, b'hello \xe9')
00125 ws.close()
00126 yield self.close_future
00127
00128 @gen_test
00129 def test_unicode_message(self):
00130 ws = yield websocket_connect(
00131 'ws://localhost:%d/echo' % self.get_http_port())
00132 ws.write_message(u('hello \u00e9'))
00133 response = yield ws.read_message()
00134 self.assertEqual(response, u('hello \u00e9'))
00135 ws.close()
00136 yield self.close_future
00137
00138 @gen_test
00139 def test_error_in_on_message(self):
00140 ws = yield websocket_connect(
00141 'ws://localhost:%d/error_in_on_message' % self.get_http_port())
00142 ws.write_message('hello')
00143 with ExpectLog(app_log, "Uncaught exception"):
00144 response = yield ws.read_message()
00145 self.assertIs(response, None)
00146 ws.close()
00147 yield self.close_future
00148
00149 @gen_test
00150 def test_websocket_http_fail(self):
00151 with self.assertRaises(HTTPError) as cm:
00152 yield websocket_connect(
00153 'ws://localhost:%d/notfound' % self.get_http_port(),
00154 io_loop=self.io_loop)
00155 self.assertEqual(cm.exception.code, 404)
00156
00157 @gen_test
00158 def test_websocket_http_success(self):
00159 with self.assertRaises(WebSocketError):
00160 yield websocket_connect(
00161 'ws://localhost:%d/non_ws' % self.get_http_port(),
00162 io_loop=self.io_loop)
00163
00164 @gen_test
00165 def test_websocket_network_fail(self):
00166 sock, port = bind_unused_port()
00167 sock.close()
00168 with self.assertRaises(IOError):
00169 with ExpectLog(gen_log, ".*"):
00170 yield websocket_connect(
00171 'ws://localhost:%d/' % port,
00172 io_loop=self.io_loop,
00173 connect_timeout=3600)
00174
00175 @gen_test
00176 def test_websocket_close_buffered_data(self):
00177 ws = yield websocket_connect(
00178 'ws://localhost:%d/echo' % self.get_http_port())
00179 ws.write_message('hello')
00180 ws.write_message('world')
00181 ws.stream.close()
00182 yield self.close_future
00183
00184 @gen_test
00185 def test_websocket_headers(self):
00186
00187 ws = yield websocket_connect(
00188 HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
00189 headers={'X-Test': 'hello'}))
00190 response = yield ws.read_message()
00191 self.assertEqual(response, 'hello')
00192 ws.close()
00193 yield self.close_future
00194
00195 @gen_test
00196 def test_server_close_reason(self):
00197 ws = yield websocket_connect(
00198 'ws://localhost:%d/close_reason' % self.get_http_port())
00199 msg = yield ws.read_message()
00200
00201 self.assertIs(msg, None)
00202 self.assertEqual(ws.close_code, 1001)
00203 self.assertEqual(ws.close_reason, "goodbye")
00204
00205 @gen_test
00206 def test_client_close_reason(self):
00207 ws = yield websocket_connect(
00208 'ws://localhost:%d/echo' % self.get_http_port())
00209 ws.close(1001, 'goodbye')
00210 code, reason = yield self.close_future
00211 self.assertEqual(code, 1001)
00212 self.assertEqual(reason, 'goodbye')
00213
00214 @gen_test
00215 def test_check_origin_valid_no_path(self):
00216 port = self.get_http_port()
00217
00218 url = 'ws://localhost:%d/echo' % port
00219 headers = {'Origin': 'http://localhost:%d' % port}
00220
00221 ws = yield websocket_connect(HTTPRequest(url, headers=headers),
00222 io_loop=self.io_loop)
00223 ws.write_message('hello')
00224 response = yield ws.read_message()
00225 self.assertEqual(response, 'hello')
00226 ws.close()
00227 yield self.close_future
00228
00229 @gen_test
00230 def test_check_origin_valid_with_path(self):
00231 port = self.get_http_port()
00232
00233 url = 'ws://localhost:%d/echo' % port
00234 headers = {'Origin': 'http://localhost:%d/something' % port}
00235
00236 ws = yield websocket_connect(HTTPRequest(url, headers=headers),
00237 io_loop=self.io_loop)
00238 ws.write_message('hello')
00239 response = yield ws.read_message()
00240 self.assertEqual(response, 'hello')
00241 ws.close()
00242 yield self.close_future
00243
00244 @gen_test
00245 def test_check_origin_invalid_partial_url(self):
00246 port = self.get_http_port()
00247
00248 url = 'ws://localhost:%d/echo' % port
00249 headers = {'Origin': 'localhost:%d' % port}
00250
00251 with self.assertRaises(HTTPError) as cm:
00252 yield websocket_connect(HTTPRequest(url, headers=headers),
00253 io_loop=self.io_loop)
00254 self.assertEqual(cm.exception.code, 403)
00255
00256 @gen_test
00257 def test_check_origin_invalid(self):
00258 port = self.get_http_port()
00259
00260 url = 'ws://localhost:%d/echo' % port
00261
00262
00263 headers = {'Origin': 'http://somewhereelse.com'}
00264
00265 with self.assertRaises(HTTPError) as cm:
00266 yield websocket_connect(HTTPRequest(url, headers=headers),
00267 io_loop=self.io_loop)
00268
00269 self.assertEqual(cm.exception.code, 403)
00270
00271 @gen_test
00272 def test_check_origin_invalid_subdomains(self):
00273 port = self.get_http_port()
00274
00275 url = 'ws://localhost:%d/echo' % port
00276
00277
00278 headers = {'Origin': 'http://subtenant.localhost'}
00279
00280 with self.assertRaises(HTTPError) as cm:
00281 yield websocket_connect(HTTPRequest(url, headers=headers),
00282 io_loop=self.io_loop)
00283
00284 self.assertEqual(cm.exception.code, 403)
00285
00286
00287 class MaskFunctionMixin(object):
00288
00289 def test_mask(self):
00290 self.assertEqual(self.mask(b'abcd', b''), b'')
00291 self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
00292 self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
00293 self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
00294
00295
00296
00297 self.assertEqual(self.mask(b'\x00\x01\x02\x03',
00298 b'\xff\xfb\xfd\xfc\xfe\xfa'),
00299 b'\xff\xfa\xff\xff\xfe\xfb')
00300 self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
00301 b'\x00\x01\x02\x03\x04\x05'),
00302 b'\xff\xfa\xff\xff\xfb\xfe')
00303
00304
00305 class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
00306 def mask(self, mask, data):
00307 return _websocket_mask_python(mask, data)
00308
00309
00310 @unittest.skipIf(speedups is None, "tornado.speedups module not present")
00311 class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
00312 def mask(self, mask, data):
00313 return speedups.websocket_mask(mask, data)