00001 
00002 
00003 
00004 from __future__ import absolute_import, division, with_statement
00005 from tornado import httpclient, simple_httpclient, netutil
00006 from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
00007 from tornado.httpserver import HTTPServer
00008 from tornado.httputil import HTTPHeaders
00009 from tornado.iostream import IOStream
00010 from tornado.simple_httpclient import SimpleAsyncHTTPClient
00011 from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
00012 from tornado.util import b, bytes_type
00013 from tornado.web import Application, RequestHandler
00014 import os
00015 import shutil
00016 import socket
00017 import sys
00018 import tempfile
00019 
00020 try:
00021     import ssl
00022 except ImportError:
00023     ssl = None
00024 
00025 
00026 class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase):
00027     def get_app(self):
00028         return Application([('/', self.__class__.Handler)])
00029 
00030     def fetch_json(self, *args, **kwargs):
00031         response = self.fetch(*args, **kwargs)
00032         response.rethrow()
00033         return json_decode(response.body)
00034 
00035 
00036 class HelloWorldRequestHandler(RequestHandler):
00037     def initialize(self, protocol="http"):
00038         self.expected_protocol = protocol
00039 
00040     def get(self):
00041         assert self.request.protocol == self.expected_protocol
00042         self.finish("Hello world")
00043 
00044     def post(self):
00045         self.finish("Got %d bytes in POST" % len(self.request.body))
00046 
00047 
00048 class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase):
00049     def get_ssl_version(self):
00050         raise NotImplementedError()
00051 
00052     def setUp(self):
00053         super(BaseSSLTest, self).setUp()
00054         
00055         
00056         
00057         self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
00058                                                  force_instance=True)
00059 
00060     def get_app(self):
00061         return Application([('/', HelloWorldRequestHandler,
00062                              dict(protocol="https"))])
00063 
00064     def get_httpserver_options(self):
00065         
00066         
00067         test_dir = os.path.dirname(__file__)
00068         return dict(ssl_options=dict(
00069                 certfile=os.path.join(test_dir, 'test.crt'),
00070                 keyfile=os.path.join(test_dir, 'test.key'),
00071                 ssl_version=self.get_ssl_version()))
00072 
00073     def fetch(self, path, **kwargs):
00074         self.http_client.fetch(self.get_url(path).replace('http', 'https'),
00075                                self.stop,
00076                                validate_cert=False,
00077                                **kwargs)
00078         return self.wait()
00079 
00080 
00081 class SSLTestMixin(object):
00082     def test_ssl(self):
00083         response = self.fetch('/')
00084         self.assertEqual(response.body, b("Hello world"))
00085 
00086     def test_large_post(self):
00087         response = self.fetch('/',
00088                               method='POST',
00089                               body='A' * 5000)
00090         self.assertEqual(response.body, b("Got 5000 bytes in POST"))
00091 
00092     def test_non_ssl_request(self):
00093         
00094         
00095         
00096         self.http_client.fetch(self.get_url("/"), self.stop,
00097                                request_timeout=3600,
00098                                connect_timeout=3600)
00099         response = self.wait()
00100         self.assertEqual(response.code, 599)
00101 
00102 
00103 
00104 
00105 
00106 
00107 
00108 class SSLv23Test(BaseSSLTest, SSLTestMixin):
00109     def get_ssl_version(self):
00110         return ssl.PROTOCOL_SSLv23
00111 
00112 
00113 class SSLv3Test(BaseSSLTest, SSLTestMixin):
00114     def get_ssl_version(self):
00115         return ssl.PROTOCOL_SSLv3
00116 
00117 
00118 class TLSv1Test(BaseSSLTest, SSLTestMixin):
00119     def get_ssl_version(self):
00120         return ssl.PROTOCOL_TLSv1
00121 
00122 if hasattr(ssl, 'PROTOCOL_SSLv2'):
00123     class SSLv2Test(BaseSSLTest):
00124         def get_ssl_version(self):
00125             return ssl.PROTOCOL_SSLv2
00126 
00127         def test_sslv2_fail(self):
00128             
00129             
00130             
00131             try:
00132                 
00133                 
00134                 
00135                 
00136                 
00137                 response = self.fetch('/', request_timeout=1)
00138             except ssl.SSLError:
00139                 
00140                 
00141                 
00142                 
00143                 
00144                 
00145                 return
00146             self.assertEqual(response.code, 599)
00147 
00148 if ssl is None:
00149     del BaseSSLTest
00150     del SSLv23Test
00151     del SSLv3Test
00152     del TLSv1Test
00153 elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
00154     
00155     
00156     
00157     
00158     
00159     del SSLv3Test
00160     del TLSv1Test
00161 
00162 
00163 class MultipartTestHandler(RequestHandler):
00164     def post(self):
00165         self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
00166                      "argument": self.get_argument("argument"),
00167                      "filename": self.request.files["files"][0].filename,
00168                      "filebody": _unicode(self.request.files["files"][0]["body"]),
00169                      })
00170 
00171 
00172 class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
00173     def set_request(self, request):
00174         self.__next_request = request
00175 
00176     def _on_connect(self, parsed, parsed_hostname):
00177         self.stream.write(self.__next_request)
00178         self.__next_request = None
00179         self.stream.read_until(b("\r\n\r\n"), self._on_headers)
00180 
00181 
00182 
00183 
00184 class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
00185     def get_handlers(self):
00186         return [("/multipart", MultipartTestHandler),
00187                 ("/hello", HelloWorldRequestHandler)]
00188 
00189     def get_app(self):
00190         return Application(self.get_handlers())
00191 
00192     def raw_fetch(self, headers, body):
00193         client = SimpleAsyncHTTPClient(self.io_loop)
00194         conn = RawRequestHTTPConnection(self.io_loop, client,
00195                                         httpclient.HTTPRequest(self.get_url("/")),
00196                                         None, self.stop,
00197                                         1024 * 1024)
00198         conn.set_request(
00199             b("\r\n").join(headers +
00200                            [utf8("Content-Length: %d\r\n" % len(body))]) +
00201             b("\r\n") + body)
00202         response = self.wait()
00203         client.close()
00204         response.rethrow()
00205         return response
00206 
00207     def test_multipart_form(self):
00208         
00209         
00210         response = self.raw_fetch([
00211                 b("POST /multipart HTTP/1.0"),
00212                 b("Content-Type: multipart/form-data; boundary=1234567890"),
00213                 b("X-Header-encoding-test: \xe9"),
00214                 ],
00215                                   b("\r\n").join([
00216                     b("Content-Disposition: form-data; name=argument"),
00217                     b(""),
00218                     u"\u00e1".encode("utf-8"),
00219                     b("--1234567890"),
00220                     u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
00221                     b(""),
00222                     u"\u00fa".encode("utf-8"),
00223                     b("--1234567890--"),
00224                     b(""),
00225                     ]))
00226         data = json_decode(response.body)
00227         self.assertEqual(u"\u00e9", data["header"])
00228         self.assertEqual(u"\u00e1", data["argument"])
00229         self.assertEqual(u"\u00f3", data["filename"])
00230         self.assertEqual(u"\u00fa", data["filebody"])
00231 
00232     def test_100_continue(self):
00233         
00234         
00235         
00236         stream = IOStream(socket.socket(), io_loop=self.io_loop)
00237         stream.connect(("localhost", self.get_http_port()), callback=self.stop)
00238         self.wait()
00239         stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
00240                                      b("Content-Length: 1024"),
00241                                      b("Expect: 100-continue"),
00242                                      b("Connection: close"),
00243                                      b("\r\n")]), callback=self.stop)
00244         self.wait()
00245         stream.read_until(b("\r\n\r\n"), self.stop)
00246         data = self.wait()
00247         self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
00248         stream.write(b("a") * 1024)
00249         stream.read_until(b("\r\n"), self.stop)
00250         first_line = self.wait()
00251         self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
00252         stream.read_until(b("\r\n\r\n"), self.stop)
00253         header_data = self.wait()
00254         headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
00255         stream.read_bytes(int(headers["Content-Length"]), self.stop)
00256         body = self.wait()
00257         self.assertEqual(body, b("Got 1024 bytes in POST"))
00258         stream.close()
00259 
00260 
00261 class EchoHandler(RequestHandler):
00262     def get(self):
00263         self.write(recursive_unicode(self.request.arguments))
00264 
00265 
00266 class TypeCheckHandler(RequestHandler):
00267     def prepare(self):
00268         self.errors = {}
00269         fields = [
00270             ('method', str),
00271             ('uri', str),
00272             ('version', str),
00273             ('remote_ip', str),
00274             ('protocol', str),
00275             ('host', str),
00276             ('path', str),
00277             ('query', str),
00278             ]
00279         for field, expected_type in fields:
00280             self.check_type(field, getattr(self.request, field), expected_type)
00281 
00282         self.check_type('header_key', self.request.headers.keys()[0], str)
00283         self.check_type('header_value', self.request.headers.values()[0], str)
00284 
00285         self.check_type('cookie_key', self.request.cookies.keys()[0], str)
00286         self.check_type('cookie_value', self.request.cookies.values()[0].value, str)
00287         
00288 
00289         self.check_type('arg_key', self.request.arguments.keys()[0], str)
00290         self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
00291 
00292     def post(self):
00293         self.check_type('body', self.request.body, bytes_type)
00294         self.write(self.errors)
00295 
00296     def get(self):
00297         self.write(self.errors)
00298 
00299     def check_type(self, name, obj, expected_type):
00300         actual_type = type(obj)
00301         if expected_type != actual_type:
00302             self.errors[name] = "expected %s, got %s" % (expected_type,
00303                                                          actual_type)
00304 
00305 
00306 class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
00307     def get_app(self):
00308         return Application([("/echo", EchoHandler),
00309                             ("/typecheck", TypeCheckHandler),
00310                             ("//doubleslash", EchoHandler),
00311                             ])
00312 
00313     def test_query_string_encoding(self):
00314         response = self.fetch("/echo?foo=%C3%A9")
00315         data = json_decode(response.body)
00316         self.assertEqual(data, {u"foo": [u"\u00e9"]})
00317 
00318     def test_types(self):
00319         headers = {"Cookie": "foo=bar"}
00320         response = self.fetch("/typecheck?foo=bar", headers=headers)
00321         data = json_decode(response.body)
00322         self.assertEqual(data, {})
00323 
00324         response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
00325         data = json_decode(response.body)
00326         self.assertEqual(data, {})
00327 
00328     def test_double_slash(self):
00329         
00330         
00331         
00332         response = self.fetch("//doubleslash")
00333         self.assertEqual(200, response.code)
00334         self.assertEqual(json_decode(response.body), {})
00335 
00336 
00337 class XHeaderTest(HandlerBaseTestCase):
00338     class Handler(RequestHandler):
00339         def get(self):
00340             self.write(dict(remote_ip=self.request.remote_ip))
00341 
00342     def get_httpserver_options(self):
00343         return dict(xheaders=True)
00344 
00345     def test_ip_headers(self):
00346         self.assertEqual(self.fetch_json("/")["remote_ip"],
00347                          "127.0.0.1")
00348 
00349         valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
00350         self.assertEqual(
00351             self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
00352             "4.4.4.4")
00353 
00354         valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
00355         self.assertEqual(
00356             self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
00357             "2620:0:1cfe:face:b00c::3")
00358 
00359         invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
00360         self.assertEqual(
00361             self.fetch_json("/", headers=invalid_chars)["remote_ip"],
00362             "127.0.0.1")
00363 
00364         invalid_host = {"X-Real-IP": "www.google.com"}
00365         self.assertEqual(
00366             self.fetch_json("/", headers=invalid_host)["remote_ip"],
00367             "127.0.0.1")
00368 
00369 
00370 class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
00371     """HTTPServers can listen on Unix sockets too.
00372 
00373     Why would you want to do this?  Nginx can proxy to backends listening
00374     on unix sockets, for one thing (and managing a namespace for unix
00375     sockets can be easier than managing a bunch of TCP port numbers).
00376 
00377     Unfortunately, there's no way to specify a unix socket in a url for
00378     an HTTP client, so we have to test this by hand.
00379     """
00380     def setUp(self):
00381         super(UnixSocketTest, self).setUp()
00382         self.tmpdir = tempfile.mkdtemp()
00383 
00384     def tearDown(self):
00385         shutil.rmtree(self.tmpdir)
00386         super(UnixSocketTest, self).tearDown()
00387 
00388     def test_unix_socket(self):
00389         sockfile = os.path.join(self.tmpdir, "test.sock")
00390         sock = netutil.bind_unix_socket(sockfile)
00391         app = Application([("/hello", HelloWorldRequestHandler)])
00392         server = HTTPServer(app, io_loop=self.io_loop)
00393         server.add_socket(sock)
00394         stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
00395         stream.connect(sockfile, self.stop)
00396         self.wait()
00397         stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
00398         stream.read_until(b("\r\n"), self.stop)
00399         response = self.wait()
00400         self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
00401         stream.read_until(b("\r\n\r\n"), self.stop)
00402         headers = HTTPHeaders.parse(self.wait().decode('latin1'))
00403         stream.read_bytes(int(headers["Content-Length"]), self.stop)
00404         body = self.wait()
00405         self.assertEqual(body, b("Hello world"))
00406         stream.close()
00407         server.stop()
00408 
00409 if not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin':
00410     del UnixSocketTest