client_stream_unary_interceptor_test.py
Go to the documentation of this file.
1 # Copyright 2020 The gRPC Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 import asyncio
15 import datetime
16 import logging
17 import unittest
18 
19 import grpc
20 from grpc.experimental import aio
21 
22 from src.proto.grpc.testing import messages_pb2
23 from src.proto.grpc.testing import test_pb2_grpc
24 from tests.unit.framework.common import test_constants
25 from tests_aio.unit._common import CountingRequestIterator
26 from tests_aio.unit._common import inject_callbacks
27 from tests_aio.unit._constants import UNREACHABLE_TARGET
28 from tests_aio.unit._test_base import AioTestBase
29 from tests_aio.unit._test_server import start_test_server
30 
31 _SHORT_TIMEOUT_S = 1.0
32 
33 _NUM_STREAM_REQUESTS = 5
34 _REQUEST_PAYLOAD_SIZE = 7
35 _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
36 
37 
38 class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor):
39 
40  async def intercept_stream_unary(self, continuation, client_call_details,
41  request_iterator):
42  return await continuation(client_call_details, request_iterator)
43 
44  def assert_in_final_state(self, test: unittest.TestCase):
45  pass
46 
47 
49  aio.StreamUnaryClientInterceptor):
50 
51  async def intercept_stream_unary(self, continuation, client_call_details,
52  request_iterator):
53  self.request_iterator = CountingRequestIterator(request_iterator)
54  call = await continuation(client_call_details, self.request_iterator)
55  return call
56 
57  def assert_in_final_state(self, test: unittest.TestCase):
58  test.assertEqual(_NUM_STREAM_REQUESTS,
59  self.request_iterator.request_cnt)
60 
61 
63 
64  async def setUp(self):
65  self._server_target, self._server = await start_test_server()
66 
67  async def tearDown(self):
68  await self._server.stop(None)
69 
70  async def test_intercepts(self):
71  for interceptor_class in (_StreamUnaryInterceptorEmpty,
72  _StreamUnaryInterceptorWithRequestIterator):
73 
74  with self.subTest(name=interceptor_class):
75  interceptor = interceptor_class()
76  channel = aio.insecure_channel(self._server_target,
77  interceptors=[interceptor])
78  stub = test_pb2_grpc.TestServiceStub(channel)
79 
80  payload = messages_pb2.Payload(body=b'\0' *
81  _REQUEST_PAYLOAD_SIZE)
83  payload=payload)
84 
85  async def request_iterator():
86  for _ in range(_NUM_STREAM_REQUESTS):
87  yield request
88 
89  call = stub.StreamingInputCall(request_iterator())
90 
91  response = await call
92 
93  self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
94  response.aggregated_payload_size)
95  self.assertEqual(await call.code(), grpc.StatusCode.OK)
96  self.assertEqual(await call.initial_metadata(), aio.Metadata())
97  self.assertEqual(await call.trailing_metadata(), aio.Metadata())
98  self.assertEqual(await call.details(), '')
99  self.assertEqual(await call.debug_error_string(), '')
100  self.assertEqual(call.cancel(), False)
101  self.assertEqual(call.cancelled(), False)
102  self.assertEqual(call.done(), True)
103 
104  interceptor.assert_in_final_state(self)
105 
106  await channel.close()
107 
108  async def test_intercepts_using_write(self):
109  for interceptor_class in (_StreamUnaryInterceptorEmpty,
110  _StreamUnaryInterceptorWithRequestIterator):
111 
112  with self.subTest(name=interceptor_class):
113  interceptor = interceptor_class()
114  channel = aio.insecure_channel(self._server_target,
115  interceptors=[interceptor])
116  stub = test_pb2_grpc.TestServiceStub(channel)
117 
118  payload = messages_pb2.Payload(body=b'\0' *
119  _REQUEST_PAYLOAD_SIZE)
121  payload=payload)
122 
123  call = stub.StreamingInputCall()
124 
125  for _ in range(_NUM_STREAM_REQUESTS):
126  await call.write(request)
127 
128  await call.done_writing()
129 
130  response = await call
131 
132  self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
133  response.aggregated_payload_size)
134  self.assertEqual(await call.code(), grpc.StatusCode.OK)
135  self.assertEqual(await call.initial_metadata(), aio.Metadata())
136  self.assertEqual(await call.trailing_metadata(), aio.Metadata())
137  self.assertEqual(await call.details(), '')
138  self.assertEqual(await call.debug_error_string(), '')
139  self.assertEqual(call.cancel(), False)
140  self.assertEqual(call.cancelled(), False)
141  self.assertEqual(call.done(), True)
142 
143  interceptor.assert_in_final_state(self)
144 
145  await channel.close()
146 
147  async def test_add_done_callback_interceptor_task_not_finished(self):
148  for interceptor_class in (_StreamUnaryInterceptorEmpty,
149  _StreamUnaryInterceptorWithRequestIterator):
150 
151  with self.subTest(name=interceptor_class):
152  interceptor = interceptor_class()
153 
154  channel = aio.insecure_channel(self._server_target,
155  interceptors=[interceptor])
156  stub = test_pb2_grpc.TestServiceStub(channel)
157 
158  payload = messages_pb2.Payload(body=b'\0' *
159  _REQUEST_PAYLOAD_SIZE)
161  payload=payload)
162 
163  async def request_iterator():
164  for _ in range(_NUM_STREAM_REQUESTS):
165  yield request
166 
167  call = stub.StreamingInputCall(request_iterator())
168 
169  validation = inject_callbacks(call)
170 
171  response = await call
172 
173  await validation
174 
175  await channel.close()
176 
177  async def test_add_done_callback_interceptor_task_finished(self):
178  for interceptor_class in (_StreamUnaryInterceptorEmpty,
179  _StreamUnaryInterceptorWithRequestIterator):
180 
181  with self.subTest(name=interceptor_class):
182  interceptor = interceptor_class()
183 
184  channel = aio.insecure_channel(self._server_target,
185  interceptors=[interceptor])
186  stub = test_pb2_grpc.TestServiceStub(channel)
187 
188  payload = messages_pb2.Payload(body=b'\0' *
189  _REQUEST_PAYLOAD_SIZE)
191  payload=payload)
192 
193  async def request_iterator():
194  for _ in range(_NUM_STREAM_REQUESTS):
195  yield request
196 
197  call = stub.StreamingInputCall(request_iterator())
198 
199  response = await call
200 
201  validation = inject_callbacks(call)
202 
203  await validation
204 
205  await channel.close()
206 
207  async def test_multiple_interceptors_request_iterator(self):
208  for interceptor_class in (_StreamUnaryInterceptorEmpty,
209  _StreamUnaryInterceptorWithRequestIterator):
210 
211  with self.subTest(name=interceptor_class):
212 
213  interceptors = [interceptor_class(), interceptor_class()]
214  channel = aio.insecure_channel(self._server_target,
215  interceptors=interceptors)
216  stub = test_pb2_grpc.TestServiceStub(channel)
217 
218  payload = messages_pb2.Payload(body=b'\0' *
219  _REQUEST_PAYLOAD_SIZE)
221  payload=payload)
222 
223  async def request_iterator():
224  for _ in range(_NUM_STREAM_REQUESTS):
225  yield request
226 
227  call = stub.StreamingInputCall(request_iterator())
228 
229  response = await call
230 
231  self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
232  response.aggregated_payload_size)
233  self.assertEqual(await call.code(), grpc.StatusCode.OK)
234  self.assertEqual(await call.initial_metadata(), aio.Metadata())
235  self.assertEqual(await call.trailing_metadata(), aio.Metadata())
236  self.assertEqual(await call.details(), '')
237  self.assertEqual(await call.debug_error_string(), '')
238  self.assertEqual(call.cancel(), False)
239  self.assertEqual(call.cancelled(), False)
240  self.assertEqual(call.done(), True)
241 
242  for interceptor in interceptors:
243  interceptor.assert_in_final_state(self)
244 
245  await channel.close()
246 
247  async def test_intercepts_request_iterator_rpc_error(self):
248  for interceptor_class in (_StreamUnaryInterceptorEmpty,
249  _StreamUnaryInterceptorWithRequestIterator):
250 
251  with self.subTest(name=interceptor_class):
252  channel = aio.insecure_channel(
253  UNREACHABLE_TARGET, interceptors=[interceptor_class()])
254  stub = test_pb2_grpc.TestServiceStub(channel)
255 
256  payload = messages_pb2.Payload(body=b'\0' *
257  _REQUEST_PAYLOAD_SIZE)
259  payload=payload)
260 
261  # When there is an error the request iterator is no longer
262  # consumed.
263  async def request_iterator():
264  for _ in range(_NUM_STREAM_REQUESTS):
265  yield request
266 
267  call = stub.StreamingInputCall(request_iterator())
268 
269  with self.assertRaises(aio.AioRpcError) as exception_context:
270  await call
271 
272  self.assertEqual(grpc.StatusCode.UNAVAILABLE,
273  exception_context.exception.code())
274  self.assertTrue(call.done())
275  self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
276 
277  await channel.close()
278 
279  async def test_intercepts_request_iterator_rpc_error_using_write(self):
280  for interceptor_class in (_StreamUnaryInterceptorEmpty,
281  _StreamUnaryInterceptorWithRequestIterator):
282 
283  with self.subTest(name=interceptor_class):
284  channel = aio.insecure_channel(
285  UNREACHABLE_TARGET, interceptors=[interceptor_class()])
286  stub = test_pb2_grpc.TestServiceStub(channel)
287 
288  payload = messages_pb2.Payload(body=b'\0' *
289  _REQUEST_PAYLOAD_SIZE)
291  payload=payload)
292 
293  call = stub.StreamingInputCall()
294 
295  # When there is an error during the write, exception is raised.
296  with self.assertRaises(asyncio.InvalidStateError):
297  for _ in range(_NUM_STREAM_REQUESTS):
298  await call.write(request)
299 
300  with self.assertRaises(aio.AioRpcError) as exception_context:
301  await call
302 
303  self.assertEqual(grpc.StatusCode.UNAVAILABLE,
304  exception_context.exception.code())
305  self.assertTrue(call.done())
306  self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
307 
308  await channel.close()
309 
310  async def test_cancel_before_rpc(self):
311 
312  interceptor_reached = asyncio.Event()
313  wait_for_ever = self.loop.create_future()
314 
315  class Interceptor(aio.StreamUnaryClientInterceptor):
316 
317  async def intercept_stream_unary(self, continuation,
318  client_call_details,
319  request_iterator):
320  interceptor_reached.set()
321  await wait_for_ever
322 
323  channel = aio.insecure_channel(self._server_target,
324  interceptors=[Interceptor()])
325  stub = test_pb2_grpc.TestServiceStub(channel)
326 
327  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
328  request = messages_pb2.StreamingInputCallRequest(payload=payload)
329 
330  call = stub.StreamingInputCall()
331 
332  self.assertFalse(call.cancelled())
333  self.assertFalse(call.done())
334 
335  await interceptor_reached.wait()
336  self.assertTrue(call.cancel())
337 
338  # When there is an error during the write, exception is raised.
339  with self.assertRaises(asyncio.InvalidStateError):
340  for _ in range(_NUM_STREAM_REQUESTS):
341  await call.write(request)
342 
343  with self.assertRaises(asyncio.CancelledError):
344  await call
345 
346  self.assertTrue(call.cancelled())
347  self.assertTrue(call.done())
348  self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
349  self.assertEqual(await call.initial_metadata(), None)
350  self.assertEqual(await call.trailing_metadata(), None)
351  await channel.close()
352 
353  async def test_cancel_after_rpc(self):
354 
355  interceptor_reached = asyncio.Event()
356  wait_for_ever = self.loop.create_future()
357 
358  class Interceptor(aio.StreamUnaryClientInterceptor):
359 
360  async def intercept_stream_unary(self, continuation,
361  client_call_details,
362  request_iterator):
363  call = await continuation(client_call_details, request_iterator)
364  interceptor_reached.set()
365  await wait_for_ever
366 
367  channel = aio.insecure_channel(self._server_target,
368  interceptors=[Interceptor()])
369  stub = test_pb2_grpc.TestServiceStub(channel)
370 
371  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
372  request = messages_pb2.StreamingInputCallRequest(payload=payload)
373 
374  call = stub.StreamingInputCall()
375 
376  self.assertFalse(call.cancelled())
377  self.assertFalse(call.done())
378 
379  await interceptor_reached.wait()
380  self.assertTrue(call.cancel())
381 
382  # When there is an error during the write, exception is raised.
383  with self.assertRaises(asyncio.InvalidStateError):
384  for _ in range(_NUM_STREAM_REQUESTS):
385  await call.write(request)
386 
387  with self.assertRaises(asyncio.CancelledError):
388  await call
389 
390  self.assertTrue(call.cancelled())
391  self.assertTrue(call.done())
392  self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
393  self.assertEqual(await call.initial_metadata(), None)
394  self.assertEqual(await call.trailing_metadata(), None)
395  await channel.close()
396 
397  async def test_cancel_while_writing(self):
398  # Test cancelation before making any write or after doing at least 1
399  for num_writes_before_cancel in (0, 1):
400  with self.subTest(name="Num writes before cancel: {}".format(
401  num_writes_before_cancel)):
402 
403  channel = aio.insecure_channel(
404  UNREACHABLE_TARGET,
406  stub = test_pb2_grpc.TestServiceStub(channel)
407 
408  payload = messages_pb2.Payload(body=b'\0' *
409  _REQUEST_PAYLOAD_SIZE)
411  payload=payload)
412 
413  call = stub.StreamingInputCall()
414 
415  with self.assertRaises(asyncio.InvalidStateError):
416  for i in range(_NUM_STREAM_REQUESTS):
417  if i == num_writes_before_cancel:
418  self.assertTrue(call.cancel())
419  await call.write(request)
420 
421  with self.assertRaises(asyncio.CancelledError):
422  await call
423 
424  self.assertTrue(call.cancelled())
425  self.assertTrue(call.done())
426  self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
427 
428  await channel.close()
429 
430  async def test_cancel_by_the_interceptor(self):
431 
432  class Interceptor(aio.StreamUnaryClientInterceptor):
433 
434  async def intercept_stream_unary(self, continuation,
435  client_call_details,
436  request_iterator):
437  call = await continuation(client_call_details, request_iterator)
438  call.cancel()
439  return call
440 
441  channel = aio.insecure_channel(UNREACHABLE_TARGET,
442  interceptors=[Interceptor()])
443  stub = test_pb2_grpc.TestServiceStub(channel)
444 
445  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
446  request = messages_pb2.StreamingInputCallRequest(payload=payload)
447 
448  call = stub.StreamingInputCall()
449 
450  with self.assertRaises(asyncio.InvalidStateError):
451  for i in range(_NUM_STREAM_REQUESTS):
452  await call.write(request)
453 
454  with self.assertRaises(asyncio.CancelledError):
455  await call
456 
457  self.assertTrue(call.cancelled())
458  self.assertTrue(call.done())
459  self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
460 
461  await channel.close()
462 
463  async def test_exception_raised_by_interceptor(self):
464 
465  class InterceptorException(Exception):
466  pass
467 
468  class Interceptor(aio.StreamUnaryClientInterceptor):
469 
470  async def intercept_stream_unary(self, continuation,
471  client_call_details,
472  request_iterator):
473  raise InterceptorException
474 
475  channel = aio.insecure_channel(UNREACHABLE_TARGET,
476  interceptors=[Interceptor()])
477  stub = test_pb2_grpc.TestServiceStub(channel)
478 
479  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
480  request = messages_pb2.StreamingInputCallRequest(payload=payload)
481 
482  call = stub.StreamingInputCall()
483 
484  with self.assertRaises(InterceptorException):
485  for i in range(_NUM_STREAM_REQUESTS):
486  await call.write(request)
487 
488  with self.assertRaises(InterceptorException):
489  await call
490 
491  await channel.close()
492 
493  async def test_intercepts_prohibit_mixing_style(self):
494  channel = aio.insecure_channel(
495  self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()])
496  stub = test_pb2_grpc.TestServiceStub(channel)
497 
498  payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
499  request = messages_pb2.StreamingInputCallRequest(payload=payload)
500 
501  async def request_iterator():
502  for _ in range(_NUM_STREAM_REQUESTS):
503  yield request
504 
505  call = stub.StreamingInputCall(request_iterator())
506 
507  with self.assertRaises(grpc._cython.cygrpc.UsageError):
508  await call.write(request)
509 
510  with self.assertRaises(grpc._cython.cygrpc.UsageError):
511  await call.done_writing()
512 
513  await channel.close()
514 
515 
516 if __name__ == '__main__':
517  logging.basicConfig(level=logging.DEBUG)
518  unittest.main(verbosity=2)
tests_aio.unit.client_stream_unary_interceptor_test._StreamUnaryInterceptorWithRequestIterator
Definition: client_stream_unary_interceptor_test.py:49
http2_test_server.format
format
Definition: http2_test_server.py:118
tests_aio.unit._test_base.AioTestBase.loop
def loop(self)
Definition: _test_base.py:55
tests_aio.unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor._server
_server
Definition: client_stream_unary_interceptor_test.py:65
tests_aio.unit._test_server
Definition: tests_aio/unit/_test_server.py:1
capstone.range
range
Definition: third_party/bloaty/third_party/capstone/bindings/python/capstone/__init__.py:6
tests_aio.unit.client_stream_unary_interceptor_test._StreamUnaryInterceptorEmpty
Definition: client_stream_unary_interceptor_test.py:38
tests_aio.unit._test_server.start_test_server
def start_test_server(port=0, secure=False, server_credentials=None, interceptors=None)
Definition: tests_aio/unit/_test_server.py:128
tests_aio.unit._test_base
Definition: _test_base.py:1
tests_aio.unit._common.CountingRequestIterator
Definition: tests/tests_aio/unit/_common.py:77
xds_interop_client.int
int
Definition: xds_interop_client.py:113
grpc::experimental
Definition: include/grpcpp/channel.h:46
tests_aio.unit.client_stream_unary_interceptor_test._StreamUnaryInterceptorWithRequestIterator.request_iterator
request_iterator
Definition: client_stream_unary_interceptor_test.py:52
tests_aio.unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor.setUp
def setUp(self)
Definition: client_stream_unary_interceptor_test.py:64
messages_pb2.Payload
Payload
Definition: messages_pb2.py:583
tests_aio.unit._common.inject_callbacks
def inject_callbacks(aio.Call call)
Definition: tests/tests_aio/unit/_common.py:48
tests_aio.unit.client_stream_unary_interceptor_test._StreamUnaryInterceptorWithRequestIterator.intercept_stream_unary
def intercept_stream_unary(self, continuation, client_call_details, request_iterator)
Definition: client_stream_unary_interceptor_test.py:51
tests_aio.unit.client_stream_unary_interceptor_test._StreamUnaryInterceptorEmpty.intercept_stream_unary
def intercept_stream_unary(self, continuation, client_call_details, request_iterator)
Definition: client_stream_unary_interceptor_test.py:40
tests_aio.unit._common
Definition: tests/tests_aio/unit/_common.py:1
tests_aio.unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor
Definition: client_stream_unary_interceptor_test.py:62
messages_pb2.StreamingInputCallRequest
StreamingInputCallRequest
Definition: messages_pb2.py:611
stop
static const char stop[]
Definition: benchmark-async-pummel.c:35
tests.unit.framework.common
Definition: src/python/grpcio_tests/tests/unit/framework/common/__init__.py:1
tests_aio.unit._constants
Definition: _constants.py:1
tests_aio.unit._test_base.AioTestBase
Definition: _test_base.py:49


grpc
Author(s):
autogenerated on Fri May 16 2025 02:57:55