21 from src.proto.grpc.testing
import messages_pb2
22 from src.proto.grpc.testing
import test_pb2_grpc
30 _LOCAL_CANCEL_DETAILS_EXPECTATION =
'Locally cancelled by application!'
31 _INITIAL_METADATA_TO_INJECT = aio.Metadata(
32 (_INITIAL_METADATA_KEY,
'extra info'),
33 (_TRAILING_METADATA_KEY, b
'\x13\x37'),
35 _TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED = 1.0
43 async
def tearDown(self):
46 def test_invalid_interceptor(self):
48 class InvalidInterceptor:
49 """Just an invalid Interceptor"""
51 with self.assertRaises(ValueError):
52 aio.insecure_channel(
"", interceptors=[InvalidInterceptor()])
54 async
def test_executed_right_order(self):
56 interceptors_executed = []
58 class Interceptor(aio.UnaryUnaryClientInterceptor):
59 """Interceptor used for testing if the interceptor is being called"""
61 async
def intercept_unary_unary(self, continuation,
62 client_call_details, request):
63 interceptors_executed.append(self)
64 call = await continuation(client_call_details, request)
67 interceptors = [Interceptor()
for i
in range(2)]
69 async
with aio.insecure_channel(self._server_target,
70 interceptors=interceptors)
as channel:
71 multicallable = channel.unary_unary(
72 '/grpc.testing.TestService/UnaryCall',
73 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
74 response_deserializer=messages_pb2.SimpleResponse.FromString)
80 self.assertSequenceEqual(interceptors_executed, interceptors)
82 self.assertIsInstance(response, messages_pb2.SimpleResponse)
84 @unittest.expectedFailure
87 def test_modify_metadata(self):
88 raise NotImplementedError()
90 @unittest.expectedFailure
93 def test_modify_credentials(self):
94 raise NotImplementedError()
96 async
def test_status_code_Ok(self):
98 class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor):
99 """Interceptor used for observing status code Ok returned by the RPC"""
104 async
def intercept_unary_unary(self, continuation,
105 client_call_details, request):
106 call = await continuation(client_call_details, request)
107 code = await call.code()
108 if code == grpc.StatusCode.OK:
113 interceptor = StatusCodeOkInterceptor()
115 async
with aio.insecure_channel(self._server_target,
116 interceptors=[interceptor])
as channel:
119 multicallable = channel.unary_unary(
120 '/grpc.testing.TestService/UnaryCall',
121 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
122 response_deserializer=messages_pb2.SimpleResponse.FromString)
126 self.assertTrue(interceptor.status_code_Ok_observed)
128 async
def test_add_timeout(self):
130 class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
131 """Interceptor used for adding a timeout to the RPC"""
133 async
def intercept_unary_unary(self, continuation,
134 client_call_details, request):
135 new_client_call_details = aio.ClientCallDetails(
136 method=client_call_details.method,
137 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
138 metadata=client_call_details.metadata,
139 credentials=client_call_details.credentials,
140 wait_for_ready=client_call_details.wait_for_ready)
141 return await continuation(new_client_call_details, request)
143 interceptor = TimeoutInterceptor()
145 async
with aio.insecure_channel(self._server_target,
146 interceptors=[interceptor])
as channel:
148 multicallable = channel.unary_unary(
149 '/grpc.testing.TestService/UnaryCallWithSleep',
150 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
151 response_deserializer=messages_pb2.SimpleResponse.FromString)
155 with self.assertRaises(aio.AioRpcError)
as exception_context:
158 self.assertEqual(exception_context.exception.code(),
159 grpc.StatusCode.DEADLINE_EXCEEDED)
161 self.assertTrue(call.done())
162 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
167 class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
168 """Simulates a Retry Interceptor which ends up by making
174 async
def intercept_unary_unary(self, continuation,
175 client_call_details, request):
177 new_client_call_details = aio.ClientCallDetails(
178 method=client_call_details.method,
179 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
180 metadata=client_call_details.metadata,
181 credentials=client_call_details.credentials,
182 wait_for_ready=client_call_details.wait_for_ready)
185 call = await continuation(new_client_call_details, request)
190 self.
calls.append(call)
192 new_client_call_details = aio.ClientCallDetails(
193 method=client_call_details.method,
195 metadata=client_call_details.metadata,
196 credentials=client_call_details.credentials,
197 wait_for_ready=client_call_details.wait_for_ready)
199 call = await continuation(new_client_call_details, request)
200 self.
calls.append(call)
203 interceptor = RetryInterceptor()
205 async
with aio.insecure_channel(self._server_target,
206 interceptors=[interceptor])
as channel:
208 multicallable = channel.unary_unary(
209 '/grpc.testing.TestService/UnaryCallWithSleep',
210 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
211 response_deserializer=messages_pb2.SimpleResponse.FromString)
217 self.assertEqual(grpc.StatusCode.OK, await call.code())
221 self.assertEqual(
len(interceptor.calls), 2)
222 self.assertEqual(await interceptor.calls[0].
code(),
223 grpc.StatusCode.DEADLINE_EXCEEDED)
224 self.assertEqual(await interceptor.calls[1].
code(),
227 async
def test_rpcresponse(self):
229 class Interceptor(aio.UnaryUnaryClientInterceptor):
230 """Raw responses are seen as reegular calls"""
232 async
def intercept_unary_unary(self, continuation,
233 client_call_details, request):
234 call = await continuation(client_call_details, request)
235 response = await call
238 class ResponseInterceptor(aio.UnaryUnaryClientInterceptor):
239 """Return a raw response"""
242 async
def intercept_unary_unary(self, continuation,
243 client_call_details, request):
244 return ResponseInterceptor.response
246 interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
248 async
with aio.insecure_channel(
250 interceptors=[interceptor, interceptor_response])
as channel:
252 multicallable = channel.unary_unary(
253 '/grpc.testing.TestService/UnaryCall',
254 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
255 response_deserializer=messages_pb2.SimpleResponse.FromString)
258 response = await call
262 self.assertEqual(
id(response),
id(ResponseInterceptor.response))
265 self.assertTrue(call.done())
266 self.assertFalse(call.cancel())
267 self.assertFalse(call.cancelled())
268 self.assertEqual(await call.code(), grpc.StatusCode.OK)
269 self.assertEqual(await call.details(),
'')
270 self.assertEqual(await call.initial_metadata(),
None)
271 self.assertEqual(await call.trailing_metadata(),
None)
272 self.assertEqual(await call.debug_error_string(),
None)
280 async
def tearDown(self):
283 async
def test_call_ok(self):
285 class Interceptor(aio.UnaryUnaryClientInterceptor):
287 async
def intercept_unary_unary(self, continuation,
288 client_call_details, request):
289 call = await continuation(client_call_details, request)
292 async
with aio.insecure_channel(self._server_target,
293 interceptors=[Interceptor()
296 multicallable = channel.unary_unary(
297 '/grpc.testing.TestService/UnaryCall',
298 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
299 response_deserializer=messages_pb2.SimpleResponse.FromString)
301 response = await call
303 self.assertTrue(call.done())
304 self.assertFalse(call.cancelled())
305 self.assertEqual(
type(response), messages_pb2.SimpleResponse)
306 self.assertEqual(await call.code(), grpc.StatusCode.OK)
307 self.assertEqual(await call.details(),
'')
308 self.assertEqual(await call.initial_metadata(), aio.Metadata())
309 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
311 async
def test_call_ok_awaited(self):
313 class Interceptor(aio.UnaryUnaryClientInterceptor):
315 async
def intercept_unary_unary(self, continuation,
316 client_call_details, request):
317 call = await continuation(client_call_details, request)
321 async
with aio.insecure_channel(self._server_target,
322 interceptors=[Interceptor()
325 multicallable = channel.unary_unary(
326 '/grpc.testing.TestService/UnaryCall',
327 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
328 response_deserializer=messages_pb2.SimpleResponse.FromString)
330 response = await call
332 self.assertTrue(call.done())
333 self.assertFalse(call.cancelled())
334 self.assertEqual(
type(response), messages_pb2.SimpleResponse)
335 self.assertEqual(await call.code(), grpc.StatusCode.OK)
336 self.assertEqual(await call.details(),
'')
337 self.assertEqual(await call.initial_metadata(), aio.Metadata())
338 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
340 async
def test_call_rpc_error(self):
342 class Interceptor(aio.UnaryUnaryClientInterceptor):
344 async
def intercept_unary_unary(self, continuation,
345 client_call_details, request):
346 call = await continuation(client_call_details, request)
349 async
with aio.insecure_channel(self._server_target,
350 interceptors=[Interceptor()
353 multicallable = channel.unary_unary(
354 '/grpc.testing.TestService/UnaryCallWithSleep',
355 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
356 response_deserializer=messages_pb2.SimpleResponse.FromString)
358 call = multicallable(
360 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
362 with self.assertRaises(aio.AioRpcError)
as exception_context:
365 self.assertTrue(call.done())
366 self.assertFalse(call.cancelled())
367 self.assertEqual(await call.code(),
368 grpc.StatusCode.DEADLINE_EXCEEDED)
369 self.assertEqual(await call.details(),
'Deadline Exceeded')
370 self.assertEqual(await call.initial_metadata(), aio.Metadata())
371 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
373 async
def test_call_rpc_error_awaited(self):
375 class Interceptor(aio.UnaryUnaryClientInterceptor):
377 async
def intercept_unary_unary(self, continuation,
378 client_call_details, request):
379 call = await continuation(client_call_details, request)
383 async
with aio.insecure_channel(self._server_target,
384 interceptors=[Interceptor()
387 multicallable = channel.unary_unary(
388 '/grpc.testing.TestService/UnaryCallWithSleep',
389 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
390 response_deserializer=messages_pb2.SimpleResponse.FromString)
392 call = multicallable(
394 timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
396 with self.assertRaises(aio.AioRpcError)
as exception_context:
399 self.assertTrue(call.done())
400 self.assertFalse(call.cancelled())
401 self.assertEqual(await call.code(),
402 grpc.StatusCode.DEADLINE_EXCEEDED)
403 self.assertEqual(await call.details(),
'Deadline Exceeded')
404 self.assertEqual(await call.initial_metadata(), aio.Metadata())
405 self.assertEqual(await call.trailing_metadata(), aio.Metadata())
407 async
def test_cancel_before_rpc(self):
409 interceptor_reached = asyncio.Event()
410 wait_for_ever = self.
loop.create_future()
412 class Interceptor(aio.UnaryUnaryClientInterceptor):
414 async
def intercept_unary_unary(self, continuation,
415 client_call_details, request):
416 interceptor_reached.set()
419 async
with aio.insecure_channel(self._server_target,
420 interceptors=[Interceptor()
423 multicallable = channel.unary_unary(
424 '/grpc.testing.TestService/UnaryCall',
425 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
426 response_deserializer=messages_pb2.SimpleResponse.FromString)
429 self.assertFalse(call.cancelled())
430 self.assertFalse(call.done())
432 await interceptor_reached.wait()
433 self.assertTrue(call.cancel())
435 with self.assertRaises(asyncio.CancelledError):
438 self.assertTrue(call.cancelled())
439 self.assertTrue(call.done())
440 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
441 self.assertEqual(await call.details(),
442 _LOCAL_CANCEL_DETAILS_EXPECTATION)
443 self.assertEqual(await call.initial_metadata(),
None)
444 self.assertEqual(await call.trailing_metadata(),
None)
446 async
def test_cancel_after_rpc(self):
448 interceptor_reached = asyncio.Event()
449 wait_for_ever = self.
loop.create_future()
451 class Interceptor(aio.UnaryUnaryClientInterceptor):
453 async
def intercept_unary_unary(self, continuation,
454 client_call_details, request):
455 call = await continuation(client_call_details, request)
457 interceptor_reached.set()
460 async
with aio.insecure_channel(self._server_target,
461 interceptors=[Interceptor()
464 multicallable = channel.unary_unary(
465 '/grpc.testing.TestService/UnaryCall',
466 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
467 response_deserializer=messages_pb2.SimpleResponse.FromString)
470 self.assertFalse(call.cancelled())
471 self.assertFalse(call.done())
473 await interceptor_reached.wait()
474 self.assertTrue(call.cancel())
476 with self.assertRaises(asyncio.CancelledError):
479 self.assertTrue(call.cancelled())
480 self.assertTrue(call.done())
481 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
482 self.assertEqual(await call.details(),
483 _LOCAL_CANCEL_DETAILS_EXPECTATION)
484 self.assertEqual(await call.initial_metadata(),
None)
485 self.assertEqual(await call.trailing_metadata(),
None)
487 async
def test_cancel_inside_interceptor_after_rpc_awaiting(self):
489 class Interceptor(aio.UnaryUnaryClientInterceptor):
491 async
def intercept_unary_unary(self, continuation,
492 client_call_details, request):
493 call = await continuation(client_call_details, request)
498 async
with aio.insecure_channel(self._server_target,
499 interceptors=[Interceptor()
502 multicallable = channel.unary_unary(
503 '/grpc.testing.TestService/UnaryCall',
504 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
505 response_deserializer=messages_pb2.SimpleResponse.FromString)
508 with self.assertRaises(asyncio.CancelledError):
511 self.assertTrue(call.cancelled())
512 self.assertTrue(call.done())
513 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
514 self.assertEqual(await call.details(),
515 _LOCAL_CANCEL_DETAILS_EXPECTATION)
516 self.assertEqual(await call.initial_metadata(),
None)
517 self.assertEqual(await call.trailing_metadata(),
None)
519 async
def test_cancel_inside_interceptor_after_rpc_not_awaiting(self):
521 class Interceptor(aio.UnaryUnaryClientInterceptor):
523 async
def intercept_unary_unary(self, continuation,
524 client_call_details, request):
525 call = await continuation(client_call_details, request)
529 async
with aio.insecure_channel(self._server_target,
530 interceptors=[Interceptor()
533 multicallable = channel.unary_unary(
534 '/grpc.testing.TestService/UnaryCall',
535 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
536 response_deserializer=messages_pb2.SimpleResponse.FromString)
539 with self.assertRaises(asyncio.CancelledError):
542 self.assertTrue(call.cancelled())
543 self.assertTrue(call.done())
544 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
545 self.assertEqual(await call.details(),
546 _LOCAL_CANCEL_DETAILS_EXPECTATION)
547 self.assertEqual(await call.initial_metadata(), aio.Metadata())
549 await call.trailing_metadata(), aio.Metadata(),
550 "When the raw response is None, empty metadata is returned")
552 async
def test_initial_metadata_modification(self):
554 class Interceptor(aio.UnaryUnaryClientInterceptor):
556 async
def intercept_unary_unary(self, continuation,
557 client_call_details, request):
558 new_metadata = aio.Metadata(*client_call_details.metadata,
559 *_INITIAL_METADATA_TO_INJECT)
560 new_details = aio.ClientCallDetails(
561 method=client_call_details.method,
562 timeout=client_call_details.timeout,
563 metadata=new_metadata,
564 credentials=client_call_details.credentials,
565 wait_for_ready=client_call_details.wait_for_ready,
567 return await continuation(new_details, request)
569 async
with aio.insecure_channel(self._server_target,
570 interceptors=[Interceptor()
572 stub = test_pb2_grpc.TestServiceStub(channel)
577 _common.seen_metadatum(
578 expected_key=_INITIAL_METADATA_KEY,
579 expected_value=_INITIAL_METADATA_TO_INJECT[
580 _INITIAL_METADATA_KEY],
581 actual=await call.initial_metadata(),
585 _common.seen_metadatum(
586 expected_key=_TRAILING_METADATA_KEY,
587 expected_value=_INITIAL_METADATA_TO_INJECT[
588 _TRAILING_METADATA_KEY],
589 actual=await call.trailing_metadata(),
591 self.assertEqual(await call.code(), grpc.StatusCode.OK)
593 async
def test_add_done_callback_before_finishes(self):
594 called = asyncio.Event()
595 interceptor_can_continue = asyncio.Event()
600 class Interceptor(aio.UnaryUnaryClientInterceptor):
602 async
def intercept_unary_unary(self, continuation,
603 client_call_details, request):
605 await interceptor_can_continue.wait()
606 call = await continuation(client_call_details, request)
609 async
with aio.insecure_channel(self._server_target,
610 interceptors=[Interceptor()
613 multicallable = channel.unary_unary(
614 '/grpc.testing.TestService/UnaryCall',
615 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
616 response_deserializer=messages_pb2.SimpleResponse.FromString)
618 call.add_done_callback(callback)
619 interceptor_can_continue.set()
623 await asyncio.wait_for(
625 timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED)
627 self.fail(
"Callback was not called")
629 async
def test_add_done_callback_after_finishes(self):
630 called = asyncio.Event()
635 class Interceptor(aio.UnaryUnaryClientInterceptor):
637 async
def intercept_unary_unary(self, continuation,
638 client_call_details, request):
640 call = await continuation(client_call_details, request)
643 async
with aio.insecure_channel(self._server_target,
644 interceptors=[Interceptor()
647 multicallable = channel.unary_unary(
648 '/grpc.testing.TestService/UnaryCall',
649 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
650 response_deserializer=messages_pb2.SimpleResponse.FromString)
655 call.add_done_callback(callback)
658 await asyncio.wait_for(
660 timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED)
662 self.fail(
"Callback was not called")
664 async
def test_add_done_callback_after_finishes_before_await(self):
665 called = asyncio.Event()
670 class Interceptor(aio.UnaryUnaryClientInterceptor):
672 async
def intercept_unary_unary(self, continuation,
673 client_call_details, request):
675 call = await continuation(client_call_details, request)
678 async
with aio.insecure_channel(self._server_target,
679 interceptors=[Interceptor()
682 multicallable = channel.unary_unary(
683 '/grpc.testing.TestService/UnaryCall',
684 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
685 response_deserializer=messages_pb2.SimpleResponse.FromString)
688 call.add_done_callback(callback)
693 await asyncio.wait_for(
695 timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED)
697 self.fail(
"Callback was not called")
700 if __name__ ==
'__main__':
701 logging.basicConfig(level=logging.DEBUG)
702 unittest.main(verbosity=2)