14 """Implementations of interoperability test methods."""
26 from typing
import Any, Optional, Union
28 from google
import auth
as google_auth
29 from google.auth
import environment_vars
as google_auth_environment_vars
30 from google.auth.transport
import grpc
as google_auth_transport_grpc
31 from google.auth.transport
import requests
as google_auth_transport_requests
35 from src.proto.grpc.testing
import empty_pb2
36 from src.proto.grpc.testing
import messages_pb2
37 from src.proto.grpc.testing
import test_pb2_grpc
39 _INITIAL_METADATA_KEY =
"x-grpc-test-echo-initial"
40 _TRAILING_METADATA_KEY =
"x-grpc-test-echo-trailing-bin"
45 code = await call.code()
46 if code != expected_code:
47 raise ValueError(
'expected code %s, got %s' %
48 (expected_code, await call.code()))
52 details = await call.details()
53 if details != expected_details:
54 raise ValueError(
'expected message %s, got %s' %
55 (expected_details, await call.details()))
60 expected_details: str) ->
None:
66 messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse],
68 expected_length: int) ->
None:
69 if response.payload.type
is not expected_type:
70 raise ValueError(
'expected payload type %s, got %s' %
71 (expected_type,
type(response.payload.type)))
72 elif len(response.payload.body) != expected_length:
73 raise ValueError(
'expected payload body size %d, got %d' %
74 (expected_length,
len(response.payload.body)))
78 stub: test_pb2_grpc.TestServiceStub, fill_username: bool,
80 ) -> messages_pb2.SimpleResponse:
83 response_type=messages_pb2.COMPRESSABLE,
86 fill_username=fill_username,
87 fill_oauth_scope=fill_oauth_scope)
88 response = await stub.UnaryCall(request, credentials=call_credentials)
93 async
def _empty_unary(stub: test_pb2_grpc.TestServiceStub) ->
None:
94 response = await stub.EmptyCall(empty_pb2.Empty())
95 if not isinstance(response, empty_pb2.Empty):
96 raise TypeError(
'response is of type "%s", not empty_pb2.Empty!' %
105 payload_body_sizes = (
112 async
def request_gen():
113 for size
in payload_body_sizes:
117 response = await stub.StreamingInputCall(request_gen())
118 if response.aggregated_payload_size !=
sum(payload_body_sizes):
119 raise ValueError(
'incorrect size %d!' %
120 response.aggregated_payload_size)
132 response_type=messages_pb2.COMPRESSABLE,
133 response_parameters=(
139 call = stub.StreamingOutputCall(request)
141 response = await call.read()
146 async
def _ping_pong(stub: test_pb2_grpc.TestServiceStub) ->
None:
147 request_response_sizes = (
153 request_payload_sizes = (
160 call = stub.FullDuplexCall()
161 for response_size, payload_size
in zip(request_response_sizes,
162 request_payload_sizes):
164 response_type=messages_pb2.COMPRESSABLE,
166 size=response_size),),
169 await call.write(request)
170 response = await call.read()
173 await call.done_writing()
178 call = stub.StreamingInputCall()
180 if not call.cancelled():
181 raise ValueError(
'expected cancelled method to return True')
182 code = await call.code()
183 if code
is not grpc.StatusCode.CANCELLED:
184 raise ValueError(
'expected status code CANCELLED')
188 request_response_sizes = (
194 request_payload_sizes = (
201 call = stub.FullDuplexCall()
203 response_size = request_response_sizes[0]
204 payload_size = request_payload_sizes[0]
206 response_type=messages_pb2.COMPRESSABLE,
208 size=response_size),),
211 await call.write(request)
218 except asyncio.CancelledError:
219 assert await call.code()
is grpc.StatusCode.CANCELLED
221 raise ValueError(
'expected call to be cancelled')
225 request_payload_size = 27182
226 time_limit = datetime.timedelta(seconds=1)
228 call = stub.FullDuplexCall(timeout=time_limit.total_seconds())
231 response_type=messages_pb2.COMPRESSABLE,
234 interval_us=
int(time_limit.total_seconds() * 2 * 10**6)),))
235 await call.write(request)
236 await call.done_writing()
239 except aio.AioRpcError
as rpc_error:
240 if rpc_error.code()
is not grpc.StatusCode.DEADLINE_EXCEEDED:
243 raise ValueError(
'expected call to exceed deadline')
247 call = stub.FullDuplexCall()
248 await call.done_writing()
249 assert await call.read() == aio.EOF
253 details =
'test status message'
254 status = grpc.StatusCode.UNKNOWN
258 response_type=messages_pb2.COMPRESSABLE,
263 call = stub.UnaryCall(request)
267 call = stub.FullDuplexCall()
269 response_type=messages_pb2.COMPRESSABLE,
274 await call.write(request)
275 await call.done_writing()
278 except aio.AioRpcError
as rpc_error:
279 assert rpc_error.code() == status
284 call = stub.UnimplementedCall(empty_pb2.Empty())
289 call = stub.UnimplementedCall(empty_pb2.Empty())
294 initial_metadata_value =
"test_initial_metadata_value"
295 trailing_metadata_value = b
"\x0a\x0b\x0a\x0b\x0a\x0b"
296 metadata = aio.Metadata(
297 (_INITIAL_METADATA_KEY, initial_metadata_value),
298 (_TRAILING_METADATA_KEY, trailing_metadata_value),
301 async
def _validate_metadata(call):
302 initial_metadata = await call.initial_metadata()
303 if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
304 raise ValueError(
'expected initial metadata %s, got %s' %
305 (initial_metadata_value,
306 initial_metadata[_INITIAL_METADATA_KEY]))
308 trailing_metadata = await call.trailing_metadata()
309 if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
310 raise ValueError(
'expected trailing metadata %s, got %s' %
311 (trailing_metadata_value,
312 trailing_metadata[_TRAILING_METADATA_KEY]))
316 response_type=messages_pb2.COMPRESSABLE,
319 call = stub.UnaryCall(request, metadata=metadata)
320 await _validate_metadata(call)
323 call = stub.FullDuplexCall(metadata=metadata)
325 response_type=messages_pb2.COMPRESSABLE,
327 await call.write(request)
329 await call.done_writing()
330 await _validate_metadata(call)
334 args: argparse.Namespace):
336 if args.default_service_account != response.username:
337 raise ValueError(
'expected username %s, got %s' %
338 (args.default_service_account, response.username))
342 args: argparse.Namespace):
343 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
344 wanted_email = json.load(
open(json_key_filename,
'r'))[
'client_email']
346 if wanted_email != response.username:
347 raise ValueError(
'expected username %s, got %s' %
348 (wanted_email, response.username))
349 if args.oauth_scope.find(response.oauth_scope) == -1:
351 'expected to find oauth scope "{}" in received "{}"'.
format(
352 response.oauth_scope, args.oauth_scope))
356 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
357 wanted_email = json.load(
open(json_key_filename,
'r'))[
'client_email']
359 if wanted_email != response.username:
360 raise ValueError(
'expected username %s, got %s' %
361 (wanted_email, response.username))
365 args: argparse.Namespace):
366 json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
367 wanted_email = json.load(
open(json_key_filename,
'r'))[
'client_email']
368 google_credentials, unused_project_id = google_auth.default(
369 scopes=[args.oauth_scope])
371 google_auth_transport_grpc.AuthMetadataPlugin(
372 credentials=google_credentials,
373 request=google_auth_transport_requests.Request()))
376 if wanted_email != response.username:
377 raise ValueError(
'expected username %s, got %s' %
378 (wanted_email, response.username))
382 details = b
'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.
decode(
384 status = grpc.StatusCode.UNKNOWN
388 response_type=messages_pb2.COMPRESSABLE,
393 call = stub.UnaryCall(request)
399 EMPTY_UNARY =
'empty_unary'
400 LARGE_UNARY =
'large_unary'
401 SERVER_STREAMING =
'server_streaming'
402 CLIENT_STREAMING =
'client_streaming'
403 PING_PONG =
'ping_pong'
404 CANCEL_AFTER_BEGIN =
'cancel_after_begin'
405 CANCEL_AFTER_FIRST_RESPONSE =
'cancel_after_first_response'
406 TIMEOUT_ON_SLEEPING_SERVER =
'timeout_on_sleeping_server'
407 EMPTY_STREAM =
'empty_stream'
408 STATUS_CODE_AND_MESSAGE =
'status_code_and_message'
409 UNIMPLEMENTED_METHOD =
'unimplemented_method'
410 UNIMPLEMENTED_SERVICE =
'unimplemented_service'
411 CUSTOM_METADATA =
"custom_metadata"
412 COMPUTE_ENGINE_CREDS =
'compute_engine_creds'
413 OAUTH2_AUTH_TOKEN =
'oauth2_auth_token'
414 JWT_TOKEN_CREDS =
'jwt_token_creds'
415 PER_RPC_CREDS =
'per_rpc_creds'
416 SPECIAL_STATUS_MESSAGE =
'special_status_message'
419 _TEST_CASE_IMPLEMENTATION_MAPPING = {
420 TestCase.EMPTY_UNARY: _empty_unary,
421 TestCase.LARGE_UNARY: _large_unary,
422 TestCase.SERVER_STREAMING: _server_streaming,
423 TestCase.CLIENT_STREAMING: _client_streaming,
424 TestCase.PING_PONG: _ping_pong,
425 TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin,
426 TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response,
427 TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server,
428 TestCase.EMPTY_STREAM: _empty_stream,
429 TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message,
430 TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method,
431 TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service,
432 TestCase.CUSTOM_METADATA: _custom_metadata,
433 TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds,
434 TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token,
435 TestCase.JWT_TOKEN_CREDS: _jwt_token_creds,
436 TestCase.PER_RPC_CREDS: _per_rpc_creds,
437 TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message,
443 stub: test_pb2_grpc.TestServiceStub,
444 args: Optional[argparse.Namespace] =
None) ->
None:
445 method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case)
447 raise NotImplementedError(f
'Test case "{case}" not implemented!')
449 num_params =
len(inspect.signature(method).parameters)
452 elif num_params == 2:
456 raise ValueError(f
'Failed to run case [{case}]: args is None')
458 raise ValueError(f
'Invalid number of parameters [{num_params}]')