00001
00002
00003
00004 from __future__ import absolute_import, division, with_statement
00005 from tornado import httpclient, simple_httpclient, netutil
00006 from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
00007 from tornado.httpserver import HTTPServer
00008 from tornado.httputil import HTTPHeaders
00009 from tornado.iostream import IOStream
00010 from tornado.simple_httpclient import SimpleAsyncHTTPClient
00011 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
00012 from tornado.util import b, bytes_type
00013 from tornado.web import Application, RequestHandler
00014 import os
00015 import shutil
00016 import socket
00017 import sys
00018 import tempfile
00019
00020 try:
00021 import ssl
00022 except ImportError:
00023 ssl = None
00024
00025
00026 class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase):
00027 def get_app(self):
00028 return Application([('/', self.__class__.Handler)])
00029
00030 def fetch_json(self, *args, **kwargs):
00031 response = self.fetch(*args, **kwargs)
00032 response.rethrow()
00033 return json_decode(response.body)
00034
00035
00036 class HelloWorldRequestHandler(RequestHandler):
00037 def initialize(self, protocol="http"):
00038 self.expected_protocol = protocol
00039
00040 def get(self):
00041 assert self.request.protocol == self.expected_protocol
00042 self.finish("Hello world")
00043
00044 def post(self):
00045 self.finish("Got %d bytes in POST" % len(self.request.body))
00046
00047
00048 class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase):
00049 def get_ssl_version(self):
00050 raise NotImplementedError()
00051
00052 def setUp(self):
00053 super(BaseSSLTest, self).setUp()
00054
00055
00056
00057 self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
00058 force_instance=True)
00059
00060 def get_app(self):
00061 return Application([('/', HelloWorldRequestHandler,
00062 dict(protocol="https"))])
00063
00064 def get_httpserver_options(self):
00065
00066
00067 test_dir = os.path.dirname(__file__)
00068 return dict(ssl_options=dict(
00069 certfile=os.path.join(test_dir, 'test.crt'),
00070 keyfile=os.path.join(test_dir, 'test.key'),
00071 ssl_version=self.get_ssl_version()))
00072
00073 def fetch(self, path, **kwargs):
00074 self.http_client.fetch(self.get_url(path).replace('http', 'https'),
00075 self.stop,
00076 validate_cert=False,
00077 **kwargs)
00078 return self.wait()
00079
00080
00081 class SSLTestMixin(object):
00082 def test_ssl(self):
00083 response = self.fetch('/')
00084 self.assertEqual(response.body, b("Hello world"))
00085
00086 def test_large_post(self):
00087 response = self.fetch('/',
00088 method='POST',
00089 body='A' * 5000)
00090 self.assertEqual(response.body, b("Got 5000 bytes in POST"))
00091
00092 def test_non_ssl_request(self):
00093
00094
00095
00096 self.http_client.fetch(self.get_url("/"), self.stop,
00097 request_timeout=3600,
00098 connect_timeout=3600)
00099 response = self.wait()
00100 self.assertEqual(response.code, 599)
00101
00102
00103
00104
00105
00106
00107
00108 class SSLv23Test(BaseSSLTest, SSLTestMixin):
00109 def get_ssl_version(self):
00110 return ssl.PROTOCOL_SSLv23
00111
00112
00113 class SSLv3Test(BaseSSLTest, SSLTestMixin):
00114 def get_ssl_version(self):
00115 return ssl.PROTOCOL_SSLv3
00116
00117
00118 class TLSv1Test(BaseSSLTest, SSLTestMixin):
00119 def get_ssl_version(self):
00120 return ssl.PROTOCOL_TLSv1
00121
00122 if hasattr(ssl, 'PROTOCOL_SSLv2'):
00123 class SSLv2Test(BaseSSLTest):
00124 def get_ssl_version(self):
00125 return ssl.PROTOCOL_SSLv2
00126
00127 def test_sslv2_fail(self):
00128
00129
00130
00131 try:
00132
00133
00134
00135
00136
00137 response = self.fetch('/', request_timeout=1)
00138 except ssl.SSLError:
00139
00140
00141
00142
00143
00144
00145 return
00146 self.assertEqual(response.code, 599)
00147
00148 if ssl is None:
00149 del BaseSSLTest
00150 del SSLv23Test
00151 del SSLv3Test
00152 del TLSv1Test
00153 elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
00154
00155
00156
00157
00158
00159 del SSLv3Test
00160 del TLSv1Test
00161
00162
00163 class MultipartTestHandler(RequestHandler):
00164 def post(self):
00165 self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
00166 "argument": self.get_argument("argument"),
00167 "filename": self.request.files["files"][0].filename,
00168 "filebody": _unicode(self.request.files["files"][0]["body"]),
00169 })
00170
00171
00172 class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
00173 def set_request(self, request):
00174 self.__next_request = request
00175
00176 def _on_connect(self, parsed, parsed_hostname):
00177 self.stream.write(self.__next_request)
00178 self.__next_request = None
00179 self.stream.read_until(b("\r\n\r\n"), self._on_headers)
00180
00181
00182
00183
00184 class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
00185 def get_handlers(self):
00186 return [("/multipart", MultipartTestHandler),
00187 ("/hello", HelloWorldRequestHandler)]
00188
00189 def get_app(self):
00190 return Application(self.get_handlers())
00191
00192 def raw_fetch(self, headers, body):
00193 client = SimpleAsyncHTTPClient(self.io_loop)
00194 conn = RawRequestHTTPConnection(self.io_loop, client,
00195 httpclient.HTTPRequest(self.get_url("/")),
00196 None, self.stop,
00197 1024 * 1024)
00198 conn.set_request(
00199 b("\r\n").join(headers +
00200 [utf8("Content-Length: %d\r\n" % len(body))]) +
00201 b("\r\n") + body)
00202 response = self.wait()
00203 client.close()
00204 response.rethrow()
00205 return response
00206
00207 def test_multipart_form(self):
00208
00209
00210 response = self.raw_fetch([
00211 b("POST /multipart HTTP/1.0"),
00212 b("Content-Type: multipart/form-data; boundary=1234567890"),
00213 b("X-Header-encoding-test: \xe9"),
00214 ],
00215 b("\r\n").join([
00216 b("Content-Disposition: form-data; name=argument"),
00217 b(""),
00218 u"\u00e1".encode("utf-8"),
00219 b("--1234567890"),
00220 u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
00221 b(""),
00222 u"\u00fa".encode("utf-8"),
00223 b("--1234567890--"),
00224 b(""),
00225 ]))
00226 data = json_decode(response.body)
00227 self.assertEqual(u"\u00e9", data["header"])
00228 self.assertEqual(u"\u00e1", data["argument"])
00229 self.assertEqual(u"\u00f3", data["filename"])
00230 self.assertEqual(u"\u00fa", data["filebody"])
00231
00232 def test_100_continue(self):
00233
00234
00235
00236 stream = IOStream(socket.socket(), io_loop=self.io_loop)
00237 stream.connect(("localhost", self.get_http_port()), callback=self.stop)
00238 self.wait()
00239 stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
00240 b("Content-Length: 1024"),
00241 b("Expect: 100-continue"),
00242 b("Connection: close"),
00243 b("\r\n")]), callback=self.stop)
00244 self.wait()
00245 stream.read_until(b("\r\n\r\n"), self.stop)
00246 data = self.wait()
00247 self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
00248 stream.write(b("a") * 1024)
00249 stream.read_until(b("\r\n"), self.stop)
00250 first_line = self.wait()
00251 self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
00252 stream.read_until(b("\r\n\r\n"), self.stop)
00253 header_data = self.wait()
00254 headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
00255 stream.read_bytes(int(headers["Content-Length"]), self.stop)
00256 body = self.wait()
00257 self.assertEqual(body, b("Got 1024 bytes in POST"))
00258 stream.close()
00259
00260
00261 class EchoHandler(RequestHandler):
00262 def get(self):
00263 self.write(recursive_unicode(self.request.arguments))
00264
00265
00266 class TypeCheckHandler(RequestHandler):
00267 def prepare(self):
00268 self.errors = {}
00269 fields = [
00270 ('method', str),
00271 ('uri', str),
00272 ('version', str),
00273 ('remote_ip', str),
00274 ('protocol', str),
00275 ('host', str),
00276 ('path', str),
00277 ('query', str),
00278 ]
00279 for field, expected_type in fields:
00280 self.check_type(field, getattr(self.request, field), expected_type)
00281
00282 self.check_type('header_key', self.request.headers.keys()[0], str)
00283 self.check_type('header_value', self.request.headers.values()[0], str)
00284
00285 self.check_type('cookie_key', self.request.cookies.keys()[0], str)
00286 self.check_type('cookie_value', self.request.cookies.values()[0].value, str)
00287
00288
00289 self.check_type('arg_key', self.request.arguments.keys()[0], str)
00290 self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
00291
00292 def post(self):
00293 self.check_type('body', self.request.body, bytes_type)
00294 self.write(self.errors)
00295
00296 def get(self):
00297 self.write(self.errors)
00298
00299 def check_type(self, name, obj, expected_type):
00300 actual_type = type(obj)
00301 if expected_type != actual_type:
00302 self.errors[name] = "expected %s, got %s" % (expected_type,
00303 actual_type)
00304
00305
00306 class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
00307 def get_app(self):
00308 return Application([("/echo", EchoHandler),
00309 ("/typecheck", TypeCheckHandler),
00310 ("//doubleslash", EchoHandler),
00311 ])
00312
00313 def test_query_string_encoding(self):
00314 response = self.fetch("/echo?foo=%C3%A9")
00315 data = json_decode(response.body)
00316 self.assertEqual(data, {u"foo": [u"\u00e9"]})
00317
00318 def test_types(self):
00319 headers = {"Cookie": "foo=bar"}
00320 response = self.fetch("/typecheck?foo=bar", headers=headers)
00321 data = json_decode(response.body)
00322 self.assertEqual(data, {})
00323
00324 response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
00325 data = json_decode(response.body)
00326 self.assertEqual(data, {})
00327
00328 def test_double_slash(self):
00329
00330
00331
00332 response = self.fetch("//doubleslash")
00333 self.assertEqual(200, response.code)
00334 self.assertEqual(json_decode(response.body), {})
00335
00336
00337 class XHeaderTest(HandlerBaseTestCase):
00338 class Handler(RequestHandler):
00339 def get(self):
00340 self.write(dict(remote_ip=self.request.remote_ip))
00341
00342 def get_httpserver_options(self):
00343 return dict(xheaders=True)
00344
00345 def test_ip_headers(self):
00346 self.assertEqual(self.fetch_json("/")["remote_ip"],
00347 "127.0.0.1")
00348
00349 valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
00350 self.assertEqual(
00351 self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
00352 "4.4.4.4")
00353
00354 valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
00355 self.assertEqual(
00356 self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
00357 "2620:0:1cfe:face:b00c::3")
00358
00359 invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
00360 self.assertEqual(
00361 self.fetch_json("/", headers=invalid_chars)["remote_ip"],
00362 "127.0.0.1")
00363
00364 invalid_host = {"X-Real-IP": "www.google.com"}
00365 self.assertEqual(
00366 self.fetch_json("/", headers=invalid_host)["remote_ip"],
00367 "127.0.0.1")
00368
00369
00370 class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
00371 """HTTPServers can listen on Unix sockets too.
00372
00373 Why would you want to do this? Nginx can proxy to backends listening
00374 on unix sockets, for one thing (and managing a namespace for unix
00375 sockets can be easier than managing a bunch of TCP port numbers).
00376
00377 Unfortunately, there's no way to specify a unix socket in a url for
00378 an HTTP client, so we have to test this by hand.
00379 """
00380 def setUp(self):
00381 super(UnixSocketTest, self).setUp()
00382 self.tmpdir = tempfile.mkdtemp()
00383
00384 def tearDown(self):
00385 shutil.rmtree(self.tmpdir)
00386 super(UnixSocketTest, self).tearDown()
00387
00388 def test_unix_socket(self):
00389 sockfile = os.path.join(self.tmpdir, "test.sock")
00390 sock = netutil.bind_unix_socket(sockfile)
00391 app = Application([("/hello", HelloWorldRequestHandler)])
00392 server = HTTPServer(app, io_loop=self.io_loop)
00393 server.add_socket(sock)
00394 stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
00395 stream.connect(sockfile, self.stop)
00396 self.wait()
00397 stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
00398 stream.read_until(b("\r\n"), self.stop)
00399 response = self.wait()
00400 self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
00401 stream.read_until(b("\r\n\r\n"), self.stop)
00402 headers = HTTPHeaders.parse(self.wait().decode('latin1'))
00403 stream.read_bytes(int(headers["Content-Length"]), self.stop)
00404 body = self.wait()
00405 self.assertEqual(body, b("Hello world"))
00406 stream.close()
00407 server.stop()
00408
00409 if not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin':
00410 del UnixSocketTest