14 """Testing the compatibility between AsyncIO stack and the old stack."""
17 from concurrent.futures
import ThreadPoolExecutor
22 from typing
import Callable, Iterable, Sequence, Tuple
28 from src.proto.grpc.testing
import messages_pb2
29 from src.proto.grpc.testing
import test_pb2_grpc
36 _NUM_STREAM_RESPONSES = 5
37 _REQUEST_PAYLOAD_SIZE = 7
38 _RESPONSE_PAYLOAD_SIZE = 42
39 _REQUEST = b
'\x03\x07'
43 return ((
'iv', random.random()),)
47 os.environ.get(
'GRPC_ASYNCIO_ENGINE',
'').lower() ==
'custom_io_manager',
48 'Compatible mode needs POLLER completion queue.')
53 options=((
'grpc.so_reuseport', 0),),
54 migration_thread_pool=ThreadPoolExecutor())
62 address =
'localhost:%d' % port
75 async
def tearDown(self):
80 async
def _run_in_another_thread(self, func: Callable[[],
None]):
81 work_done = asyncio.Event()
85 self.
loop.call_soon_threadsafe(work_done.set)
87 thread = threading.Thread(target=thread_work, daemon=
True)
89 await work_done.wait()
92 async
def test_unary_unary(self):
95 timeout=test_constants.LONG_TIMEOUT)
98 def sync_work() -> None:
99 response, call = self.
_sync_stub.UnaryCall.with_call(
101 timeout=test_constants.LONG_TIMEOUT)
102 self.assertIsInstance(response, messages_pb2.SimpleResponse)
103 self.assertEqual(grpc.StatusCode.OK, call.code())
105 await self._run_in_another_thread(sync_work)
107 async
def test_unary_stream(self):
109 for _
in range(_NUM_STREAM_RESPONSES):
110 request.response_parameters.append(
114 call = self.
_async_stub.StreamingOutputCall(request)
116 for _
in range(_NUM_STREAM_RESPONSES):
118 self.assertEqual(grpc.StatusCode.OK, await call.code())
121 def sync_work() -> None:
122 response_iterator = self.
_sync_stub.StreamingOutputCall(request)
123 for response
in response_iterator:
124 assert _RESPONSE_PAYLOAD_SIZE ==
len(response.payload.body)
125 self.assertEqual(grpc.StatusCode.OK, response_iterator.code())
127 await self._run_in_another_thread(sync_work)
129 async
def test_stream_unary(self):
135 for _
in range(_NUM_STREAM_RESPONSES):
139 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
140 response.aggregated_payload_size)
143 def sync_work() -> None:
144 response = self.
_sync_stub.StreamingInputCall(
145 iter([request] * _NUM_STREAM_RESPONSES))
146 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
147 response.aggregated_payload_size)
149 await self._run_in_another_thread(sync_work)
151 async
def test_stream_stream(self):
153 request.response_parameters.append(
159 for _
in range(_NUM_STREAM_RESPONSES):
160 await call.write(request)
161 response = await call.read()
162 assert _RESPONSE_PAYLOAD_SIZE ==
len(response.payload.body)
164 await call.done_writing()
165 assert await call.code() == grpc.StatusCode.OK
168 def sync_work() -> None:
169 response_iterator = self.
_sync_stub.FullDuplexCall(
iter([request]))
170 for response
in response_iterator:
171 assert _RESPONSE_PAYLOAD_SIZE ==
len(response.payload.body)
172 self.assertEqual(grpc.StatusCode.OK, response_iterator.code())
174 await self._run_in_another_thread(sync_work)
180 def service(self, handler_call_details):
186 handlers=(GenericHandlers(),))
187 port = server.add_insecure_port(
'localhost:0')
190 def sync_work() -> None:
193 response = channel.unary_unary(
'/test/test')(b
'\x07\x08')
194 self.assertEqual(response, b
'\x07\x08')
196 await self._run_in_another_thread(sync_work)
198 async
def test_many_loop(self):
204 async
def async_work():
206 async_channel = aio.insecure_channel(address,
208 async_stub = test_pb2_grpc.TestServiceStub(async_channel)
211 response = await call
212 self.assertIsInstance(response, messages_pb2.SimpleResponse)
213 self.assertEqual(grpc.StatusCode.OK, await call.code())
215 loop = asyncio.new_event_loop()
216 loop.run_until_complete(async_work())
218 await self._run_in_another_thread(sync_work)
219 await server.stop(
None)
221 async
def test_sync_unary_unary_success(self):
223 @grpc.unary_unary_rpc_method_handler
224 def echo_unary_unary(request: bytes, unused_context):
230 self.assertEqual(_REQUEST, response)
232 async
def test_sync_unary_unary_metadata(self):
233 metadata = ((
'unique',
'key-42'),)
235 @grpc.unary_unary_rpc_method_handler
237 context.send_initial_metadata(metadata)
243 _common.seen_metadata(aio.Metadata(*metadata), await
244 call.initial_metadata()))
246 async
def test_sync_unary_unary_abort(self):
248 @grpc.unary_unary_rpc_method_handler
250 context.abort(grpc.StatusCode.INTERNAL,
'Test')
253 with self.assertRaises(aio.AioRpcError)
as exception_context:
256 self.assertEqual(grpc.StatusCode.INTERNAL,
257 exception_context.exception.code())
259 async
def test_sync_unary_unary_set_code(self):
261 @grpc.unary_unary_rpc_method_handler
263 context.set_code(grpc.StatusCode.INTERNAL)
266 with self.assertRaises(aio.AioRpcError)
as exception_context:
269 self.assertEqual(grpc.StatusCode.INTERNAL,
270 exception_context.exception.code())
272 async
def test_sync_unary_stream_success(self):
274 @grpc.unary_stream_rpc_method_handler
275 def echo_unary_stream(request: bytes, unused_context):
276 for _
in range(_NUM_STREAM_RESPONSES):
281 async
for response
in call:
282 self.assertEqual(_REQUEST, response)
284 async
def test_sync_unary_stream_error(self):
286 @grpc.unary_stream_rpc_method_handler
287 def error_unary_stream(request: bytes, unused_context):
288 for _
in range(_NUM_STREAM_RESPONSES):
290 raise RuntimeError(
'Test')
294 with self.assertRaises(aio.AioRpcError)
as exception_context:
295 async
for response
in call:
296 self.assertEqual(_REQUEST, response)
297 self.assertEqual(grpc.StatusCode.UNKNOWN,
298 exception_context.exception.code())
300 async
def test_sync_stream_unary_success(self):
302 @grpc.stream_unary_rpc_method_handler
303 def echo_stream_unary(request_iterator: Iterable[bytes],
305 self.assertEqual(
len(list(request_iterator)), _NUM_STREAM_RESPONSES)
309 request_iterator =
iter([_REQUEST] * _NUM_STREAM_RESPONSES)
312 self.assertEqual(_REQUEST, response)
314 async
def test_sync_stream_unary_error(self):
316 @grpc.stream_unary_rpc_method_handler
317 def echo_stream_unary(request_iterator: Iterable[bytes],
319 self.assertEqual(
len(list(request_iterator)), _NUM_STREAM_RESPONSES)
320 raise RuntimeError(
'Test')
323 request_iterator =
iter([_REQUEST] * _NUM_STREAM_RESPONSES)
324 with self.assertRaises(aio.AioRpcError)
as exception_context:
326 _common.ADHOC_METHOD)(request_iterator)
327 self.assertEqual(grpc.StatusCode.UNKNOWN,
328 exception_context.exception.code())
330 async
def test_sync_stream_stream_success(self):
332 @grpc.stream_stream_rpc_method_handler
333 def echo_stream_stream(request_iterator: Iterable[bytes],
335 for request
in request_iterator:
339 request_iterator =
iter([_REQUEST] * _NUM_STREAM_RESPONSES)
341 _common.ADHOC_METHOD)(request_iterator)
342 async
for response
in call:
343 self.assertEqual(_REQUEST, response)
345 async
def test_sync_stream_stream_error(self):
347 @grpc.stream_stream_rpc_method_handler
348 def echo_stream_stream(request_iterator: Iterable[bytes],
350 for request
in request_iterator:
352 raise RuntimeError(
'test')
355 request_iterator =
iter([_REQUEST] * _NUM_STREAM_RESPONSES)
357 _common.ADHOC_METHOD)(request_iterator)
358 with self.assertRaises(aio.AioRpcError)
as exception_context:
359 async
for response
in call:
360 self.assertEqual(_REQUEST, response)
361 self.assertEqual(grpc.StatusCode.UNKNOWN,
362 exception_context.exception.code())
365 if __name__ ==
'__main__':
366 logging.basicConfig(level=logging.DEBUG)
367 unittest.main(verbosity=2)