00001
00002 from __future__ import absolute_import, division, print_function, with_statement
00003
00004 from tornado import gen
00005 from tornado.log import app_log
00006 from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
00007 ExceptionStackContext, run_with_stack_context, _state)
00008 from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
00009 from tornado.test.util import unittest
00010 from tornado.web import asynchronous, Application, RequestHandler
00011 import contextlib
00012 import functools
00013 import logging
00014
00015
00016 class TestRequestHandler(RequestHandler):
00017 def __init__(self, app, request, io_loop):
00018 super(TestRequestHandler, self).__init__(app, request)
00019 self.io_loop = io_loop
00020
00021 @asynchronous
00022 def get(self):
00023 logging.debug('in get()')
00024
00025
00026 self.io_loop.add_callback(self.part2)
00027
00028 def part2(self):
00029 logging.debug('in part2()')
00030
00031
00032 self.io_loop.add_callback(self.part3)
00033
00034 def part3(self):
00035 logging.debug('in part3()')
00036 raise Exception('test exception')
00037
00038 def write_error(self, status_code, **kwargs):
00039 if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
00040 self.write('got expected exception')
00041 else:
00042 self.write('unexpected failure')
00043
00044
00045 class HTTPStackContextTest(AsyncHTTPTestCase):
00046 def get_app(self):
00047 return Application([('/', TestRequestHandler,
00048 dict(io_loop=self.io_loop))])
00049
00050 def test_stack_context(self):
00051 with ExpectLog(app_log, "Uncaught exception GET /"):
00052 self.http_client.fetch(self.get_url('/'), self.handle_response)
00053 self.wait()
00054 self.assertEqual(self.response.code, 500)
00055 self.assertTrue(b'got expected exception' in self.response.body)
00056
00057 def handle_response(self, response):
00058 self.response = response
00059 self.stop()
00060
00061
00062 class StackContextTest(AsyncTestCase):
00063 def setUp(self):
00064 super(StackContextTest, self).setUp()
00065 self.active_contexts = []
00066
00067 @contextlib.contextmanager
00068 def context(self, name):
00069 self.active_contexts.append(name)
00070 yield
00071 self.assertEqual(self.active_contexts.pop(), name)
00072
00073
00074
00075 def test_exit_library_context(self):
00076 def library_function(callback):
00077
00078 callback = wrap(callback)
00079 with StackContext(functools.partial(self.context, 'library')):
00080 self.io_loop.add_callback(
00081 functools.partial(library_inner_callback, callback))
00082
00083 def library_inner_callback(callback):
00084 self.assertEqual(self.active_contexts[-2:],
00085 ['application', 'library'])
00086 callback()
00087
00088 def final_callback():
00089
00090
00091
00092
00093 self.assertEqual(self.active_contexts[-1], 'application')
00094 self.stop()
00095 with StackContext(functools.partial(self.context, 'application')):
00096 library_function(final_callback)
00097 self.wait()
00098
00099 def test_deactivate(self):
00100 deactivate_callbacks = []
00101
00102 def f1():
00103 with StackContext(functools.partial(self.context, 'c1')) as c1:
00104 deactivate_callbacks.append(c1)
00105 self.io_loop.add_callback(f2)
00106
00107 def f2():
00108 with StackContext(functools.partial(self.context, 'c2')) as c2:
00109 deactivate_callbacks.append(c2)
00110 self.io_loop.add_callback(f3)
00111
00112 def f3():
00113 with StackContext(functools.partial(self.context, 'c3')) as c3:
00114 deactivate_callbacks.append(c3)
00115 self.io_loop.add_callback(f4)
00116
00117 def f4():
00118 self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
00119 deactivate_callbacks[1]()
00120
00121
00122 self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
00123 self.io_loop.add_callback(f5)
00124
00125 def f5():
00126 self.assertEqual(self.active_contexts, ['c1', 'c3'])
00127 self.stop()
00128 self.io_loop.add_callback(f1)
00129 self.wait()
00130
00131 def test_deactivate_order(self):
00132
00133
00134 def check_contexts():
00135
00136
00137 full_contexts, chain = _state.contexts
00138 exception_contexts = []
00139 while chain is not None:
00140 exception_contexts.append(chain)
00141 chain = chain.old_contexts[1]
00142 self.assertEqual(list(reversed(full_contexts)), exception_contexts)
00143 return list(self.active_contexts)
00144
00145 def make_wrapped_function():
00146 """Wraps a function in three stack contexts, and returns
00147 the function along with the deactivation functions.
00148 """
00149
00150
00151 with NullContext():
00152 partial = functools.partial
00153 with StackContext(partial(self.context, 'c0')) as c0:
00154 with StackContext(partial(self.context, 'c1')) as c1:
00155 with StackContext(partial(self.context, 'c2')) as c2:
00156 return (wrap(check_contexts), [c0, c1, c2])
00157
00158
00159 func, deactivate_callbacks = make_wrapped_function()
00160 self.assertEqual(func(), ['c0', 'c1', 'c2'])
00161
00162
00163 func, deactivate_callbacks = make_wrapped_function()
00164 deactivate_callbacks[0]()
00165 self.assertEqual(func(), ['c1', 'c2'])
00166
00167
00168 func, deactivate_callbacks = make_wrapped_function()
00169 deactivate_callbacks[1]()
00170 self.assertEqual(func(), ['c0', 'c2'])
00171
00172
00173 func, deactivate_callbacks = make_wrapped_function()
00174 deactivate_callbacks[2]()
00175 self.assertEqual(func(), ['c0', 'c1'])
00176
00177 def test_isolation_nonempty(self):
00178
00179
00180
00181 def f1():
00182 with StackContext(functools.partial(self.context, 'c1')):
00183 wrapped = wrap(f2)
00184 with StackContext(functools.partial(self.context, 'c2')):
00185 wrapped()
00186
00187 def f2():
00188 self.assertIn('c1', self.active_contexts)
00189 self.io_loop.add_callback(f3)
00190
00191 def f3():
00192 self.assertIn('c1', self.active_contexts)
00193 self.assertNotIn('c2', self.active_contexts)
00194 self.stop()
00195
00196 self.io_loop.add_callback(f1)
00197 self.wait()
00198
00199 def test_isolation_empty(self):
00200
00201
00202
00203 def f1():
00204 with NullContext():
00205 wrapped = wrap(f2)
00206 with StackContext(functools.partial(self.context, 'c2')):
00207 wrapped()
00208
00209 def f2():
00210 self.io_loop.add_callback(f3)
00211
00212 def f3():
00213 self.assertNotIn('c2', self.active_contexts)
00214 self.stop()
00215
00216 self.io_loop.add_callback(f1)
00217 self.wait()
00218
00219 def test_yield_in_with(self):
00220 @gen.engine
00221 def f():
00222 self.callback = yield gen.Callback('a')
00223 with StackContext(functools.partial(self.context, 'c1')):
00224
00225
00226
00227
00228 yield gen.Wait('a')
00229
00230 with self.assertRaises(StackContextInconsistentError):
00231 f()
00232 self.wait()
00233
00234
00235
00236 self.callback()
00237 del self.callback
00238
00239 @gen_test
00240 def test_yield_outside_with(self):
00241
00242 cb = yield gen.Callback('k1')
00243 with StackContext(functools.partial(self.context, 'c1')):
00244 self.io_loop.add_callback(cb)
00245 yield gen.Wait('k1')
00246
00247 def test_yield_in_with_exception_stack_context(self):
00248
00249 @gen.engine
00250 def f():
00251 with ExceptionStackContext(lambda t, v, tb: False):
00252 yield gen.Task(self.io_loop.add_callback)
00253
00254 with self.assertRaises(StackContextInconsistentError):
00255 f()
00256 self.wait()
00257
00258 @gen_test
00259 def test_yield_outside_with_exception_stack_context(self):
00260 cb = yield gen.Callback('k1')
00261 with ExceptionStackContext(lambda t, v, tb: False):
00262 self.io_loop.add_callback(cb)
00263 yield gen.Wait('k1')
00264
00265 @gen_test
00266 def test_run_with_stack_context(self):
00267 @gen.coroutine
00268 def f1():
00269 self.assertEqual(self.active_contexts, ['c1'])
00270 yield run_with_stack_context(
00271 StackContext(functools.partial(self.context, 'c2')),
00272 f2)
00273 self.assertEqual(self.active_contexts, ['c1'])
00274
00275 @gen.coroutine
00276 def f2():
00277 self.assertEqual(self.active_contexts, ['c1', 'c2'])
00278 yield gen.Task(self.io_loop.add_callback)
00279 self.assertEqual(self.active_contexts, ['c1', 'c2'])
00280
00281 self.assertEqual(self.active_contexts, [])
00282 yield run_with_stack_context(
00283 StackContext(functools.partial(self.context, 'c1')),
00284 f1)
00285 self.assertEqual(self.active_contexts, [])
00286
00287 if __name__ == '__main__':
00288 unittest.main()