server_interceptor_test.py
Go to the documentation of this file.
1 # Copyright 2020 The gRPC Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """Test the functionality of server interceptors."""
15 
16 import asyncio
17 import functools
18 import logging
19 from typing import Any, Awaitable, Callable, Tuple
20 import unittest
21 
22 import grpc
23 from grpc.experimental import aio
24 from grpc.experimental import wrap_server_method_handler
25 
26 from src.proto.grpc.testing import messages_pb2
27 from src.proto.grpc.testing import test_pb2_grpc
28 from tests_aio.unit._test_base import AioTestBase
29 from tests_aio.unit._test_server import start_test_server
30 
31 _NUM_STREAM_RESPONSES = 5
32 _REQUEST_PAYLOAD_SIZE = 7
33 _RESPONSE_PAYLOAD_SIZE = 42
34 
35 
36 class _LoggingInterceptor(aio.ServerInterceptor):
37 
38  def __init__(self, tag: str, record: list) -> None:
39  self.tag = tag
40  self.record = record
41 
42  async def intercept_service(
43  self, continuation: Callable[[grpc.HandlerCallDetails],
44  Awaitable[grpc.RpcMethodHandler]],
45  handler_call_details: grpc.HandlerCallDetails
47  self.record.append(self.tag + ':intercept_service')
48  return await continuation(handler_call_details)
49 
50 
51 class _GenericInterceptor(aio.ServerInterceptor):
52 
53  def __init__(
54  self, fn: Callable[[
55  Callable[[grpc.HandlerCallDetails],
57  ], Any]
58  ) -> None:
59  self._fn = fn
60 
61  async def intercept_service(
62  self, continuation: Callable[[grpc.HandlerCallDetails],
63  Awaitable[grpc.RpcMethodHandler]],
64  handler_call_details: grpc.HandlerCallDetails
66  return await self._fn(continuation, handler_call_details)
67 
68 
70  condition: Callable,
71  interceptor: aio.ServerInterceptor) -> aio.ServerInterceptor:
72 
73  async def intercept_service(
74  continuation: Callable[[grpc.HandlerCallDetails],
75  Awaitable[grpc.RpcMethodHandler]],
76  handler_call_details: grpc.HandlerCallDetails
78  if condition(handler_call_details):
79  return await interceptor.intercept_service(continuation,
80  handler_call_details)
81  return await continuation(handler_call_details)
82 
83  return _GenericInterceptor(intercept_service)
84 
85 
86 class _CacheInterceptor(aio.ServerInterceptor):
87  """An interceptor that caches response based on request message."""
88 
89  def __init__(self, cache_store=None):
90  self.cache_store = cache_store or {}
91 
92  async def intercept_service(
93  self, continuation: Callable[[grpc.HandlerCallDetails],
94  Awaitable[grpc.RpcMethodHandler]],
95  handler_call_details: grpc.HandlerCallDetails
97  # Get the actual handler
98  handler = await continuation(handler_call_details)
99 
100  # Only intercept unary call RPCs
101  if handler and (handler.request_streaming or # pytype: disable=attribute-error
102  handler.response_streaming): # pytype: disable=attribute-error
103  return handler
104 
105  def wrapper(behavior: Callable[
106  [messages_pb2.SimpleRequest, aio.ServicerContext],
107  messages_pb2.SimpleResponse]):
108 
109  @functools.wraps(behavior)
110  async def wrapper(
111  request: messages_pb2.SimpleRequest,
112  context: aio.ServicerContext
113  ) -> messages_pb2.SimpleResponse:
114  if request.response_size not in self.cache_store:
115  self.cache_store[request.response_size] = await behavior(
116  request, context)
117  return self.cache_store[request.response_size]
118 
119  return wrapper
120 
121  return wrap_server_method_handler(wrapper, handler)
122 
123 
125  *interceptors: aio.ServerInterceptor
126 ) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]:
127  """Creates a server-stub pair with given interceptors.
128 
129  Returning the server object to protect it from being garbage collected.
130  """
131  server_target, server = await start_test_server(interceptors=interceptors)
132  channel = aio.insecure_channel(server_target)
133  return server, test_pb2_grpc.TestServiceStub(channel)
134 
135 
137 
138  async def test_invalid_interceptor(self):
139 
140  class InvalidInterceptor:
141  """Just an invalid Interceptor"""
142 
143  with self.assertRaises(ValueError):
144  server_target, _ = await start_test_server(
145  interceptors=(InvalidInterceptor(),))
146 
147  async def test_executed_right_order(self):
148  record = []
149  server_target, _ = await start_test_server(interceptors=(
150  _LoggingInterceptor('log1', record),
151  _LoggingInterceptor('log2', record),
152  ))
153 
154  async with aio.insecure_channel(server_target) as channel:
155  multicallable = channel.unary_unary(
156  '/grpc.testing.TestService/UnaryCall',
157  request_serializer=messages_pb2.SimpleRequest.SerializeToString,
158  response_deserializer=messages_pb2.SimpleResponse.FromString)
159  call = multicallable(messages_pb2.SimpleRequest())
160  response = await call
161 
162  # Check that all interceptors were executed, and were executed
163  # in the right order.
164  self.assertSequenceEqual([
165  'log1:intercept_service',
166  'log2:intercept_service',
167  ], record)
168  self.assertIsInstance(response, messages_pb2.SimpleResponse)
169 
170  async def test_response_ok(self):
171  record = []
172  server_target, _ = await start_test_server(
173  interceptors=(_LoggingInterceptor('log1', record),))
174 
175  async with aio.insecure_channel(server_target) as channel:
176  multicallable = channel.unary_unary(
177  '/grpc.testing.TestService/UnaryCall',
178  request_serializer=messages_pb2.SimpleRequest.SerializeToString,
179  response_deserializer=messages_pb2.SimpleResponse.FromString)
180  call = multicallable(messages_pb2.SimpleRequest())
181  response = await call
182  code = await call.code()
183 
184  self.assertSequenceEqual(['log1:intercept_service'], record)
185  self.assertIsInstance(response, messages_pb2.SimpleResponse)
186  self.assertEqual(code, grpc.StatusCode.OK)
187 
189  record = []
190  conditional_interceptor = _filter_server_interceptor(
191  lambda x: ('secret', '42') in x.invocation_metadata,
192  _LoggingInterceptor('log3', record))
193  server_target, _ = await start_test_server(interceptors=(
194  _LoggingInterceptor('log1', record),
195  conditional_interceptor,
196  _LoggingInterceptor('log2', record),
197  ))
198 
199  async with aio.insecure_channel(server_target) as channel:
200  multicallable = channel.unary_unary(
201  '/grpc.testing.TestService/UnaryCall',
202  request_serializer=messages_pb2.SimpleRequest.SerializeToString,
203  response_deserializer=messages_pb2.SimpleResponse.FromString)
204 
205  metadata = aio.Metadata(('key', 'value'),)
206  call = multicallable(messages_pb2.SimpleRequest(),
207  metadata=metadata)
208  await call
209  self.assertSequenceEqual([
210  'log1:intercept_service',
211  'log2:intercept_service',
212  ], record)
213 
214  record.clear()
215  metadata = aio.Metadata(('key', 'value'), ('secret', '42'))
216  call = multicallable(messages_pb2.SimpleRequest(),
217  metadata=metadata)
218  await call
219  self.assertSequenceEqual([
220  'log1:intercept_service',
221  'log3:intercept_service',
222  'log2:intercept_service',
223  ], record)
224 
225  async def test_response_caching(self):
226  # Prepares a preset value to help testing
227  interceptor = _CacheInterceptor({
228  42:
230  body=b'\x42'))
231  })
232 
233  # Constructs a server with the cache interceptor
234  server, stub = await _create_server_stub_pair(interceptor)
235 
236  # Tests if the cache store is used
237  response = await stub.UnaryCall(
238  messages_pb2.SimpleRequest(response_size=42))
239  self.assertEqual(1, len(interceptor.cache_store[42].payload.body))
240  self.assertEqual(interceptor.cache_store[42], response)
241 
242  # Tests response can be cached
243  response = await stub.UnaryCall(
244  messages_pb2.SimpleRequest(response_size=1337))
245  self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body))
246  self.assertEqual(interceptor.cache_store[1337], response)
247  response = await stub.UnaryCall(
248  messages_pb2.SimpleRequest(response_size=1337))
249  self.assertEqual(interceptor.cache_store[1337], response)
250 
252  record = []
253  server, stub = await _create_server_stub_pair(
254  _LoggingInterceptor('log_unary_stream', record))
255 
256  # Prepares the request
258  for _ in range(_NUM_STREAM_RESPONSES):
259  request.response_parameters.append(
260  messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
261 
262  # Tests if the cache store is used
263  call = stub.StreamingOutputCall(request)
264 
265  # Ensures the RPC goes fine
266  async for response in call:
267  self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
268  self.assertEqual(await call.code(), grpc.StatusCode.OK)
269 
270  self.assertSequenceEqual([
271  'log_unary_stream:intercept_service',
272  ], record)
273 
275  record = []
276  server, stub = await _create_server_stub_pair(
277  _LoggingInterceptor('log_stream_unary', record))
278 
279  # Invokes the actual RPC
280  call = stub.StreamingInputCall()
281 
282  # Prepares the request
283  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
284  request = messages_pb2.StreamingInputCallRequest(payload=payload)
285 
286  # Sends out requests
287  for _ in range(_NUM_STREAM_RESPONSES):
288  await call.write(request)
289  await call.done_writing()
290 
291  # Validates the responses
292  response = await call
293  self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
294  self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
295  response.aggregated_payload_size)
296 
297  self.assertEqual(await call.code(), grpc.StatusCode.OK)
298 
299  self.assertSequenceEqual([
300  'log_stream_unary:intercept_service',
301  ], record)
302 
304  record = []
305  server, stub = await _create_server_stub_pair(
306  _LoggingInterceptor('log_stream_stream', record))
307 
308  # Prepares the request
309  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
310  request = messages_pb2.StreamingInputCallRequest(payload=payload)
311 
312  async def gen():
313  for _ in range(_NUM_STREAM_RESPONSES):
314  yield request
315 
316  # Invokes the actual RPC
317  call = stub.StreamingInputCall(gen())
318 
319  # Validates the responses
320  response = await call
321  self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
322  self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
323  response.aggregated_payload_size)
324 
325  self.assertEqual(await call.code(), grpc.StatusCode.OK)
326 
327  self.assertSequenceEqual([
328  'log_stream_stream:intercept_service',
329  ], record)
330 
331 
332 if __name__ == '__main__':
333  logging.basicConfig(level=logging.DEBUG)
334  unittest.main(verbosity=2)
messages_pb2.SimpleRequest
SimpleRequest
Definition: messages_pb2.py:597
tests_aio.unit.server_interceptor_test._filter_server_interceptor
aio.ServerInterceptor _filter_server_interceptor(Callable condition, aio.ServerInterceptor interceptor)
Definition: server_interceptor_test.py:69
tests_aio.unit.server_interceptor_test._LoggingInterceptor.intercept_service
grpc.RpcMethodHandler intercept_service(self, Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]] continuation, grpc.HandlerCallDetails handler_call_details)
Definition: server_interceptor_test.py:42
tests_aio.unit.server_interceptor_test._CacheInterceptor.__init__
def __init__(self, cache_store=None)
Definition: server_interceptor_test.py:89
tests_aio.unit._test_server
Definition: tests_aio/unit/_test_server.py:1
capstone.range
range
Definition: third_party/bloaty/third_party/capstone/bindings/python/capstone/__init__.py:6
tests_aio.unit._test_server.start_test_server
def start_test_server(port=0, secure=False, server_credentials=None, interceptors=None)
Definition: tests_aio/unit/_test_server.py:128
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_apply_different_interceptors_by_metadata
def test_apply_different_interceptors_by_metadata(self)
Definition: server_interceptor_test.py:188
tests_aio.unit._test_base
Definition: _test_base.py:1
tests_aio.unit.server_interceptor_test.TestServerInterceptor
Definition: server_interceptor_test.py:136
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_response_caching
def test_response_caching(self)
Definition: server_interceptor_test.py:225
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_interceptor_unary_stream
def test_interceptor_unary_stream(self)
Definition: server_interceptor_test.py:251
grpc::experimental
Definition: include/grpcpp/channel.h:46
tests_aio.unit.server_interceptor_test._create_server_stub_pair
Tuple[aio.Server, test_pb2_grpc.TestServiceStub] _create_server_stub_pair(*aio.ServerInterceptor interceptors)
Definition: server_interceptor_test.py:124
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_invalid_interceptor
def test_invalid_interceptor(self)
Definition: server_interceptor_test.py:138
tests_aio.unit.server_interceptor_test._LoggingInterceptor.__init__
None __init__(self, str tag, list record)
Definition: server_interceptor_test.py:38
tests_aio.unit.server_interceptor_test._LoggingInterceptor.tag
tag
Definition: server_interceptor_test.py:39
gen
OPENSSL_EXPORT GENERAL_NAME * gen
Definition: x509v3.h:495
messages_pb2.ResponseParameters
ResponseParameters
Definition: messages_pb2.py:625
wrapper
grpc_channel_wrapper * wrapper
Definition: src/php/ext/grpc/channel.h:48
messages_pb2.StreamingOutputCallRequest
StreamingOutputCallRequest
Definition: messages_pb2.py:632
messages_pb2.Payload
Payload
Definition: messages_pb2.py:583
tests_aio.unit.server_interceptor_test._CacheInterceptor.intercept_service
grpc.RpcMethodHandler intercept_service(self, Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]] continuation, grpc.HandlerCallDetails handler_call_details)
Definition: server_interceptor_test.py:92
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_interceptor_stream_unary
def test_interceptor_stream_unary(self)
Definition: server_interceptor_test.py:274
tests_aio.unit.server_interceptor_test._CacheInterceptor
Definition: server_interceptor_test.py:86
tests_aio.unit.server_interceptor_test._GenericInterceptor.intercept_service
grpc.RpcMethodHandler intercept_service(self, Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]] continuation, grpc.HandlerCallDetails handler_call_details)
Definition: server_interceptor_test.py:61
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_executed_right_order
def test_executed_right_order(self)
Definition: server_interceptor_test.py:147
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_response_ok
def test_response_ok(self)
Definition: server_interceptor_test.py:170
tests_aio.unit.server_interceptor_test._LoggingInterceptor.record
record
Definition: server_interceptor_test.py:40
tests_aio.unit.server_interceptor_test._CacheInterceptor.cache_store
cache_store
Definition: server_interceptor_test.py:90
messages_pb2.StreamingInputCallRequest
StreamingInputCallRequest
Definition: messages_pb2.py:611
tests_aio.unit.server_interceptor_test._GenericInterceptor._fn
_fn
Definition: server_interceptor_test.py:54
grpc::experimental.wrap_server_method_handler
def wrap_server_method_handler(wrapper, handler)
Definition: src/python/grpcio/grpc/experimental/__init__.py:82
grpc.HandlerCallDetails
Definition: src/python/grpcio/grpc/__init__.py:1324
len
int len
Definition: abseil-cpp/absl/base/internal/low_level_alloc_test.cc:46
tests_aio.unit.server_interceptor_test._LoggingInterceptor
Definition: server_interceptor_test.py:36
messages_pb2.SimpleResponse
SimpleResponse
Definition: messages_pb2.py:604
tests_aio.unit.server_interceptor_test._GenericInterceptor
Definition: server_interceptor_test.py:51
tests_aio.unit.server_interceptor_test._GenericInterceptor.__init__
None __init__(self, Callable[[Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]], grpc.HandlerCallDetails], Any] fn)
Definition: server_interceptor_test.py:53
grpc.RpcMethodHandler
Definition: src/python/grpcio/grpc/__init__.py:1288
tests_aio.unit._test_base.AioTestBase
Definition: _test_base.py:49
tests_aio.unit.server_interceptor_test.TestServerInterceptor.test_interceptor_stream_stream
def test_interceptor_stream_stream(self)
Definition: server_interceptor_test.py:303


grpc
Author(s):
autogenerated on Thu Mar 13 2025 03:01:17