22 from src.proto.grpc.testing
import messages_pb2
23 from src.proto.grpc.testing
import test_pb2_grpc
31 _SHORT_TIMEOUT_S = 1.0
33 _NUM_STREAM_REQUESTS = 5
34 _REQUEST_PAYLOAD_SIZE = 7
35 _RESPONSE_INTERVAL_US =
int(_SHORT_TIMEOUT_S * 1000 * 1000)
42 return await continuation(client_call_details, request_iterator)
44 def assert_in_final_state(self, test: unittest.TestCase):
49 aio.StreamUnaryClientInterceptor):
57 def assert_in_final_state(self, test: unittest.TestCase):
58 test.assertEqual(_NUM_STREAM_REQUESTS,
67 async
def tearDown(self):
70 async
def test_intercepts(self):
71 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
72 _StreamUnaryInterceptorWithRequestIterator):
74 with self.subTest(name=interceptor_class):
75 interceptor = interceptor_class()
76 channel = aio.insecure_channel(self._server_target,
77 interceptors=[interceptor])
78 stub = test_pb2_grpc.TestServiceStub(channel)
81 _REQUEST_PAYLOAD_SIZE)
85 async
def request_iterator():
86 for _
in range(_NUM_STREAM_REQUESTS):
89 call = stub.StreamingInputCall(request_iterator())
93 self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
94 response.aggregated_payload_size)
95 self.assertEqual(await call.code(), grpc.StatusCode.OK)
96 self.assertEqual(await call.initial_metadata(), aio.Metadata())
97 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
98 self.assertEqual(await call.details(),
'')
99 self.assertEqual(await call.debug_error_string(),
'')
100 self.assertEqual(call.cancel(),
False)
101 self.assertEqual(call.cancelled(),
False)
102 self.assertEqual(call.done(),
True)
104 interceptor.assert_in_final_state(self)
106 await channel.close()
108 async
def test_intercepts_using_write(self):
109 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
110 _StreamUnaryInterceptorWithRequestIterator):
112 with self.subTest(name=interceptor_class):
113 interceptor = interceptor_class()
114 channel = aio.insecure_channel(self._server_target,
115 interceptors=[interceptor])
116 stub = test_pb2_grpc.TestServiceStub(channel)
119 _REQUEST_PAYLOAD_SIZE)
123 call = stub.StreamingInputCall()
125 for _
in range(_NUM_STREAM_REQUESTS):
126 await call.write(request)
128 await call.done_writing()
130 response = await call
132 self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
133 response.aggregated_payload_size)
134 self.assertEqual(await call.code(), grpc.StatusCode.OK)
135 self.assertEqual(await call.initial_metadata(), aio.Metadata())
136 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
137 self.assertEqual(await call.details(),
'')
138 self.assertEqual(await call.debug_error_string(),
'')
139 self.assertEqual(call.cancel(),
False)
140 self.assertEqual(call.cancelled(),
False)
141 self.assertEqual(call.done(),
True)
143 interceptor.assert_in_final_state(self)
145 await channel.close()
147 async
def test_add_done_callback_interceptor_task_not_finished(self):
148 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
149 _StreamUnaryInterceptorWithRequestIterator):
151 with self.subTest(name=interceptor_class):
152 interceptor = interceptor_class()
154 channel = aio.insecure_channel(self._server_target,
155 interceptors=[interceptor])
156 stub = test_pb2_grpc.TestServiceStub(channel)
159 _REQUEST_PAYLOAD_SIZE)
163 async
def request_iterator():
164 for _
in range(_NUM_STREAM_REQUESTS):
167 call = stub.StreamingInputCall(request_iterator())
171 response = await call
175 await channel.close()
177 async
def test_add_done_callback_interceptor_task_finished(self):
178 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
179 _StreamUnaryInterceptorWithRequestIterator):
181 with self.subTest(name=interceptor_class):
182 interceptor = interceptor_class()
184 channel = aio.insecure_channel(self._server_target,
185 interceptors=[interceptor])
186 stub = test_pb2_grpc.TestServiceStub(channel)
189 _REQUEST_PAYLOAD_SIZE)
193 async
def request_iterator():
194 for _
in range(_NUM_STREAM_REQUESTS):
197 call = stub.StreamingInputCall(request_iterator())
199 response = await call
205 await channel.close()
207 async
def test_multiple_interceptors_request_iterator(self):
208 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
209 _StreamUnaryInterceptorWithRequestIterator):
211 with self.subTest(name=interceptor_class):
213 interceptors = [interceptor_class(), interceptor_class()]
214 channel = aio.insecure_channel(self._server_target,
215 interceptors=interceptors)
216 stub = test_pb2_grpc.TestServiceStub(channel)
219 _REQUEST_PAYLOAD_SIZE)
223 async
def request_iterator():
224 for _
in range(_NUM_STREAM_REQUESTS):
227 call = stub.StreamingInputCall(request_iterator())
229 response = await call
231 self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
232 response.aggregated_payload_size)
233 self.assertEqual(await call.code(), grpc.StatusCode.OK)
234 self.assertEqual(await call.initial_metadata(), aio.Metadata())
235 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
236 self.assertEqual(await call.details(),
'')
237 self.assertEqual(await call.debug_error_string(),
'')
238 self.assertEqual(call.cancel(),
False)
239 self.assertEqual(call.cancelled(),
False)
240 self.assertEqual(call.done(),
True)
242 for interceptor
in interceptors:
243 interceptor.assert_in_final_state(self)
245 await channel.close()
247 async
def test_intercepts_request_iterator_rpc_error(self):
248 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
249 _StreamUnaryInterceptorWithRequestIterator):
251 with self.subTest(name=interceptor_class):
252 channel = aio.insecure_channel(
253 UNREACHABLE_TARGET, interceptors=[interceptor_class()])
254 stub = test_pb2_grpc.TestServiceStub(channel)
257 _REQUEST_PAYLOAD_SIZE)
263 async
def request_iterator():
264 for _
in range(_NUM_STREAM_REQUESTS):
267 call = stub.StreamingInputCall(request_iterator())
269 with self.assertRaises(aio.AioRpcError)
as exception_context:
272 self.assertEqual(grpc.StatusCode.UNAVAILABLE,
273 exception_context.exception.code())
274 self.assertTrue(call.done())
275 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
277 await channel.close()
279 async
def test_intercepts_request_iterator_rpc_error_using_write(self):
280 for interceptor_class
in (_StreamUnaryInterceptorEmpty,
281 _StreamUnaryInterceptorWithRequestIterator):
283 with self.subTest(name=interceptor_class):
284 channel = aio.insecure_channel(
285 UNREACHABLE_TARGET, interceptors=[interceptor_class()])
286 stub = test_pb2_grpc.TestServiceStub(channel)
289 _REQUEST_PAYLOAD_SIZE)
293 call = stub.StreamingInputCall()
296 with self.assertRaises(asyncio.InvalidStateError):
297 for _
in range(_NUM_STREAM_REQUESTS):
298 await call.write(request)
300 with self.assertRaises(aio.AioRpcError)
as exception_context:
303 self.assertEqual(grpc.StatusCode.UNAVAILABLE,
304 exception_context.exception.code())
305 self.assertTrue(call.done())
306 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
308 await channel.close()
310 async
def test_cancel_before_rpc(self):
312 interceptor_reached = asyncio.Event()
313 wait_for_ever = self.
loop.create_future()
315 class Interceptor(aio.StreamUnaryClientInterceptor):
317 async
def intercept_stream_unary(self, continuation,
320 interceptor_reached.set()
323 channel = aio.insecure_channel(self._server_target,
324 interceptors=[Interceptor()])
325 stub = test_pb2_grpc.TestServiceStub(channel)
330 call = stub.StreamingInputCall()
332 self.assertFalse(call.cancelled())
333 self.assertFalse(call.done())
335 await interceptor_reached.wait()
336 self.assertTrue(call.cancel())
339 with self.assertRaises(asyncio.InvalidStateError):
340 for _
in range(_NUM_STREAM_REQUESTS):
341 await call.write(request)
343 with self.assertRaises(asyncio.CancelledError):
346 self.assertTrue(call.cancelled())
347 self.assertTrue(call.done())
348 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
349 self.assertEqual(await call.initial_metadata(),
None)
350 self.assertEqual(await call.trailing_metadata(),
None)
351 await channel.close()
353 async
def test_cancel_after_rpc(self):
355 interceptor_reached = asyncio.Event()
356 wait_for_ever = self.
loop.create_future()
358 class Interceptor(aio.StreamUnaryClientInterceptor):
360 async
def intercept_stream_unary(self, continuation,
363 call = await continuation(client_call_details, request_iterator)
364 interceptor_reached.set()
367 channel = aio.insecure_channel(self._server_target,
368 interceptors=[Interceptor()])
369 stub = test_pb2_grpc.TestServiceStub(channel)
374 call = stub.StreamingInputCall()
376 self.assertFalse(call.cancelled())
377 self.assertFalse(call.done())
379 await interceptor_reached.wait()
380 self.assertTrue(call.cancel())
383 with self.assertRaises(asyncio.InvalidStateError):
384 for _
in range(_NUM_STREAM_REQUESTS):
385 await call.write(request)
387 with self.assertRaises(asyncio.CancelledError):
390 self.assertTrue(call.cancelled())
391 self.assertTrue(call.done())
392 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
393 self.assertEqual(await call.initial_metadata(),
None)
394 self.assertEqual(await call.trailing_metadata(),
None)
395 await channel.close()
397 async
def test_cancel_while_writing(self):
399 for num_writes_before_cancel
in (0, 1):
400 with self.subTest(name=
"Num writes before cancel: {}".
format(
401 num_writes_before_cancel)):
403 channel = aio.insecure_channel(
406 stub = test_pb2_grpc.TestServiceStub(channel)
409 _REQUEST_PAYLOAD_SIZE)
413 call = stub.StreamingInputCall()
415 with self.assertRaises(asyncio.InvalidStateError):
416 for i
in range(_NUM_STREAM_REQUESTS):
417 if i == num_writes_before_cancel:
418 self.assertTrue(call.cancel())
419 await call.write(request)
421 with self.assertRaises(asyncio.CancelledError):
424 self.assertTrue(call.cancelled())
425 self.assertTrue(call.done())
426 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
428 await channel.close()
430 async
def test_cancel_by_the_interceptor(self):
432 class Interceptor(aio.StreamUnaryClientInterceptor):
434 async
def intercept_stream_unary(self, continuation,
437 call = await continuation(client_call_details, request_iterator)
441 channel = aio.insecure_channel(UNREACHABLE_TARGET,
442 interceptors=[Interceptor()])
443 stub = test_pb2_grpc.TestServiceStub(channel)
448 call = stub.StreamingInputCall()
450 with self.assertRaises(asyncio.InvalidStateError):
451 for i
in range(_NUM_STREAM_REQUESTS):
452 await call.write(request)
454 with self.assertRaises(asyncio.CancelledError):
457 self.assertTrue(call.cancelled())
458 self.assertTrue(call.done())
459 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
461 await channel.close()
463 async
def test_exception_raised_by_interceptor(self):
465 class InterceptorException(Exception):
468 class Interceptor(aio.StreamUnaryClientInterceptor):
470 async
def intercept_stream_unary(self, continuation,
473 raise InterceptorException
475 channel = aio.insecure_channel(UNREACHABLE_TARGET,
476 interceptors=[Interceptor()])
477 stub = test_pb2_grpc.TestServiceStub(channel)
482 call = stub.StreamingInputCall()
484 with self.assertRaises(InterceptorException):
485 for i
in range(_NUM_STREAM_REQUESTS):
486 await call.write(request)
488 with self.assertRaises(InterceptorException):
491 await channel.close()
493 async
def test_intercepts_prohibit_mixing_style(self):
494 channel = aio.insecure_channel(
496 stub = test_pb2_grpc.TestServiceStub(channel)
501 async
def request_iterator():
502 for _
in range(_NUM_STREAM_REQUESTS):
505 call = stub.StreamingInputCall(request_iterator())
507 with self.assertRaises(grpc._cython.cygrpc.UsageError):
508 await call.write(request)
510 with self.assertRaises(grpc._cython.cygrpc.UsageError):
511 await call.done_writing()
513 await channel.close()
516 if __name__ ==
'__main__':
517 logging.basicConfig(level=logging.DEBUG)
518 unittest.main(verbosity=2)