Skip to content

Commit

Permalink
add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
dimbleby committed Nov 11, 2024
1 parent 090da62 commit 61a2a42
Show file tree
Hide file tree
Showing 16 changed files with 1,164 additions and 908 deletions.
18 changes: 13 additions & 5 deletions src/connectedk8s/azext_connectedk8s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations

from typing import TYPE_CHECKING

from azure.cli.core import AzCommandsLoader

from azext_connectedk8s._help import helps

if TYPE_CHECKING:
from azure.cli.core import AzCli
from knack.commands import CLICommand


class Connectedk8sCommandsLoader(AzCommandsLoader):
def __init__(self, cli_ctx=None):
class Connectedk8sCommandsLoader(AzCommandsLoader): # type: ignore[misc]
def __init__(self, cli_ctx: AzCli | None = None) -> None:
from azure.cli.core.commands import CliCommandType

from azext_connectedk8s._client_factory import cf_connectedk8s
Expand All @@ -20,13 +27,14 @@ def __init__(self, cli_ctx=None):
)
super().__init__(cli_ctx=cli_ctx, custom_command_type=connectedk8s_custom)

def load_command_table(self, args):
def load_command_table(self, args: list[str] | None) -> dict[str, CLICommand]:
from azext_connectedk8s.commands import load_command_table

load_command_table(self, args)
return self.command_table
command_table: dict[str, CLICommand] = self.command_table
return command_table

def load_arguments(self, command):
def load_arguments(self, command: CLICommand) -> None:
from azext_connectedk8s._params import load_arguments

load_arguments(self, command)
Expand Down
133 changes: 93 additions & 40 deletions src/connectedk8s/azext_connectedk8s/_client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations

import os
from collections import namedtuple
from typing import TYPE_CHECKING, Any

import requests
from azure.cli.core import telemetry
Expand All @@ -14,93 +16,142 @@

import azext_connectedk8s._constants as consts

if TYPE_CHECKING:
from azure.cli.core import AzCli
from azure.mgmt.hybridcompute.operations import PrivateLinkScopesOperations
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.resource.resources.v2022_09_01.operations import (
ProvidersOperations,
ResourceGroupsOperations,
)

from azext_connectedk8s.vendored_sdks import ConnectedKubernetesClient
from azext_connectedk8s.vendored_sdks.operations import ConnectedClusterOperations
from azext_connectedk8s.vendored_sdks.preview_2024_07_01 import (
ConnectedKubernetesClient as ConnectedKubernetesClient20240701,
)
from azext_connectedk8s.vendored_sdks.preview_2024_07_01.operations import (
ConnectedClusterOperations as ConnectedClusterOperations20240701,
)

AccessToken = namedtuple("AccessToken", ["token", "expires_on"])


def cf_connectedk8s(cli_ctx, *_):
def cf_connectedk8s(cli_ctx: AzCli, *_: Any) -> ConnectedKubernetesClient:
from azext_connectedk8s.vendored_sdks import ConnectedKubernetesClient

if os.getenv(consts.Azure_Access_Token_Variable):
client: ConnectedKubernetesClient
access_token = os.getenv(consts.Azure_Access_Token_Variable)
if access_token is not None:
validate_custom_token()
credential = AccessTokenCredential(
access_token=os.getenv(consts.Azure_Access_Token_Variable)
)
return get_mgmt_service_client(
credential = AccessTokenCredential(access_token=access_token)
client = get_mgmt_service_client(
cli_ctx,
ConnectedKubernetesClient,
subscription_id=os.getenv("AZURE_SUBSCRIPTION_ID"),
credential=credential,
)
return get_mgmt_service_client(cli_ctx, ConnectedKubernetesClient)
return client

client = get_mgmt_service_client(cli_ctx, ConnectedKubernetesClient)
return client


def cf_connected_cluster(cli_ctx, _):
def cf_connected_cluster(cli_ctx: AzCli, _: Any) -> ConnectedClusterOperations:
return cf_connectedk8s(cli_ctx).connected_cluster


def cf_connectedk8s_prev_2024_07_01(cli_ctx, *_):
def cf_connectedk8s_prev_2024_07_01(
cli_ctx: AzCli, *_: Any
) -> ConnectedKubernetesClient20240701:
from azext_connectedk8s.vendored_sdks.preview_2024_07_01 import (
ConnectedKubernetesClient,
)

if os.getenv(consts.Azure_Access_Token_Variable):
client: ConnectedKubernetesClient
access_token = os.getenv(consts.Azure_Access_Token_Variable)
if access_token is not None:
validate_custom_token()
credential = AccessTokenCredential(
access_token=os.getenv(consts.Azure_Access_Token_Variable)
)
return get_mgmt_service_client(
credential = AccessTokenCredential(access_token=access_token)
client = get_mgmt_service_client(
cli_ctx,
ConnectedKubernetesClient,
subscription_id=os.getenv("AZURE_SUBSCRIPTION_ID"),
credential=credential,
)
return get_mgmt_service_client(cli_ctx, ConnectedKubernetesClient)
return client

client = get_mgmt_service_client(cli_ctx, ConnectedKubernetesClient)
return client


def cf_connected_cluster_prev_2024_07_01(cli_ctx, _):
def cf_connected_cluster_prev_2024_07_01(
cli_ctx: AzCli, _: Any
) -> ConnectedClusterOperations20240701:
return cf_connectedk8s_prev_2024_07_01(cli_ctx).connected_cluster


def cf_connectedmachine(cli_ctx, subscription_id):
def cf_connectedmachine(
cli_ctx: AzCli, subscription_id: str | None
) -> PrivateLinkScopesOperations:
from azure.mgmt.hybridcompute import HybridComputeManagementClient

if os.getenv(consts.Azure_Access_Token_Variable):
credential = AccessTokenCredential(
access_token=os.getenv(consts.Azure_Access_Token_Variable)
)
return get_mgmt_service_client(
client: HybridComputeManagementClient
access_token = os.getenv(consts.Azure_Access_Token_Variable)
if access_token is not None:
credential = AccessTokenCredential(access_token=access_token)
client = get_mgmt_service_client(
cli_ctx,
HybridComputeManagementClient,
subscription_id=subscription_id,
credential=credential,
).private_link_scopes
return get_mgmt_service_client(
)
return client.private_link_scopes

client = get_mgmt_service_client(
cli_ctx, HybridComputeManagementClient, subscription_id=subscription_id
).private_link_scopes
)
return client.private_link_scopes


def cf_resource_groups(cli_ctx, subscription_id=None):
return _resource_client_factory(cli_ctx, subscription_id).resource_groups
def cf_resource_groups(
cli_ctx: AzCli, subscription_id: str | None = None
) -> ResourceGroupsOperations:
resource_groups: ResourceGroupsOperations = _resource_client_factory(
cli_ctx, subscription_id
).resource_groups
return resource_groups


def _resource_client_factory(cli_ctx, subscription_id=None):
if os.getenv(consts.Azure_Access_Token_Variable):
credential = AccessTokenCredential(
access_token=os.getenv(consts.Azure_Access_Token_Variable)
)
return get_mgmt_service_client(
def _resource_client_factory(
cli_ctx: AzCli, subscription_id: str | None = None
) -> ResourceManagementClient:
client: ResourceManagementClient

access_token = os.getenv(consts.Azure_Access_Token_Variable)
if access_token is not None:
credential = AccessTokenCredential(access_token=access_token)
client = get_mgmt_service_client(
cli_ctx,
ResourceType.MGMT_RESOURCE_RESOURCES,
subscription_id=subscription_id,
credential=credential,
)
return get_mgmt_service_client(
return client

client = get_mgmt_service_client(
cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, subscription_id=subscription_id
)
return client


def resource_providers_client(cli_ctx, subscription_id=None):
return _resource_client_factory(cli_ctx, subscription_id).providers
def resource_providers_client(
cli_ctx: AzCli, subscription_id: str | None = None
) -> ProvidersOperations:
providers: ProvidersOperations = _resource_client_factory(
cli_ctx, subscription_id
).providers
return providers

# Alternate: This should also work
# subscription_id = get_subscription_id(cli_ctx)
Expand All @@ -111,23 +162,25 @@ def resource_providers_client(cli_ctx, subscription_id=None):
class AccessTokenCredential:
"""Simple access token Authentication. Returns the access token as-is."""

def __init__(self, access_token):
def __init__(self, access_token: str) -> None:
self.access_token = access_token

def get_token(self, *arg, **kwargs):
def get_token(self, *arg: Any, **kwargs: Any) -> AccessToken:
import time

# Assume the access token expires in 60 minutes
return AccessToken(self.access_token, int(time.time()) + 3600)

def signed_session(self, session=None):
def signed_session(
self, session: requests.Session | None = None
) -> requests.Session:
session = session or requests.Session()
header = "{} {}".format("Bearer", self.access_token)
session.headers["Authorization"] = header
return session


def validate_custom_token():
def validate_custom_token() -> None:
if os.getenv("AZURE_SUBSCRIPTION_ID") is None:
telemetry.set_exception(
exception="Required environment variable 'AZURE_SUBSCRIPTION_ID' is not set, when "
Expand Down
53 changes: 39 additions & 14 deletions src/connectedk8s/azext_connectedk8s/_clientproxyutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations

import base64
import json
Expand All @@ -10,6 +11,7 @@
import sys
import time
from base64 import b64decode, b64encode
from typing import TYPE_CHECKING, Any, NoReturn

import requests
import yaml
Expand All @@ -27,14 +29,19 @@

import azext_connectedk8s._constants as consts

if TYPE_CHECKING:
from subprocess import Popen

from knack.commands import CLICommand

logger = get_logger(__name__)


def check_if_port_is_open(port):
def check_if_port_is_open(port: int) -> bool:
try:
connections = net_connections(kind="inet")
for tup in connections:
if int(tup[3][1]) == int(port):
if int(tup[3][1]) == port: # type: ignore[misc]
return True
except Exception as e:
telemetry.set_exception(
Expand All @@ -48,18 +55,27 @@ def check_if_port_is_open(port):
return False


def close_subprocess_and_raise_cli_error(proc_subprocess, msg):
def close_subprocess_and_raise_cli_error(
proc_subprocess: Popen[bytes], msg: str
) -> NoReturn:
proc_subprocess.terminate()
raise CLIInternalError(msg)


def check_if_csp_is_running(clientproxy_process):
def check_if_csp_is_running(clientproxy_process: Popen[bytes]) -> bool:
return clientproxy_process.poll() is None


def make_api_call_with_retries(
uri, data, method, tls_verify, fault_type, summary, cli_error, clientproxy_process
):
uri: str,
data: dict[str, Any],
method: str,
tls_verify: bool,
fault_type: str,
summary: str,
cli_error: str,
clientproxy_process: Popen[bytes],
) -> requests.Response:
for i in range(consts.API_CALL_RETRIES):
try:
response = requests.request(method, uri, json=data, verify=tls_verify)
Expand All @@ -76,9 +92,12 @@ def make_api_call_with_retries(
clientproxy_process, cli_error + str(e)
)

assert False


def fetch_pop_publickey_kid(api_server_port, clientproxy_process):
requestbody = {}
def fetch_pop_publickey_kid(
api_server_port: int, clientproxy_process: Popen[bytes]
) -> str:
poppublickey_uri = f"https://localhost:{api_server_port}/identity/poppublickey"
# Needed to prevent skip tls warning from printing to the console
original_stderr = sys.stderr
Expand All @@ -87,7 +106,7 @@ def fetch_pop_publickey_kid(api_server_port, clientproxy_process):

get_publickey_response = make_api_call_with_retries(
poppublickey_uri,
requestbody,
{},
"get",
False,
consts.Get_PublicKey_Info_Fault_Type,
Expand All @@ -98,12 +117,18 @@ def fetch_pop_publickey_kid(api_server_port, clientproxy_process):

sys.stderr = original_stderr
publickey_info = json.loads(get_publickey_response.text)
kid = publickey_info["publicKey"]["kid"]
kid: str = publickey_info["publicKey"]["kid"]

return kid


def fetch_and_post_at_to_csp(cmd, api_server_port, tenant_id, kid, clientproxy_process):
def fetch_and_post_at_to_csp(
cmd: CLICommand,
api_server_port: int,
tenant_id: str,
kid: str,
clientproxy_process: Popen[bytes],
) -> requests.Response:
req_cnfJSON = {"kid": kid, "xms_ksl": "sw"}
req_cnf = base64.urlsafe_b64encode(json.dumps(req_cnfJSON).encode("utf-8")).decode(
"utf-8"
Expand Down Expand Up @@ -160,8 +185,8 @@ def fetch_and_post_at_to_csp(cmd, api_server_port, tenant_id, kid, clientproxy_p
return post_at_response


def insert_token_in_kubeconfig(data, token):
b64kubeconfig = data["kubeconfigs"][0]["value"]
def insert_token_in_kubeconfig(data: dict[str, Any], token: str) -> str:
b64kubeconfig: str = data["kubeconfigs"][0]["value"]
decoded_kubeconfig_str = b64decode(b64kubeconfig).decode("utf-8")
dict_yaml = yaml.safe_load(decoded_kubeconfig_str)
dict_yaml["users"][0]["user"]["token"] = token
Expand All @@ -170,7 +195,7 @@ def insert_token_in_kubeconfig(data, token):
return b64kubeconfig


def check_process(processName):
def check_process(processName: str) -> bool:
"""
Check if there is any running process that contains the given name processName.
"""
Expand Down
Loading

0 comments on commit 61a2a42

Please sign in to comment.