17 from concurrent
import futures
24 from typing
import DefaultDict, Dict, List, Mapping, Sequence, Set, Tuple
30 from src.proto.grpc.testing
import empty_pb2
31 from src.proto.grpc.testing
import messages_pb2
32 from src.proto.grpc.testing
import test_pb2
33 from src.proto.grpc.testing
import test_pb2_grpc
35 logger = logging.getLogger()
36 console_handler = logging.StreamHandler()
37 formatter = logging.Formatter(fmt=
'%(asctime)s: %(levelname)-8s %(message)s')
38 console_handler.setFormatter(formatter)
39 logger.addHandler(console_handler)
41 _SUPPORTED_METHODS = (
46 _METHOD_CAMEL_TO_CAPS_SNAKE = {
47 "UnaryCall":
"UNARY_CALL",
48 "EmptyCall":
"EMPTY_CALL",
51 _METHOD_STR_TO_ENUM = {
52 "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
53 "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
56 _METHOD_ENUM_TO_STR = {v: k
for k, v
in _METHOD_STR_TO_ENUM.items()}
58 PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
60 _CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
67 _rpcs_by_peer: DefaultDict[str, int]
68 _rpcs_by_method: DefaultDict[str, DefaultDict[str, int]]
71 _condition: threading.Condition
79 lambda: collections.defaultdict(int))
84 """Records statistics for a single RPC."""
96 self, timeout_sec: int) -> messages_pb2.LoadBalancerStatsResponse:
97 """Blocks until a full response has been collected."""
100 timeout=float(timeout_sec))
101 response = messages_pb2.LoadBalancerStatsResponse()
103 response.rpcs_by_peer[peer] = count
105 for peer, count
in count_by_peer.items():
106 response.rpcs_by_method[method].rpcs_by_peer[peer] = count
111 _global_lock = threading.Lock()
112 _stop_event = threading.Event()
113 _global_rpc_id: int = 0
114 _watchers: Set[_StatsWatcher] =
set()
115 _global_server =
None
116 _global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
117 _global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
118 _global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
121 _global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict(
122 lambda: collections.defaultdict(int))
126 logger.warning(
"Received SIGINT")
128 _global_server.stop(
None)
135 super(_LoadBalancerStatsServicer).
__init__()
138 self, request: messages_pb2.LoadBalancerStatsRequest,
140 ) -> messages_pb2.LoadBalancerStatsResponse:
141 logger.info(
"Received stats request.")
146 start = _global_rpc_id + 1
147 end = start + request.num_rpcs
149 _watchers.add(watcher)
150 response = watcher.await_rpc_stats_response(request.timeout_sec)
152 _watchers.remove(watcher)
153 logger.info(
"Returning stats response: %s", response)
157 self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest,
159 ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse:
160 logger.info(
"Received cumulative stats request.")
161 response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
163 for method
in _SUPPORTED_METHODS:
164 caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method]
165 response.num_rpcs_started_by_method[
166 caps_method] = _global_rpcs_started[method]
167 response.num_rpcs_succeeded_by_method[
168 caps_method] = _global_rpcs_succeeded[method]
169 response.num_rpcs_failed_by_method[
170 caps_method] = _global_rpcs_failed[method]
171 response.stats_per_method[
172 caps_method].rpcs_started = _global_rpcs_started[method]
173 for code, count
in _global_rpc_statuses[method].
items():
174 response.stats_per_method[caps_method].result[code] = count
175 logger.info(
"Returning cumulative stats response.")
179 def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]],
180 request_id: int, stub: test_pb2_grpc.TestServiceStub,
181 timeout: float, futures: Mapping[int, Tuple[
grpc.Future,
183 logger.debug(f
"Sending {method} request to backend: {request_id}")
184 if method ==
"UnaryCall":
188 elif method ==
"EmptyCall":
189 future = stub.EmptyCall.future(empty_pb2.Empty(),
193 raise ValueError(f
"Unrecognized method '{method}'.")
194 futures[request_id] = (future, method)
198 print_response: bool) ->
None:
199 exception = future.exception()
202 _global_rpc_statuses[method][future.code().value[0]] += 1
203 if exception
is not None:
205 _global_rpcs_failed[method] += 1
206 if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
207 logger.error(f
"RPC {rpc_id} timed out")
209 logger.error(exception)
211 response = future.result()
213 for metadatum
in future.initial_metadata():
214 if metadatum[0] ==
"hostname":
215 hostname = metadatum[1]
218 hostname = response.hostname
219 if future.code() == grpc.StatusCode.OK:
221 _global_rpcs_succeeded[method] += 1
224 _global_rpcs_failed[method] += 1
226 if future.code() == grpc.StatusCode.OK:
227 logger.debug(
"Successful response.")
229 logger.debug(f
"RPC failed: {call}")
231 for watcher
in _watchers:
232 watcher.on_rpc_complete(rpc_id, hostname, method)
236 print_response: bool) ->
None:
237 logger.debug(
"Removing completed RPCs")
239 for future_id, (future, method)
in futures.items():
241 _on_rpc_done(future_id, future, method, args.print_response)
242 done.append(future_id)
248 logger.info(
"Cancelling all remaining RPCs")
249 for future, _
in futures.values():
254 """Configuration for a single client channel.
256 Instances of this class are meant to be dealt with as PODs. That is,
257 data member should be accessed directly. This class is not thread-safe.
258 When accessing any of its members, the lock member should be held.
261 def __init__(self, method: str, metadata: Sequence[Tuple[str, str]],
262 qps: int, server: str, rpc_timeout_sec: int,
263 print_response: bool, secure_mode: bool):
277 global _global_rpc_id
278 with config.condition:
279 server = config.server
281 if config.secure_mode:
288 stub = test_pb2_grpc.TestServiceStub(channel)
290 while not _stop_event.is_set():
291 with config.condition:
293 config.condition.wait(
294 timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds())
297 duration_per_query = 1.0 / float(config.qps)
300 request_id = _global_rpc_id
302 _global_rpcs_started[config.method] += 1
304 end = start + duration_per_query
305 _start_rpc(config.method, config.metadata, request_id, stub,
306 float(config.rpc_timeout_sec), futures)
307 print_response = config.print_response
309 logger.debug(f
"Currently {len(futures)} in-flight RPCs")
312 time.sleep(end - now)
318 test_pb2_grpc.XdsUpdateClientConfigureServiceServicer):
320 def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration],
322 super(_XdsUpdateClientConfigureServicer).
__init__()
327 self, request: messages_pb2.ClientConfigureRequest,
329 ) -> messages_pb2.ClientConfigureResponse:
330 logger.info(
"Received Configure RPC: %s", request)
331 method_strs = [_METHOD_ENUM_TO_STR[t]
for t
in request.types]
332 for method
in _SUPPORTED_METHODS:
333 method_enum = _METHOD_STR_TO_ENUM[method]
335 if method
in method_strs:
337 metadata = ((md.key, md.value)
338 for md
in request.metadata
339 if md.type == method_enum)
342 if request.timeout_sec == 0:
343 timeout_sec = channel_config.rpc_timeout_sec
345 timeout_sec = request.timeout_sec
350 timeout_sec = channel_config.rpc_timeout_sec
351 with channel_config.condition:
352 channel_config.qps = qps
353 channel_config.metadata = list(metadata)
354 channel_config.rpc_timeout_sec = timeout_sec
355 channel_config.condition.notify_all()
356 return messages_pb2.ClientConfigureResponse()
360 """An object grouping together threads driving RPCs for a method."""
362 _channel_threads: List[threading.Thread]
365 channel_config: _ChannelConfiguration):
366 """Creates and starts a group of threads running the indicated method."""
368 for i
in range(num_channels):
369 thread = threading.Thread(target=_run_single_channel,
370 args=(channel_config,))
375 """Joins all threads referenced by the handle."""
377 channel_thread.join()
380 def _run(args: argparse.Namespace, methods: Sequence[str],
381 per_method_metadata: PerMethodMetadataType) ->
None:
382 logger.info(
"Starting python xDS Interop Client.")
383 global _global_server
386 for method
in _SUPPORTED_METHODS:
387 if method
in methods:
392 method, per_method_metadata.get(method, []), qps, args.server,
393 args.rpc_timeout_sec, args.print_response, args.secure_mode)
394 channel_configs[method] = channel_config
395 method_handles.append(
_MethodHandle(args.num_channels, channel_config))
396 _global_server =
grpc.server(futures.ThreadPoolExecutor())
397 _global_server.add_insecure_port(f
"0.0.0.0:{args.stats_port}")
398 test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
400 test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(
403 channelz.add_channelz_servicer(_global_server)
405 _global_server.start()
406 _global_server.wait_for_termination()
407 for method_handle
in method_handles:
412 metadata = metadata_arg.split(
",")
if args.metadata
else []
413 per_method_metadata = collections.defaultdict(list)
414 for metadatum
in metadata:
415 elems = metadatum.split(
":")
418 f
"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'")
419 if elems[0]
not in _SUPPORTED_METHODS:
420 raise ValueError(f
"Unrecognized method '{elems[0]}'")
421 per_method_metadata[elems[0]].append((elems[1], elems[2]))
422 return per_method_metadata
426 methods = rpc_arg.split(
",")
427 if set(methods) -
set(_SUPPORTED_METHODS):
428 raise ValueError(
"--rpc supported methods: {}".
format(
429 ", ".join(_SUPPORTED_METHODS)))
434 if arg.lower()
in (
"true",
"yes",
"y"):
436 elif arg.lower()
in (
"false",
"no",
"n"):
439 raise argparse.ArgumentTypeError(f
"Could not parse '{arg}' as a bool.")
442 if __name__ ==
"__main__":
443 parser = argparse.ArgumentParser(
444 description=
'Run Python XDS interop client.')
449 help=
"The number of channels from which to send requests.")
450 parser.add_argument(
"--print_response",
453 help=
"Write RPC response to STDOUT.")
458 help=
"The number of queries to send from each channel per second.")
459 parser.add_argument(
"--rpc_timeout_sec",
462 help=
"The per-RPC timeout in seconds.")
463 parser.add_argument(
"--server",
464 default=
"localhost:50051",
465 help=
"The address of the server.")
470 help=
"The port on which to expose the peer distribution stats service.")
475 help=
"If specified, uses xDS credentials to connect to the server.")
476 parser.add_argument(
'--verbose',
477 help=
'verbose log output',
480 parser.add_argument(
"--log_file",
483 help=
"A file to log to.")
484 rpc_help =
"A comma-delimited list of RPC methods to run. Must be one of "
485 rpc_help +=
", ".join(_SUPPORTED_METHODS)
487 parser.add_argument(
"--rpc", default=
"UnaryCall", type=str, help=rpc_help)
489 "A comma-delimited list of 3-tuples of the form " +
490 "METHOD:KEY:VALUE, e.g. " +
491 "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3")
492 parser.add_argument(
"--metadata", default=
"", type=str, help=metadata_help)
493 args = parser.parse_args()
494 signal.signal(signal.SIGINT, _handle_sigint)
496 logger.setLevel(logging.DEBUG)
498 file_handler = logging.FileHandler(args.log_file, mode=
'a')
499 file_handler.setFormatter(formatter)
500 logger.addHandler(file_handler)