Skip to content

Commit

Permalink
Merge pull request #27 from AzureArcForKubernetes/jorgedaboub/msal_cs…
Browse files Browse the repository at this point in the history
…p_migration

Add AT refresh for Proxy
  • Loading branch information
JorgeDaboub authored Dec 5, 2024
2 parents 77ac7cd + 401c18f commit c9058ca
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 174 deletions.
12 changes: 0 additions & 12 deletions src/connectedk8s/azext_connectedk8s/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,18 +1542,6 @@ def az_cli(args_str: str) -> Any:
return True


# def is_cli_using_msal_auth():
# response_cli_version = az_cli("version --output json")
# try:
# cli_version = response_cli_version['azure-cli']
# except Exception as ex:
# raise CLIInternalError(f"Unable to decode the az cli version installed: {ex}")
# if version.parse(cli_version) >= version.parse(consts.AZ_CLI_ADAL_TO_MSAL_MIGRATE_VERSION):
# return True
# else:
# return False


def is_cli_using_msal_auth() -> bool:
response_cli_version = az_cli("version --output json")
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
22 changes: 22 additions & 0 deletions src/connectedk8s/azext_connectedk8s/clientproxyhelper/_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations

from enum import Enum


class ProxyStatus(Enum):
FirstRun = 0
HCTokenRefresh = 1
AccessTokenRefresh = 2
AllRefresh = 3

@classmethod
def should_hc_token_refresh(cls, status: ProxyStatus) -> bool:
return status in {cls.FirstRun, cls.HCTokenRefresh, cls.AllRefresh}

@classmethod
def should_access_token_refresh(cls, status: ProxyStatus) -> bool:
return status in {cls.FirstRun, cls.AccessTokenRefresh, cls.AllRefresh}
124 changes: 124 additions & 0 deletions src/connectedk8s/azext_connectedk8s/clientproxyhelper/_proxylogic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from azure.cli.core import telemetry

import azext_connectedk8s._constants as consts
import azext_connectedk8s.clientproxyhelper._utils as clientproxyutils

from ..vendored_sdks.models import (
ListClusterUserCredentialProperties,
)

if TYPE_CHECKING:
from subprocess import Popen

from knack.commands import CLICommmand
from requests.models import Response

from azext_connectedk8s.vendored_sdks.preview_2024_07_01.models import (
CredentialResults,
)
from azext_connectedk8s.vendored_sdks.preview_2024_07_01.operations import (
ConnectedClusterOperations,
)


def handle_post_at_to_csp(
cmd: CLICommmand,
api_server_port: int,
tenant_id: str,
clientproxy_process: Popen[bytes],
) -> int:
kid = clientproxyutils.fetch_pop_publickey_kid(api_server_port, clientproxy_process)
post_at_response, expiry = clientproxyutils.fetch_and_post_at_to_csp(
cmd, api_server_port, tenant_id, kid, clientproxy_process
)

if post_at_response.status_code != 200:
if (
post_at_response.status_code == 500
and "public key expired" in post_at_response.text
):
# Handle public key rotation
telemetry.set_exception(
exception=post_at_response.text,
fault_type=consts.PoP_Public_Key_Expried_Fault_Type,
summary="PoP public key has expired",
)
kid = clientproxyutils.fetch_pop_publickey_kid(
api_server_port, clientproxy_process
) # Fetch rotated public key
# Retry posting AT with the new public key
post_at_response, expiry = clientproxyutils.fetch_and_post_at_to_csp(
cmd, api_server_port, tenant_id, kid, clientproxy_process
)
# If after second try we still dont get a 200, raise error
if post_at_response.status_code != 200:
telemetry.set_exception(
exception=post_at_response.text,
fault_type=consts.Post_AT_To_ClientProxy_Failed_Fault_Type,
summary="Failed to post access token to client proxy",
)
clientproxyutils.close_subprocess_and_raise_cli_error(
clientproxy_process,
"Failed to post access token to client proxy" + post_at_response.text,
)

return expiry


def get_cluster_user_credentials(
client: ConnectedClusterOperations,
resource_group_name: str,
cluster_name: str,
auth_method: str,
) -> CredentialResults:
list_prop = ListClusterUserCredentialProperties(
authentication_method=auth_method, client_proxy=True
)

result: CredentialResults = client.list_cluster_user_credential( # type: ignore[call-overload]
resource_group_name,
cluster_name,
list_prop,
)
return result


def post_register_to_proxy(
data: dict[str, Any],
token: str | None,
client_proxy_port: int,
subscription_id: str,
resource_group_name: str,
cluster_name: str,
clientproxy_process: Popen[bytes],
) -> Response:
if token is not None:
data["kubeconfigs"][0]["value"] = clientproxyutils.insert_token_in_kubeconfig(
data, token
)

uri = (
f"http://localhost:{client_proxy_port}/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}"
f"/providers/Microsoft.Kubernetes/connectedClusters/{cluster_name}/register?api-version=2020-10-01"
)

# Posting hybrid connection details to proxy in order to get kubeconfig
response = clientproxyutils.make_api_call_with_retries(
uri,
data,
"post",
False,
consts.Post_Hybridconn_Fault_Type,
"Unable to post hybrid connection details to clientproxy",
"Failed to pass hybrid connection details to proxy.",
clientproxy_process,
)
return response
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

from knack.commands import CLICommand

from azext_connectedk8s.vendored_sdks.preview_2024_07_01.models import (
CredentialResults,
)

logger = get_logger(__name__)


Expand Down Expand Up @@ -128,7 +132,7 @@ def fetch_and_post_at_to_csp(
tenant_id: str,
kid: str,
clientproxy_process: Popen[bytes],
) -> requests.Response:
) -> tuple[requests.Response, int]:
req_cnfJSON = {"kid": kid, "xms_ksl": "sw"}
req_cnf = base64.urlsafe_b64encode(json.dumps(req_cnfJSON).encode("utf-8")).decode(
"utf-8"
Expand Down Expand Up @@ -182,7 +186,7 @@ def fetch_and_post_at_to_csp(
)

sys.stderr = original_stderr
return post_at_response
return post_at_response, accessToken.expires_on


def insert_token_in_kubeconfig(data: dict[str, Any], token: str) -> str:
Expand All @@ -195,6 +199,26 @@ def insert_token_in_kubeconfig(data: dict[str, Any], token: str) -> str:
return b64kubeconfig


# Prepare data as needed by client proxy executable
def prepare_clientproxy_data(response: CredentialResults) -> dict[str, Any]:
data: dict[str, Any] = {}
data["kubeconfigs"] = []
kubeconfig = {}
kubeconfig["name"] = "Kubeconfig"
kubeconfig["value"] = b64encode(response.kubeconfigs[0].value).decode("utf-8") # type: ignore[index]
data["kubeconfigs"].append(kubeconfig)
data["hybridConnectionConfig"] = {}
data["hybridConnectionConfig"]["relay"] = response.hybrid_connection_config.relay # type: ignore[attr-defined]
data["hybridConnectionConfig"]["hybridConnectionName"] = (
response.hybrid_connection_config.hybrid_connection_name # type: ignore[attr-defined]
)
data["hybridConnectionConfig"]["token"] = response.hybrid_connection_config.token # type: ignore[attr-defined]
data["hybridConnectionConfig"]["expirationTime"] = (
response.hybrid_connection_config.expiration_time # type: ignore[attr-defined]
)
return data


def check_process(processName: str) -> bool:
"""
Check if there is any running process that contains the given name processName.
Expand Down
Loading

0 comments on commit c9058ca

Please sign in to comment.