14 """Test the functionality of server interceptors."""
19 from typing
import Any, Awaitable, Callable, Tuple
26 from src.proto.grpc.testing
import messages_pb2
27 from src.proto.grpc.testing
import test_pb2_grpc
31 _NUM_STREAM_RESPONSES = 5
32 _REQUEST_PAYLOAD_SIZE = 7
33 _RESPONSE_PAYLOAD_SIZE = 42
38 def __init__(self, tag: str, record: list) ->
None:
47 self.
record.append(self.
tag +
':intercept_service')
48 return await continuation(handler_call_details)
66 return await self.
_fn(continuation, handler_call_details)
71 interceptor: aio.ServerInterceptor) -> aio.ServerInterceptor:
73 async
def intercept_service(
78 if condition(handler_call_details):
79 return await interceptor.intercept_service(continuation,
81 return await continuation(handler_call_details)
87 """An interceptor that caches response based on request message."""
98 handler = await continuation(handler_call_details)
101 if handler
and (handler.request_streaming
or
102 handler.response_streaming):
105 def wrapper(behavior: Callable[
106 [messages_pb2.SimpleRequest, aio.ServicerContext],
107 messages_pb2.SimpleResponse]):
109 @functools.wraps(behavior)
111 request: messages_pb2.SimpleRequest,
112 context: aio.ServicerContext
113 ) -> messages_pb2.SimpleResponse:
115 self.
cache_store[request.response_size] = await behavior(
125 *interceptors: aio.ServerInterceptor
126 ) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]:
127 """Creates a server-stub pair with given interceptors.
129 Returning the server object to protect it from being garbage collected.
132 channel = aio.insecure_channel(server_target)
133 return server, test_pb2_grpc.TestServiceStub(channel)
140 class InvalidInterceptor:
141 """Just an invalid Interceptor"""
143 with self.assertRaises(ValueError):
145 interceptors=(InvalidInterceptor(),))
154 async
with aio.insecure_channel(server_target)
as channel:
155 multicallable = channel.unary_unary(
156 '/grpc.testing.TestService/UnaryCall',
157 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
158 response_deserializer=messages_pb2.SimpleResponse.FromString)
160 response = await call
164 self.assertSequenceEqual([
165 'log1:intercept_service',
166 'log2:intercept_service',
168 self.assertIsInstance(response, messages_pb2.SimpleResponse)
175 async
with aio.insecure_channel(server_target)
as channel:
176 multicallable = channel.unary_unary(
177 '/grpc.testing.TestService/UnaryCall',
178 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
179 response_deserializer=messages_pb2.SimpleResponse.FromString)
181 response = await call
182 code = await call.code()
184 self.assertSequenceEqual([
'log1:intercept_service'], record)
185 self.assertIsInstance(response, messages_pb2.SimpleResponse)
186 self.assertEqual(code, grpc.StatusCode.OK)
191 lambda x: (
'secret',
'42')
in x.invocation_metadata,
195 conditional_interceptor,
199 async
with aio.insecure_channel(server_target)
as channel:
200 multicallable = channel.unary_unary(
201 '/grpc.testing.TestService/UnaryCall',
202 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
203 response_deserializer=messages_pb2.SimpleResponse.FromString)
205 metadata = aio.Metadata((
'key',
'value'),)
209 self.assertSequenceEqual([
210 'log1:intercept_service',
211 'log2:intercept_service',
215 metadata = aio.Metadata((
'key',
'value'), (
'secret',
'42'))
219 self.assertSequenceEqual([
220 'log1:intercept_service',
221 'log3:intercept_service',
222 'log2:intercept_service',
237 response = await stub.UnaryCall(
239 self.assertEqual(1,
len(interceptor.cache_store[42].payload.body))
240 self.assertEqual(interceptor.cache_store[42], response)
243 response = await stub.UnaryCall(
245 self.assertEqual(1337,
len(interceptor.cache_store[1337].payload.body))
246 self.assertEqual(interceptor.cache_store[1337], response)
247 response = await stub.UnaryCall(
249 self.assertEqual(interceptor.cache_store[1337], response)
258 for _
in range(_NUM_STREAM_RESPONSES):
259 request.response_parameters.append(
263 call = stub.StreamingOutputCall(request)
266 async
for response
in call:
267 self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
268 self.assertEqual(await call.code(), grpc.StatusCode.OK)
270 self.assertSequenceEqual([
271 'log_unary_stream:intercept_service',
280 call = stub.StreamingInputCall()
287 for _
in range(_NUM_STREAM_RESPONSES):
288 await call.write(request)
289 await call.done_writing()
292 response = await call
293 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
294 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
295 response.aggregated_payload_size)
297 self.assertEqual(await call.code(), grpc.StatusCode.OK)
299 self.assertSequenceEqual([
300 'log_stream_unary:intercept_service',
313 for _
in range(_NUM_STREAM_RESPONSES):
317 call = stub.StreamingInputCall(
gen())
320 response = await call
321 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
322 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
323 response.aggregated_payload_size)
325 self.assertEqual(await call.code(), grpc.StatusCode.OK)
327 self.assertSequenceEqual([
328 'log_stream_stream:intercept_service',
332 if __name__ ==
'__main__':
333 logging.basicConfig(level=logging.DEBUG)
334 unittest.main(verbosity=2)