14 """Test of gRPC Python interceptors."""
17 from concurrent
import futures
31 _SERIALIZE_REQUEST =
lambda bytestring: bytestring * 2
32 _DESERIALIZE_REQUEST =
lambda bytestring: bytestring[
len(bytestring) // 2:]
33 _SERIALIZE_RESPONSE =
lambda bytestring: bytestring * 3
34 _DESERIALIZE_RESPONSE =
lambda bytestring: bytestring[:
len(bytestring) // 3]
36 _EXCEPTION_REQUEST = b
'\x09\x0a'
38 _UNARY_UNARY =
'/test/UnaryUnary'
39 _UNARY_STREAM =
'/test/UnaryStream'
40 _STREAM_UNARY =
'/test/StreamUnary'
41 _STREAM_STREAM =
'/test/StreamStream'
75 if servicer_context
is not None:
76 servicer_context.set_trailing_metadata(((
80 if request == _EXCEPTION_REQUEST:
85 if request == _EXCEPTION_REQUEST:
87 for _
in range(test_constants.STREAM_LENGTH):
91 if servicer_context
is not None:
92 servicer_context.set_trailing_metadata(((
98 if servicer_context
is not None:
99 servicer_context.invocation_metadata()
101 response_elements = []
102 for request
in request_iterator:
104 response_elements.append(request)
106 if servicer_context
is not None:
107 servicer_context.set_trailing_metadata(((
111 if _EXCEPTION_REQUEST
in response_elements:
113 return b
''.join(response_elements)
117 if servicer_context
is not None:
118 servicer_context.set_trailing_metadata(((
122 for request
in request_iterator:
123 if request == _EXCEPTION_REQUEST:
132 def __init__(self, request_streaming, response_streaming,
133 request_deserializer, response_serializer, unary_unary,
134 unary_stream, stream_unary, stream_stream):
151 if handler_call_details.method == _UNARY_UNARY:
153 self.
_handler.handle_unary_unary,
None,
None,
155 elif handler_call_details.method == _UNARY_STREAM:
157 _SERIALIZE_RESPONSE,
None,
158 self.
_handler.handle_unary_stream,
None,
None)
159 elif handler_call_details.method == _STREAM_UNARY:
161 _SERIALIZE_RESPONSE,
None,
None,
162 self.
_handler.handle_stream_unary,
None)
163 elif handler_call_details.method == _STREAM_STREAM:
171 return channel.unary_unary(_UNARY_UNARY)
175 return channel.unary_stream(_UNARY_STREAM,
176 request_serializer=_SERIALIZE_REQUEST,
177 response_deserializer=_DESERIALIZE_RESPONSE)
181 return channel.stream_unary(_STREAM_UNARY,
182 request_serializer=_SERIALIZE_REQUEST,
183 response_deserializer=_DESERIALIZE_RESPONSE)
187 return channel.stream_stream(_STREAM_STREAM)
191 collections.namedtuple(
192 '_ClientCallDetails',
193 (
'method',
'timeout',
'metadata',
'credentials')),
204 self.
_fn = interceptor_function
207 new_details, new_request_iterator, postprocess = self.
_fn(
208 client_call_details,
iter((request,)),
False,
False)
209 response = continuation(new_details,
next(new_request_iterator))
210 return postprocess(response)
if postprocess
else response
214 new_details, new_request_iterator, postprocess = self.
_fn(
215 client_call_details,
iter((request,)),
False,
True)
216 response_it = continuation(new_details, new_request_iterator)
217 return postprocess(response_it)
if postprocess
else response_it
221 new_details, new_request_iterator, postprocess = self.
_fn(
222 client_call_details, request_iterator,
True,
False)
223 response = continuation(new_details,
next(new_request_iterator))
224 return postprocess(response)
if postprocess
else response
228 new_details, new_request_iterator, postprocess = self.
_fn(
229 client_call_details, request_iterator,
True,
True)
230 response_it = continuation(new_details, new_request_iterator)
231 return postprocess(response_it)
if postprocess
else response_it
245 self.
record.append(self.
tag +
':intercept_service')
246 return continuation(handler_call_details)
249 self.
record.append(self.
tag +
':intercept_unary_unary')
250 result = continuation(client_call_details, request)
254 result,
type(result))
258 result,
type(result))
263 self.
record.append(self.
tag +
':intercept_unary_stream')
264 return continuation(client_call_details, request)
268 self.
record.append(self.
tag +
':intercept_stream_unary')
269 result = continuation(client_call_details, request_iterator)
280 self.
record.append(self.
tag +
':intercept_stream_stream')
281 return continuation(client_call_details, request_iterator)
287 ignored_client_call_details, ignored_request):
288 raise test_control.Defect()
293 def intercept_call(client_call_details, request_iterator, request_streaming,
294 ignored_response_streaming):
295 if request_streaming:
296 return client_call_details,
wrapper(request_iterator),
None
298 return client_call_details, request_iterator,
None
305 def intercept_call(client_call_details, request_iterator,
306 ignored_request_streaming, ignored_response_streaming):
308 if client_call_details.metadata:
309 metadata = list(client_call_details.metadata)
315 client_call_details.method, client_call_details.timeout, metadata,
316 client_call_details.credentials)
317 return client_call_details, request_iterator,
None
328 return self.
_fn(continuation, handler_call_details)
333 def intercept_service(continuation, handler_call_details):
334 if condition(handler_call_details):
335 return interceptor.intercept_service(continuation,
336 handler_call_details)
337 return continuation(handler_call_details)
347 self.
_server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
351 lambda x: (
'secret',
'42')
in x.invocation_metadata,
355 options=((
'grpc.so_reuseport', 0),),
358 conditional_interceptor,
361 port = self.
_server.add_insecure_port(
'[::]:0')
374 def triple(request_iterator):
377 item =
next(request_iterator)
381 except StopIteration:
387 b
'\x07\x08' for _
in range(test_constants.STREAM_LENGTH))
394 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
396 responses = tuple(response_iterator)
397 self.assertEqual(
len(responses), 3 * test_constants.STREAM_LENGTH)
404 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
406 responses = tuple(response_iterator)
407 self.assertEqual(
len(responses), test_constants.STREAM_LENGTH)
413 request = b
'\x07\x08'
416 call_future = multi_callable.future(
419 'InterceptedUnaryRequestBlockingUnaryResponse'),))
421 self.assertIsNotNone(call_future.exception())
422 self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL)
425 request = b
'\x07\x08'
436 multi_callable.with_call(
440 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))
442 self.assertSequenceEqual(self.
_record, [
443 'c1:intercept_unary_unary',
'c2:intercept_unary_unary',
444 's1:intercept_service',
's3:intercept_service',
445 's2:intercept_service'
449 request = b
'\x07\x08'
461 'InterceptedUnaryRequestBlockingUnaryResponse'),))
463 self.assertSequenceEqual(self.
_record, [
464 'c1:intercept_unary_unary',
'c2:intercept_unary_unary',
465 's1:intercept_service',
's2:intercept_service'
469 request = _EXCEPTION_REQUEST
482 'InterceptedUnaryRequestBlockingUnaryResponse'),))
483 exception = exception_context.exception
484 self.assertFalse(exception.cancelled())
485 self.assertFalse(exception.running())
486 self.assertTrue(exception.done())
492 request = b
'\x07\x08'
501 multi_callable.with_call(
505 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))
507 self.assertSequenceEqual(self.
_record, [
508 'c1:intercept_unary_unary',
'c2:intercept_unary_unary',
509 's1:intercept_service',
's2:intercept_service'
513 request = b
'\x07\x08'
521 response_future = multi_callable.future(
523 metadata=((
'test',
'InterceptedUnaryRequestFutureUnaryResponse'),))
524 response_future.result()
526 self.assertSequenceEqual(self.
_record, [
527 'c1:intercept_unary_unary',
'c2:intercept_unary_unary',
528 's1:intercept_service',
's2:intercept_service'
532 request = b
'\x37\x58'
542 metadata=((
'test',
'InterceptedUnaryRequestStreamResponse'),))
543 tuple(response_iterator)
545 self.assertSequenceEqual(self.
_record, [
546 'c1:intercept_unary_stream',
'c2:intercept_unary_stream',
547 's1:intercept_service',
's2:intercept_service'
551 request = _EXCEPTION_REQUEST
561 metadata=((
'test',
'InterceptedUnaryRequestStreamResponse'),))
563 tuple(response_iterator)
564 exception = exception_context.exception
565 self.assertFalse(exception.cancelled())
566 self.assertFalse(exception.running())
567 self.assertTrue(exception.done())
574 b
'\x07\x08' for _
in range(test_constants.STREAM_LENGTH))
575 request_iterator =
iter(requests)
586 'InterceptedStreamRequestBlockingUnaryResponse'),))
588 self.assertSequenceEqual(self.
_record, [
589 'c1:intercept_stream_unary',
'c2:intercept_stream_unary',
590 's1:intercept_service',
's2:intercept_service'
595 b
'\x07\x08' for _
in range(test_constants.STREAM_LENGTH))
596 request_iterator =
iter(requests)
604 multi_callable.with_call(
608 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
610 self.assertSequenceEqual(self.
_record, [
611 'c1:intercept_stream_unary',
'c2:intercept_stream_unary',
612 's1:intercept_service',
's2:intercept_service'
617 b
'\x07\x08' for _
in range(test_constants.STREAM_LENGTH))
618 request_iterator =
iter(requests)
626 response_future = multi_callable.future(
628 metadata=((
'test',
'InterceptedStreamRequestFutureUnaryResponse'),))
629 response_future.result()
631 self.assertSequenceEqual(self.
_record, [
632 'c1:intercept_stream_unary',
'c2:intercept_stream_unary',
633 's1:intercept_service',
's2:intercept_service'
638 _EXCEPTION_REQUEST
for _
in range(test_constants.STREAM_LENGTH))
639 request_iterator =
iter(requests)
647 response_future = multi_callable.future(
649 metadata=((
'test',
'InterceptedStreamRequestFutureUnaryResponse'),))
651 response_future.result()
652 exception = exception_context.exception
653 self.assertFalse(exception.cancelled())
654 self.assertFalse(exception.running())
655 self.assertTrue(exception.done())
662 b
'\x77\x58' for _
in range(test_constants.STREAM_LENGTH))
663 request_iterator =
iter(requests)
673 metadata=((
'test',
'InterceptedStreamRequestStreamResponse'),))
674 tuple(response_iterator)
676 self.assertSequenceEqual(self.
_record, [
677 'c1:intercept_stream_stream',
'c2:intercept_stream_stream',
678 's1:intercept_service',
's2:intercept_service'
683 _EXCEPTION_REQUEST
for _
in range(test_constants.STREAM_LENGTH))
684 request_iterator =
iter(requests)
694 metadata=((
'test',
'InterceptedStreamRequestStreamResponse'),))
696 tuple(response_iterator)
697 exception = exception_context.exception
698 self.assertFalse(exception.cancelled())
699 self.assertFalse(exception.running())
700 self.assertTrue(exception.done())
706 if __name__ ==
'__main__':
707 logging.basicConfig()
708 unittest.main(verbosity=2)