diff --git a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py index 41122047eb9d..b4ec5aba6701 100644 --- a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py +++ b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py @@ -49,10 +49,10 @@ class AutoscalingConfigProducer: """ def __init__(self, ray_cluster_name, ray_cluster_namespace): - self._headers, self._verify = node_provider.load_k8s_secrets() - self._ray_cr_url = node_provider.url_from_resource( - namespace=ray_cluster_namespace, path=f"rayclusters/{ray_cluster_name}" + self.kubernetes_api_client = node_provider.KubernetesHttpApiClient( + namespace=ray_cluster_namespace ) + self._ray_cr_path = f"rayclusters/{ray_cluster_name}" def __call__(self): ray_cr = self._fetch_ray_cr_from_k8s_with_retries() @@ -67,7 +67,7 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: """ for i in range(1, MAX_RAYCLUSTER_FETCH_TRIES + 1): try: - return self._fetch_ray_cr_from_k8s() + return self.kubernetes_api_client.get(self._ray_cr_path) except requests.HTTPError as e: if i < MAX_RAYCLUSTER_FETCH_TRIES: logger.exception( @@ -80,18 +80,6 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: # This branch is inaccessible. Raise to satisfy mypy. raise AssertionError - def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]: - result = requests.get( - self._ray_cr_url, - headers=self._headers, - timeout=node_provider.KUBERAY_REQUEST_TIMEOUT_S, - verify=self._verify, - ) - if not result.status_code == 200: - result.raise_for_status() - ray_cr = result.json() - return ray_cr - def _derive_autoscaling_config_from_ray_cr(ray_cr: Dict[str, Any]) -> Dict[str, Any]: provider_config = _generate_provider_config(ray_cr["metadata"]["namespace"]) @@ -179,7 +167,7 @@ def _generate_legacy_autoscaling_config_fields() -> Dict[str, Any]: def _generate_available_node_types_from_ray_cr_spec( - ray_cr_spec: Dict[str, Any] + ray_cr_spec: Dict[str, Any], ) -> Dict[str, Any]: """Formats autoscaler "available_node_types" field based on the Ray CR's group specs. diff --git a/python/ray/autoscaler/_private/kuberay/node_provider.py b/python/ray/autoscaler/_private/kuberay/node_provider.py index 6e788564d7a9..0bf01e550443 100644 --- a/python/ray/autoscaler/_private/kuberay/node_provider.py +++ b/python/ray/autoscaler/_private/kuberay/node_provider.py @@ -1,3 +1,4 @@ +import datetime import json import logging import os @@ -54,6 +55,8 @@ # Key for GKE label that identifies which multi-host replica a pod belongs to REPLICA_INDEX_KEY = "replicaIndex" +TOKEN_REFRESH_PERIOD = datetime.timedelta(minutes=1) + # Design: # Each modification the autoscaler wants to make is posted to the API server goal state @@ -264,7 +267,19 @@ class KubernetesHttpApiClient(IKubernetesHttpApiClient): def __init__(self, namespace: str, kuberay_crd_version: str = KUBERAY_CRD_VER): self._kuberay_crd_version = kuberay_crd_version self._namespace = namespace - self._headers, self._verify = load_k8s_secrets() + self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD + self._headers, self._verify = None, None + + def _get_refreshed_headers_and_verify(self): + if (datetime.datetime.now() >= self._token_expires_at) or ( + self._headers is None or self._verify is None + ): + logger.info("Refreshing K8s API client token and certs.") + self._headers, self._verify = load_k8s_secrets() + self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD + return self._headers, self._verify + else: + return self._headers, self._verify def get(self, path: str) -> Dict[str, Any]: """Wrapper for REST GET of resource with proper headers. @@ -283,11 +298,13 @@ def get(self, path: str) -> Dict[str, Any]: path=path, kuberay_crd_version=self._kuberay_crd_version, ) + + headers, verify = self._get_refreshed_headers_and_verify() result = requests.get( url, - headers=self._headers, + headers=headers, timeout=KUBERAY_REQUEST_TIMEOUT_S, - verify=self._verify, + verify=verify, ) if not result.status_code == 200: result.raise_for_status() @@ -311,11 +328,12 @@ def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]: path=path, kuberay_crd_version=self._kuberay_crd_version, ) + headers, verify = self._get_refreshed_headers_and_verify() result = requests.patch( url, json.dumps(payload), - headers={**self._headers, "Content-type": "application/json-patch+json"}, - verify=self._verify, + headers={**headers, "Content-type": "application/json-patch+json"}, + verify=verify, ) if not result.status_code == 200: result.raise_for_status() diff --git a/python/ray/tests/kuberay/test_autoscaling_config.py b/python/ray/tests/kuberay/test_autoscaling_config.py index 7fe8759fc1c2..82aec91ff969 100644 --- a/python/ray/tests/kuberay/test_autoscaling_config.py +++ b/python/ray/tests/kuberay/test_autoscaling_config.py @@ -395,17 +395,22 @@ def test_autoscaling_config_fetch_retries(exception, num_exceptions): AutoscalingConfigProducer._fetch_ray_cr_from_k8s_with_retries. """ - class MockAutoscalingConfigProducer(AutoscalingConfigProducer): - def __init__(self, *args, **kwargs): + class MockKubernetesHttpApiClient: + def __init__(self): self.exception_counter = 0 - def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]: + def get(self, *args, **kwargs): if self.exception_counter < num_exceptions: self.exception_counter += 1 raise exception else: return {"ok-key": "ok-value"} + class MockAutoscalingConfigProducer(AutoscalingConfigProducer): + def __init__(self, *args, **kwargs): + self.kubernetes_api_client = MockKubernetesHttpApiClient() + self._ray_cr_path = "rayclusters/mock" + config_producer = MockAutoscalingConfigProducer() # Patch retry backoff period. with mock.patch(