websocket_test.py
Go to the documentation of this file.
00001 from __future__ import absolute_import, division, print_function, with_statement
00002 
00003 import traceback
00004 
00005 from tornado.concurrent import Future
00006 from tornado.httpclient import HTTPError, HTTPRequest
00007 from tornado.log import gen_log, app_log
00008 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
00009 from tornado.test.util import unittest
00010 from tornado.web import Application, RequestHandler
00011 from tornado.util import u
00012 
00013 try:
00014     import tornado.websocket
00015     from tornado.util import _websocket_mask_python
00016 except ImportError:
00017     # The unittest module presents misleading errors on ImportError
00018     # (it acts as if websocket_test could not be found, hiding the underlying
00019     # error).  If we get an ImportError here (which could happen due to
00020     # TORNADO_EXTENSION=1), print some extra information before failing.
00021     traceback.print_exc()
00022     raise
00023 
00024 from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
00025 
00026 try:
00027     from tornado import speedups
00028 except ImportError:
00029     speedups = None
00030 
00031 
00032 class TestWebSocketHandler(WebSocketHandler):
00033     """Base class for testing handlers that exposes the on_close event.
00034 
00035     This allows for deterministic cleanup of the associated socket.
00036     """
00037     def initialize(self, close_future):
00038         self.close_future = close_future
00039 
00040     def on_close(self):
00041         self.close_future.set_result((self.close_code, self.close_reason))
00042 
00043 
00044 class EchoHandler(TestWebSocketHandler):
00045     def on_message(self, message):
00046         self.write_message(message, isinstance(message, bytes))
00047 
00048 
00049 class ErrorInOnMessageHandler(TestWebSocketHandler):
00050     def on_message(self, message):
00051         1/0
00052 
00053 
00054 class HeaderHandler(TestWebSocketHandler):
00055     def open(self):
00056         try:
00057             # In a websocket context, many RequestHandler methods
00058             # raise RuntimeErrors.
00059             self.set_status(503)
00060             raise Exception("did not get expected exception")
00061         except RuntimeError:
00062             pass
00063         self.write_message(self.request.headers.get('X-Test', ''))
00064 
00065 
00066 class NonWebSocketHandler(RequestHandler):
00067     def get(self):
00068         self.write('ok')
00069 
00070 
00071 class CloseReasonHandler(TestWebSocketHandler):
00072     def open(self):
00073         self.close(1001, "goodbye")
00074 
00075 
00076 class WebSocketTest(AsyncHTTPTestCase):
00077     def get_app(self):
00078         self.close_future = Future()
00079         return Application([
00080             ('/echo', EchoHandler, dict(close_future=self.close_future)),
00081             ('/non_ws', NonWebSocketHandler),
00082             ('/header', HeaderHandler, dict(close_future=self.close_future)),
00083             ('/close_reason', CloseReasonHandler,
00084              dict(close_future=self.close_future)),
00085             ('/error_in_on_message', ErrorInOnMessageHandler,
00086              dict(close_future=self.close_future)),
00087         ])
00088 
00089     def test_http_request(self):
00090         # WS server, HTTP client.
00091         response = self.fetch('/echo')
00092         self.assertEqual(response.code, 400)
00093 
00094     @gen_test
00095     def test_websocket_gen(self):
00096         ws = yield websocket_connect(
00097             'ws://localhost:%d/echo' % self.get_http_port(),
00098             io_loop=self.io_loop)
00099         ws.write_message('hello')
00100         response = yield ws.read_message()
00101         self.assertEqual(response, 'hello')
00102         ws.close()
00103         yield self.close_future
00104 
00105     def test_websocket_callbacks(self):
00106         websocket_connect(
00107             'ws://localhost:%d/echo' % self.get_http_port(),
00108             io_loop=self.io_loop, callback=self.stop)
00109         ws = self.wait().result()
00110         ws.write_message('hello')
00111         ws.read_message(self.stop)
00112         response = self.wait().result()
00113         self.assertEqual(response, 'hello')
00114         self.close_future.add_done_callback(lambda f: self.stop())
00115         ws.close()
00116         self.wait()
00117 
00118     @gen_test
00119     def test_binary_message(self):
00120         ws = yield websocket_connect(
00121             'ws://localhost:%d/echo' % self.get_http_port())
00122         ws.write_message(b'hello \xe9', binary=True)
00123         response = yield ws.read_message()
00124         self.assertEqual(response, b'hello \xe9')
00125         ws.close()
00126         yield self.close_future
00127 
00128     @gen_test
00129     def test_unicode_message(self):
00130         ws = yield websocket_connect(
00131             'ws://localhost:%d/echo' % self.get_http_port())
00132         ws.write_message(u('hello \u00e9'))
00133         response = yield ws.read_message()
00134         self.assertEqual(response, u('hello \u00e9'))
00135         ws.close()
00136         yield self.close_future
00137 
00138     @gen_test
00139     def test_error_in_on_message(self):
00140         ws = yield websocket_connect(
00141             'ws://localhost:%d/error_in_on_message' % self.get_http_port())
00142         ws.write_message('hello')
00143         with ExpectLog(app_log, "Uncaught exception"):
00144             response = yield ws.read_message()
00145         self.assertIs(response, None)
00146         ws.close()
00147         yield self.close_future
00148 
00149     @gen_test
00150     def test_websocket_http_fail(self):
00151         with self.assertRaises(HTTPError) as cm:
00152             yield websocket_connect(
00153                 'ws://localhost:%d/notfound' % self.get_http_port(),
00154                 io_loop=self.io_loop)
00155         self.assertEqual(cm.exception.code, 404)
00156 
00157     @gen_test
00158     def test_websocket_http_success(self):
00159         with self.assertRaises(WebSocketError):
00160             yield websocket_connect(
00161                 'ws://localhost:%d/non_ws' % self.get_http_port(),
00162                 io_loop=self.io_loop)
00163 
00164     @gen_test
00165     def test_websocket_network_fail(self):
00166         sock, port = bind_unused_port()
00167         sock.close()
00168         with self.assertRaises(IOError):
00169             with ExpectLog(gen_log, ".*"):
00170                 yield websocket_connect(
00171                     'ws://localhost:%d/' % port,
00172                     io_loop=self.io_loop,
00173                     connect_timeout=3600)
00174 
00175     @gen_test
00176     def test_websocket_close_buffered_data(self):
00177         ws = yield websocket_connect(
00178             'ws://localhost:%d/echo' % self.get_http_port())
00179         ws.write_message('hello')
00180         ws.write_message('world')
00181         ws.stream.close()
00182         yield self.close_future
00183 
00184     @gen_test
00185     def test_websocket_headers(self):
00186         # Ensure that arbitrary headers can be passed through websocket_connect.
00187         ws = yield websocket_connect(
00188             HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
00189                         headers={'X-Test': 'hello'}))
00190         response = yield ws.read_message()
00191         self.assertEqual(response, 'hello')
00192         ws.close()
00193         yield self.close_future
00194 
00195     @gen_test
00196     def test_server_close_reason(self):
00197         ws = yield websocket_connect(
00198             'ws://localhost:%d/close_reason' % self.get_http_port())
00199         msg = yield ws.read_message()
00200         # A message of None means the other side closed the connection.
00201         self.assertIs(msg, None)
00202         self.assertEqual(ws.close_code, 1001)
00203         self.assertEqual(ws.close_reason, "goodbye")
00204 
00205     @gen_test
00206     def test_client_close_reason(self):
00207         ws = yield websocket_connect(
00208             'ws://localhost:%d/echo' % self.get_http_port())
00209         ws.close(1001, 'goodbye')
00210         code, reason = yield self.close_future
00211         self.assertEqual(code, 1001)
00212         self.assertEqual(reason, 'goodbye')
00213 
00214     @gen_test
00215     def test_check_origin_valid_no_path(self):
00216         port = self.get_http_port()
00217 
00218         url = 'ws://localhost:%d/echo' % port
00219         headers = {'Origin': 'http://localhost:%d' % port}
00220 
00221         ws = yield websocket_connect(HTTPRequest(url, headers=headers),
00222             io_loop=self.io_loop)
00223         ws.write_message('hello')
00224         response = yield ws.read_message()
00225         self.assertEqual(response, 'hello')
00226         ws.close()
00227         yield self.close_future
00228 
00229     @gen_test
00230     def test_check_origin_valid_with_path(self):
00231         port = self.get_http_port()
00232 
00233         url = 'ws://localhost:%d/echo' % port
00234         headers = {'Origin': 'http://localhost:%d/something' % port}
00235 
00236         ws = yield websocket_connect(HTTPRequest(url, headers=headers),
00237             io_loop=self.io_loop)
00238         ws.write_message('hello')
00239         response = yield ws.read_message()
00240         self.assertEqual(response, 'hello')
00241         ws.close()
00242         yield self.close_future
00243 
00244     @gen_test
00245     def test_check_origin_invalid_partial_url(self):
00246         port = self.get_http_port()
00247 
00248         url = 'ws://localhost:%d/echo' % port
00249         headers = {'Origin': 'localhost:%d' % port}
00250 
00251         with self.assertRaises(HTTPError) as cm:
00252             yield websocket_connect(HTTPRequest(url, headers=headers),
00253                                     io_loop=self.io_loop)
00254         self.assertEqual(cm.exception.code, 403)
00255 
00256     @gen_test
00257     def test_check_origin_invalid(self):
00258         port = self.get_http_port()
00259 
00260         url = 'ws://localhost:%d/echo' % port
00261         # Host is localhost, which should not be accessible from some other
00262         # domain
00263         headers = {'Origin': 'http://somewhereelse.com'}
00264 
00265         with self.assertRaises(HTTPError) as cm:
00266             yield websocket_connect(HTTPRequest(url, headers=headers),
00267                                     io_loop=self.io_loop)
00268 
00269         self.assertEqual(cm.exception.code, 403)
00270 
00271     @gen_test
00272     def test_check_origin_invalid_subdomains(self):
00273         port = self.get_http_port()
00274 
00275         url = 'ws://localhost:%d/echo' % port
00276         # Subdomains should be disallowed by default.  If we could pass a
00277         # resolver to websocket_connect we could test sibling domains as well.
00278         headers = {'Origin': 'http://subtenant.localhost'}
00279 
00280         with self.assertRaises(HTTPError) as cm:
00281             yield websocket_connect(HTTPRequest(url, headers=headers),
00282                                     io_loop=self.io_loop)
00283 
00284         self.assertEqual(cm.exception.code, 403)
00285 
00286 
00287 class MaskFunctionMixin(object):
00288     # Subclasses should define self.mask(mask, data)
00289     def test_mask(self):
00290         self.assertEqual(self.mask(b'abcd', b''), b'')
00291         self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
00292         self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
00293         self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
00294         # Include test cases with \x00 bytes (to ensure that the C
00295         # extension isn't depending on null-terminated strings) and
00296         # bytes with the high bit set (to smoke out signedness issues).
00297         self.assertEqual(self.mask(b'\x00\x01\x02\x03',
00298                                    b'\xff\xfb\xfd\xfc\xfe\xfa'),
00299                          b'\xff\xfa\xff\xff\xfe\xfb')
00300         self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
00301                                    b'\x00\x01\x02\x03\x04\x05'),
00302                          b'\xff\xfa\xff\xff\xfb\xfe')
00303 
00304 
00305 class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
00306     def mask(self, mask, data):
00307         return _websocket_mask_python(mask, data)
00308 
00309 
00310 @unittest.skipIf(speedups is None, "tornado.speedups module not present")
00311 class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
00312     def mask(self, mask, data):
00313         return speedups.websocket_mask(mask, data)


rosbridge_tools
Author(s): Jonathan Mace
autogenerated on Sun Dec 28 2014 11:43:22