21 from typing
import List, Optional, Tuple
23 from absl
import flags
24 from absl.testing
import absltest
28 from framework
import xds_flags
29 from framework
import xds_k8s_flags
30 from framework
import xds_url_map_testcase
43 logger = logging.getLogger(__name__)
46 _CHECK_LOCAL_CERTS = flags.DEFINE_bool(
49 help=
"Security Tests also check the value of local certs")
50 flags.adopt_module_key_flags(xds_flags)
51 flags.adopt_module_key_flags(xds_k8s_flags)
54 TrafficDirectorManager = traffic_director.TrafficDirectorManager
55 TrafficDirectorAppNetManager = traffic_director.TrafficDirectorAppNetManager
56 TrafficDirectorSecureManager = traffic_director.TrafficDirectorSecureManager
57 XdsTestServer = server_app.XdsTestServer
58 XdsTestClient = client_app.XdsTestClient
59 KubernetesServerRunner = server_app.KubernetesServerRunner
60 KubernetesClientRunner = client_app.KubernetesClientRunner
61 LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse
62 _ChannelState = grpc_channelz.ChannelState
63 _timedelta = datetime.timedelta
64 ClientConfig = grpc_csds.ClientConfig
66 _TD_CONFIG_MAX_WAIT_SEC = 600
70 """Indicates that TD config hasn't propagated yet, and it's safe to retry"""
75 client_runner: KubernetesClientRunner
78 gcp_api_manager: gcp.api.GcpApiManager
79 gcp_service_account: Optional[str]
80 k8s_api_manager: k8s.KubernetesApiManager
81 secondary_k8s_api_manager: k8s.KubernetesApiManager
85 resource_suffix: str =
''
88 resource_suffix_randomize: bool =
True
89 server_maintenance_port: Optional[int]
91 server_runner: KubernetesServerRunner
94 td: TrafficDirectorManager
95 td_bootstrap_image: str
99 """Overridden by the test class to decide if the config is supported.
102 A bool indicates if the given config is supported.
109 """Hook method for setting up class fixture before running tests in
112 logger.info(
'----- Testing %s -----', cls.__name__)
113 logger.info(
'Logs timezone: %s', time.localtime().tm_zone)
131 if xds_flags.RESOURCE_SUFFIX.value
is not None:
151 xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
153 xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value
158 xds_k8s_flags.KUBE_CONTEXT.value)
160 xds_k8s_flags.SECONDARY_KUBE_CONTEXT.value)
176 wait_for_healthy_status=True,
178 max_rate_per_endpoint: Optional[int] =
None):
179 if server_runner
is None:
180 server_runner = self.server_runner
182 neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg(
186 self.td.backend_service_add_neg_backends(
187 neg_name, neg_zones, max_rate_per_endpoint=max_rate_per_endpoint)
188 if wait_for_healthy_status:
189 self.td.wait_for_backends_healthy_status()
192 if server_runner
is None:
193 server_runner = self.server_runner
195 neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg(
199 self.td.backend_service_remove_neg_backends(neg_name, neg_zones)
202 test_client: XdsTestClient,
203 num_rpcs: int = 100):
206 failed =
int(lb_stats.num_failures)
207 self.assertLessEqual(
210 msg=f
'Expected all RPCs to succeed: {failed} of {num_rpcs} failed')
214 before: grpc_testing.LoadBalancerAccumulatedStatsResponse,
215 after: grpc_testing.LoadBalancerAccumulatedStatsResponse):
216 """Only diffs stats_per_method, as the other fields are deprecated."""
217 diff = grpc_testing.LoadBalancerAccumulatedStatsResponse()
218 for method, method_stats
in after.stats_per_method.items():
219 for status, count
in method_stats.result.items():
220 count -= before.stats_per_method[method].result[status]
222 raise AssertionError(
"Diff of count shouldn't be negative")
224 diff.stats_per_method[method].result[status] = count
229 method: str) ->
None:
230 """Assert all RPCs for a method are completing with a certain status."""
232 before_stats = test_client.get_load_balancer_accumulated_stats()
233 response_type =
'LoadBalancerAccumulatedStatsResponse'
234 logging.info(
'Received %s from test client %s: before:\n%s',
235 response_type, test_client.ip, before_stats)
236 time.sleep(duration.total_seconds())
237 after_stats = test_client.get_load_balancer_accumulated_stats()
238 logging.info(
'Received %s from test client %s: after:\n%s',
239 response_type, test_client.ip, after_stats)
243 stats = diff_stats.stats_per_method[method]
244 status = status_code.value[0]
245 for found_status, count
in stats.result.items():
246 if found_status != status
and count > 0:
247 self.fail(f
"Expected only status {status} but found status "
248 f
"{found_status} for method {method}:\n{diff_stats}")
249 self.assertGreater(stats.result[status_code.value[0]], 0)
252 test_client: XdsTestClient,
253 servers: List[XdsTestServer],
254 num_rpcs: int = 100):
255 retryer = retryers.constant_retryer(
256 wait_fixed=datetime.timedelta(seconds=1),
257 timeout=datetime.timedelta(seconds=_TD_CONFIG_MAX_WAIT_SEC),
258 log_level=logging.INFO)
262 except retryers.RetryError:
264 'Rpcs did not go to expected servers before timeout %s',
265 _TD_CONFIG_MAX_WAIT_SEC)
268 servers: List[XdsTestServer],
270 server_names = [server.pod_name
for server
in servers]
271 logger.info(
'Verifying RPCs go to %s', server_names)
273 failed =
int(lb_stats.num_failures)
274 self.assertLessEqual(
277 msg=f
'Expected all RPCs to succeed: {failed} of {num_rpcs} failed')
278 for server_name
in server_names:
279 self.assertIn(server_name, lb_stats.rpcs_by_peer,
280 f
'{server_name} did not receive RPCs')
281 for peer
in lb_stats.rpcs_by_peer.keys():
282 self.assertIn(peer, server_names,
283 f
'Unexpected server {peer} received RPCs')
286 config = test_client.csds.fetch_client_status(log_level=logging.INFO)
287 self.assertIsNotNone(config)
295 for xds_config
in config.xds_config:
296 seen.add(xds_config.WhichOneof(
'per_xds_config'))
297 for generic_xds_config
in config.generic_xds_configs:
298 if re.search(
r'\.Listener$', generic_xds_config.type_url):
299 seen.add(
'listener_config')
300 elif re.search(
r'\.RouteConfiguration$',
301 generic_xds_config.type_url):
302 seen.add(
'route_config')
303 elif re.search(
r'\.Cluster$', generic_xds_config.type_url):
304 seen.add(
'cluster_config')
305 elif re.search(
r'\.ClusterLoadAssignment$',
306 generic_xds_config.type_url):
307 seen.add(
'endpoint_config')
308 logger.debug(
'Received xDS config dump: %s',
309 json_format.MessageToJson(config, indent=2))
310 self.assertSameElements(want, seen)
313 self, test_client: XdsTestClient,
314 previous_route_config_version: str, retry_wait_second: int,
315 timeout_second: int):
316 retryer = retryers.constant_retryer(
317 wait_fixed=datetime.timedelta(seconds=retry_wait_second),
318 timeout=datetime.timedelta(seconds=timeout_second),
319 retry_on_exceptions=(TdPropagationRetryableError,),
321 log_level=logging.INFO)
323 for attempt
in retryer:
326 raw_config = test_client.csds.fetch_client_status(
327 log_level=logging.INFO)
329 json_format.MessageToDict(raw_config))
330 route_config_version = dumped_config.rds_version
331 if previous_route_config_version == route_config_version:
333 'Routing config not propagated yet. Retrying.')
335 "CSDS not get updated routing config corresponding"
336 " to the second set of url maps")
340 (
'[SUCCESS] Confirmed successful RPC with the '
341 'updated routing config, version=%s'),
342 route_config_version)
343 except retryers.RetryError
as retry_error:
345 (
'Retry exhausted. TD routing config propagation failed after '
346 'timeout %ds. Last seen client config dump: %s'),
347 timeout_second, dumped_config)
351 test_client: XdsTestClient,
352 num_rpcs: Optional[int] = 100):
354 failed =
int(lb_stats.num_failures)
358 msg=f
'Expected all RPCs to fail: {failed} of {num_rpcs} failed')
362 num_rpcs: int) -> LoadBalancerStatsResponse:
363 lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
365 'Received LoadBalancerStatsResponse from test client %s:\n%s',
366 test_client.ip, lb_stats)
371 for backend, rpcs_count
in lb_stats.rpcs_by_peer.items():
375 msg=f
'Backend {backend} did not receive a single RPC')
379 metaclass=abc.ABCMeta):
380 """Isolated test case.
382 Base class for tests cases where infra resources are created before
383 each test, and destroyed after.
387 """Hook method for setting up the test fixture before exercising it."""
392 logger.info(
'Test run resource prefix: %s, suffix: %s',
410 self.
td.create_firewall_rule(
425 raise NotImplementedError
429 raise NotImplementedError
433 raise NotImplementedError
436 logger.info(
'----- TestMethod %s teardown -----', self.id())
437 retryer = retryers.constant_retryer(wait_fixed=
_timedelta(seconds=10),
439 log_level=logging.INFO)
442 except retryers.RetryError:
443 logger.exception(
'Got error during teardown')
453 """Regular test case base class for testing PSM features in isolation."""
457 """Hook method for setting up class fixture before running tests in
463 KubernetesServerRunner.DEFAULT_MAINTENANCE_PORT
509 **kwargs) -> List[XdsTestServer]:
510 if server_runner
is None:
512 test_servers = server_runner.run(
513 replica_count=replica_count,
517 for test_server
in test_servers:
523 **kwargs) -> XdsTestClient:
526 test_client.wait_for_active_server_channel()
531 td: TrafficDirectorAppNetManager
544 """Test case base class for testing PSM security features in isolation."""
545 td: TrafficDirectorSecureManager
550 PLAINTEXT = enum.auto()
554 """Hook method for setting up class fixture before running tests in
565 KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT
588 deployment_template=
'server-secure.deployment.yaml',
603 deployment_template=
'client-secure.deployment.yaml',
610 replica_count=replica_count,
631 test_server: XdsTestServer,
633 wait_for_active_server_channel=
True,
634 **kwargs) -> XdsTestClient:
638 if wait_for_active_server_channel:
639 test_client.wait_for_active_server_channel()
643 test_client: XdsTestClient,
644 test_server: XdsTestServer):
646 test_client, test_server)
647 server_security: grpc_channelz.Security = server_socket.security
648 client_security: grpc_channelz.Security = client_socket.security
659 raise TypeError(
'Incorrect security mode')
662 server_security: grpc_channelz.Security):
663 self.assertEqual(client_security.WhichOneof(
'model'),
665 msg=
'(mTLS) Client socket security model must be TLS')
666 self.assertEqual(server_security.WhichOneof(
'model'),
668 msg=
'(mTLS) Server socket security model must be TLS')
669 server_tls, client_tls = server_security.tls, client_security.tls
672 self.assertNotEmpty(client_tls.remote_certificate,
673 msg=
"(mTLS) Client remote certificate is missing")
676 server_tls.local_certificate,
677 msg=
"(mTLS) Server local certificate is missing")
679 server_tls.local_certificate,
680 client_tls.remote_certificate,
681 msg=
"(mTLS) Server local certificate must match client's "
682 "remote certificate")
685 self.assertNotEmpty(server_tls.remote_certificate,
686 msg=
"(mTLS) Server remote certificate is missing")
689 client_tls.local_certificate,
690 msg=
"(mTLS) Client local certificate is missing")
692 server_tls.remote_certificate,
693 client_tls.local_certificate,
694 msg=
"(mTLS) Server remote certificate must match client's "
698 server_security: grpc_channelz.Security):
699 self.assertEqual(client_security.WhichOneof(
'model'),
701 msg=
'(TLS) Client socket security model must be TLS')
702 self.assertEqual(server_security.WhichOneof(
'model'),
704 msg=
'(TLS) Server socket security model must be TLS')
705 server_tls, client_tls = server_security.tls, client_security.tls
708 self.assertNotEmpty(client_tls.remote_certificate,
709 msg=
"(TLS) Client remote certificate is missing")
711 self.assertNotEmpty(server_tls.local_certificate,
712 msg=
"(TLS) Server local certificate is missing")
714 server_tls.local_certificate,
715 client_tls.remote_certificate,
716 msg=
"(TLS) Server local certificate must match client "
717 "remote certificate")
721 server_tls.remote_certificate,
722 msg=
"(TLS) Server remote certificate must be empty in TLS mode. "
723 "Is server security incorrectly configured for mTLS?")
725 client_tls.local_certificate,
726 msg=
"(TLS) Client local certificate must be empty in TLS mode. "
727 "Is client security incorrectly configured for mTLS?")
730 server_tls, client_tls = server_security.tls, client_security.tls
733 server_tls.local_certificate,
734 msg=
"(Plaintext) Server local certificate must be empty.")
736 client_tls.local_certificate,
737 msg=
"(Plaintext) Client local certificate must be empty.")
741 server_tls.remote_certificate,
742 msg=
"(Plaintext) Server remote certificate must be empty.")
744 client_tls.local_certificate,
745 msg=
"(Plaintext) Client local certificate must be empty.")
749 test_client: XdsTestClient,
751 times: Optional[int] =
None,
752 delay: Optional[_timedelta] =
None):
754 Asserts that the client repeatedly cannot reach the server.
756 With negative tests we can't be absolutely certain expected failure
757 state is not caused by something else.
758 To mitigate for this, we repeat the checks several times, and expect
759 all of them to succeed.
761 This is useful in case the channel eventually stabilizes, and RPCs pass.
764 test_client: An instance of XdsTestClient
765 times: Optional; A positive number of times to confirm that
766 the server is unreachable. Defaults to `3` attempts.
767 delay: Optional; Specifies how long to wait before the next check.
768 Defaults to `10` seconds.
770 if times
is None or times < 1:
775 for i
in range(1, times + 1):
778 logger.info(
'Check %s passed, waiting %s before the next check',
780 time.sleep(delay.total_seconds())
787 channel = test_client.wait_for_server_channel_state(
788 state=_ChannelState.TRANSIENT_FAILURE)
790 test_client.channelz.list_channel_subchannels(channel))
791 self.assertLen(subchannels,
793 msg=
"Client channel must have exactly one subchannel "
794 "in state TRANSIENT_FAILURE.")
798 test_client: XdsTestClient, test_server: XdsTestServer
799 ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
800 client_sock = test_client.get_active_server_channel_socket()
801 server_sock = test_server.get_server_socket_matching_client(client_sock)
802 return client_sock, server_sock
806 if security.WhichOneof(
'model') ==
'other':
807 return f
'other: <{security.other.name}={security.other.value}>'
809 return (f
'local: <{cls.debug_cert(security.tls.local_certificate)}>, '
810 f
'remote: <{cls.debug_cert(security.tls.remote_certificate)}>')
816 sha1 = hashlib.sha1(cert)
817 return f
'sha1={sha1.hexdigest()}, len={len(cert)}'