00001 from __future__ import absolute_import, division, print_function, with_statement
00002
00003 from hashlib import md5
00004
00005 from tornado.escape import utf8
00006 from tornado.httpclient import HTTPRequest
00007 from tornado.stack_context import ExceptionStackContext
00008 from tornado.testing import AsyncHTTPTestCase
00009 from tornado.test import httpclient_test
00010 from tornado.test.util import unittest
00011 from tornado.web import Application, RequestHandler
00012
00013 try:
00014 import pycurl
00015 except ImportError:
00016 pycurl = None
00017
00018 if pycurl is not None:
00019 from tornado.curl_httpclient import CurlAsyncHTTPClient
00020
00021
00022 @unittest.skipIf(pycurl is None, "pycurl module not present")
00023 class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
00024 def get_http_client(self):
00025 client = CurlAsyncHTTPClient(io_loop=self.io_loop,
00026 defaults=dict(allow_ipv6=False))
00027
00028 self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
00029 return client
00030
00031
00032 class DigestAuthHandler(RequestHandler):
00033 def get(self):
00034 realm = 'test'
00035 opaque = 'asdf'
00036
00037 nonce = "1234"
00038 username = 'foo'
00039 password = 'bar'
00040
00041 auth_header = self.request.headers.get('Authorization', None)
00042 if auth_header is not None:
00043 auth_mode, params = auth_header.split(' ', 1)
00044 assert auth_mode == 'Digest'
00045 param_dict = {}
00046 for pair in params.split(','):
00047 k, v = pair.strip().split('=', 1)
00048 if v[0] == '"' and v[-1] == '"':
00049 v = v[1:-1]
00050 param_dict[k] = v
00051 assert param_dict['realm'] == realm
00052 assert param_dict['opaque'] == opaque
00053 assert param_dict['nonce'] == nonce
00054 assert param_dict['username'] == username
00055 assert param_dict['uri'] == self.request.path
00056 h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
00057 h2 = md5(utf8('%s:%s' % (self.request.method,
00058 self.request.path))).hexdigest()
00059 digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
00060 if digest == param_dict['response']:
00061 self.write('ok')
00062 else:
00063 self.write('fail')
00064 else:
00065 self.set_status(401)
00066 self.set_header('WWW-Authenticate',
00067 'Digest realm="%s", nonce="%s", opaque="%s"' %
00068 (realm, nonce, opaque))
00069
00070
00071 class CustomReasonHandler(RequestHandler):
00072 def get(self):
00073 self.set_status(200, "Custom reason")
00074
00075
00076 class CustomFailReasonHandler(RequestHandler):
00077 def get(self):
00078 self.set_status(400, "Custom reason")
00079
00080
00081 @unittest.skipIf(pycurl is None, "pycurl module not present")
00082 class CurlHTTPClientTestCase(AsyncHTTPTestCase):
00083 def setUp(self):
00084 super(CurlHTTPClientTestCase, self).setUp()
00085 self.http_client = CurlAsyncHTTPClient(self.io_loop,
00086 defaults=dict(allow_ipv6=False))
00087
00088 def get_app(self):
00089 return Application([
00090 ('/digest', DigestAuthHandler),
00091 ('/custom_reason', CustomReasonHandler),
00092 ('/custom_fail_reason', CustomFailReasonHandler),
00093 ])
00094
00095 def test_prepare_curl_callback_stack_context(self):
00096 exc_info = []
00097
00098 def error_handler(typ, value, tb):
00099 exc_info.append((typ, value, tb))
00100 self.stop()
00101 return True
00102
00103 with ExceptionStackContext(error_handler):
00104 request = HTTPRequest(self.get_url('/'),
00105 prepare_curl_callback=lambda curl: 1 / 0)
00106 self.http_client.fetch(request, callback=self.stop)
00107 self.wait()
00108 self.assertEqual(1, len(exc_info))
00109 self.assertIs(exc_info[0][0], ZeroDivisionError)
00110
00111 def test_digest_auth(self):
00112 response = self.fetch('/digest', auth_mode='digest',
00113 auth_username='foo', auth_password='bar')
00114 self.assertEqual(response.body, b'ok')
00115
00116 def test_custom_reason(self):
00117 response = self.fetch('/custom_reason')
00118 self.assertEqual(response.reason, "Custom reason")
00119
00120 def test_fail_custom_reason(self):
00121 response = self.fetch('/custom_fail_reason')
00122 self.assertEqual(str(response.error), "HTTP 400: Custom reason")