14 """Testing the done callbacks mechanism."""
23 from src.proto.grpc.testing
import messages_pb2
24 from src.proto.grpc.testing
import test_pb2_grpc
29 _NUM_STREAM_RESPONSES = 5
30 _REQUEST_PAYLOAD_SIZE = 7
31 _RESPONSE_PAYLOAD_SIZE = 42
32 _REQUEST = b
'\x01\x02\x03'
33 _RESPONSE = b
'\x04\x05\x06'
34 _TEST_METHOD =
'/test/Test'
35 _FAKE_METHOD =
'/test/Fake'
45 async
def tearDown(self):
49 async
def test_add_after_done(self):
51 self.assertEqual(grpc.StatusCode.OK, await call.code())
56 async
def test_unary_unary(self):
60 self.assertEqual(grpc.StatusCode.OK, await call.code())
64 async
def test_unary_stream(self):
66 for _
in range(_NUM_STREAM_RESPONSES):
67 request.response_parameters.append(
70 call = self.
_stub.StreamingOutputCall(request)
74 async
for response
in call:
76 self.assertIsInstance(response,
77 messages_pb2.StreamingOutputCallResponse)
78 self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
80 self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
81 self.assertEqual(grpc.StatusCode.OK, await call.code())
85 async
def test_stream_unary(self):
90 for _
in range(_NUM_STREAM_RESPONSES):
93 call = self.
_stub.StreamingInputCall(
gen())
97 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
98 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
99 response.aggregated_payload_size)
100 self.assertEqual(grpc.StatusCode.OK, await call.code())
104 async
def test_stream_stream(self):
105 call = self.
_stub.FullDuplexCall()
109 request.response_parameters.append(
112 for _
in range(_NUM_STREAM_RESPONSES):
113 await call.write(request)
114 response = await call.read()
115 self.assertIsInstance(response,
116 messages_pb2.StreamingOutputCallResponse)
117 self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
119 await call.done_writing()
121 self.assertEqual(grpc.StatusCode.OK, await call.code())
129 port = self.
_server.add_insecure_port(
'[::]:0')
130 self.
_channel = aio.insecure_channel(
'localhost:%d' % port)
132 async
def tearDown(self):
136 async
def _register_method_handler(self, method_handler):
137 """Registers method handler and starts the server"""
140 dict(Test=method_handler),
142 self.
_server.add_generic_rpc_handlers((generic_handler,))
145 async
def test_unary_unary(self):
146 validation_future = self.
loop.create_future()
148 async
def test_handler(request: bytes, context: aio.ServicerContext):
149 self.assertEqual(_REQUEST, request)
153 await self._register_method_handler(
156 self.assertEqual(_RESPONSE, response)
158 validation = await validation_future
161 async
def test_unary_stream(self):
162 validation_future = self.
loop.create_future()
164 async
def test_handler(request: bytes, context: aio.ServicerContext):
165 self.assertEqual(_REQUEST, request)
167 for _
in range(_NUM_STREAM_RESPONSES):
170 await self._register_method_handler(
173 async
for response
in call:
174 self.assertEqual(_RESPONSE, response)
176 validation = await validation_future
179 async
def test_stream_unary(self):
180 validation_future = self.
loop.create_future()
182 async
def test_handler(request_iterator, context: aio.ServicerContext):
185 async
for request
in request_iterator:
186 self.assertEqual(_REQUEST, request)
189 await self._register_method_handler(
192 for _
in range(_NUM_STREAM_RESPONSES):
193 await call.write(_REQUEST)
194 await call.done_writing()
195 self.assertEqual(_RESPONSE, await call)
197 validation = await validation_future
200 async
def test_stream_stream(self):
201 validation_future = self.
loop.create_future()
203 async
def test_handler(request_iterator, context: aio.ServicerContext):
206 async
for request
in request_iterator:
207 self.assertEqual(_REQUEST, request)
210 await self._register_method_handler(
213 for _
in range(_NUM_STREAM_RESPONSES):
214 await call.write(_REQUEST)
215 await call.done_writing()
216 async
for response
in call:
217 self.assertEqual(_RESPONSE, response)
219 validation = await validation_future
222 async
def test_error_in_handler(self):
223 """Errors in the handler still triggers callbacks."""
224 validation_future = self.
loop.create_future()
226 async
def test_handler(request: bytes, context: aio.ServicerContext):
227 self.assertEqual(_REQUEST, request)
229 raise RuntimeError(
'A test RuntimeError')
231 await self._register_method_handler(
233 with self.assertRaises(aio.AioRpcError)
as exception_context:
235 rpc_error = exception_context.exception
236 self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code())
238 validation = await validation_future
241 async
def test_error_in_callback(self):
242 """Errors in the callback won't be propagated to client."""
243 validation_future = self.
loop.create_future()
245 async
def test_handler(request: bytes, context: aio.ServicerContext):
246 self.assertEqual(_REQUEST, request)
248 def exception_raiser(unused_context):
249 raise RuntimeError(
'A test RuntimeError')
251 context.add_done_callback(exception_raiser)
255 await self._register_method_handler(
259 self.assertEqual(_RESPONSE, response)
262 validation = await validation_future
263 with self.assertRaises(asyncio.TimeoutError):
268 with self.assertRaises(aio.AioRpcError)
as exception_context:
270 rpc_error = exception_context.exception
271 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
274 if __name__ ==
'__main__':
275 logging.basicConfig(level=logging.DEBUG)
276 unittest.main(verbosity=2)