diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/__init__.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/__init__.py index d644e8db7e1..31d88019da1 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/__init__.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/__init__.py @@ -24,7 +24,7 @@ LifetimeAction, KeyVaultCertificate ) -from ._shared.multi_api import ApiVersion +from ._shared.client_base import ApiVersion __all__ = [ "ApiVersion", diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/__init__.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/__init__.py index a8fd2a41d71..e13f15a61c7 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/__init__.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/__init__.py @@ -15,6 +15,7 @@ from .http_challenge import HttpChallenge from . import http_challenge_cache as HttpChallengeCache + __all__ = [ "ChallengeAuthPolicy", "ChallengeAuthPolicyBase", diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py index 84c827a3a93..fb68aa345d5 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py @@ -4,11 +4,12 @@ # ------------------------------------ from typing import TYPE_CHECKING -from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.transport import AioHttpTransport from . import AsyncChallengeAuthPolicy -from .client_base import _get_policies -from .multi_api import load_generated_api +from .client_base import ApiVersion +from .._user_agent import USER_AGENT +from .._generated.aio import KeyVaultClient as _KeyVaultClient if TYPE_CHECKING: try: @@ -21,16 +22,7 @@ # AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package pass - -def _build_pipeline(config: "Configuration", transport: "AsyncHttpTransport" = None, **kwargs: "Any") -> AsyncPipeline: - policies = _get_policies(config, **kwargs) - if transport is None: - from azure.core.pipeline.transport import AioHttpTransport - - transport = AioHttpTransport(**kwargs) - - return AsyncPipeline(transport, policies=policies) - +DEFAULT_VERSION = ApiVersion.V7_1_preview class AsyncKeyVaultClientBase(object): def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: @@ -49,18 +41,26 @@ def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: self._client = client return - api_version = kwargs.pop("api_version", None) - generated = load_generated_api(api_version, aio=True) + api_version = kwargs.pop("api_version", DEFAULT_VERSION) pipeline = kwargs.pop("pipeline", None) - if not pipeline: - config = generated.config_cls(credential, **kwargs) - config.authentication_policy = AsyncChallengeAuthPolicy(credential) - pipeline = _build_pipeline(config, **kwargs) - - # generated clients don't use their credentials parameter - self._client = generated.client_cls(credentials="", pipeline=pipeline) - self._models = generated.models + transport = kwargs.pop("transport", AioHttpTransport(**kwargs)) + + try: + self._client = _KeyVaultClient( + api_version=api_version, + pipeline=pipeline, + transport=transport, + authentication_policy=AsyncChallengeAuthPolicy(credential), + sdk_moniker=USER_AGENT, + **kwargs + ) + self._models = _KeyVaultClient.models(api_version=api_version) + except NotImplementedError: + raise NotImplementedError( + "This package doesn't support API version '{}'. ".format(api_version) + + "Supported versions: {}".format(", ".join(v.value for v in ApiVersion)) + ) @property def vault_url(self) -> str: @@ -78,4 +78,4 @@ async def close(self) -> None: Calling this method is unnecessary when using the client as a context manager. """ - await self._client.__aexit__() + await self._client.close() diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py index b1e1a2e997d..f19d1226769 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py @@ -3,53 +3,31 @@ # Licensed under the MIT License. # ------------------------------------ from typing import TYPE_CHECKING +from enum import Enum -from azure.core.pipeline import Pipeline -from azure.core.pipeline.policies import ( - ContentDecodePolicy, - UserAgentPolicy, - DistributedTracingPolicy, - HttpLoggingPolicy, -) from azure.core.pipeline.transport import RequestsTransport -from .multi_api import load_generated_api -from .challenge_auth_policy import ChallengeAuthPolicy +from . import ChallengeAuthPolicy +from .._generated import KeyVaultClient as _KeyVaultClient from .._user_agent import USER_AGENT if TYPE_CHECKING: - # pylint:disable=unused-import + # pylint:disable=unused-import,ungrouped-imports from typing import Any from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpTransport from azure.core.configuration import Configuration +class ApiVersion(str, Enum): + """Key Vault API versions supported by this package""" -def _get_policies(config, **kwargs): - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") + #: this is the default version + V7_1_preview = "7.1-preview" + V7_0 = "7.0" + V2016_10_01 = "2016-10-01" - return [ - config.headers_policy, - UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs), - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - -def _build_pipeline(config, transport=None, **kwargs): - # type: (Configuration, HttpTransport, **Any) -> Pipeline - policies = _get_policies(config) - if transport is None: - transport = RequestsTransport(**kwargs) - - return Pipeline(transport, policies=policies) +DEFAULT_VERSION = ApiVersion.V7_1_preview class KeyVaultClientBase(object): @@ -70,18 +48,26 @@ def __init__(self, vault_url, credential, **kwargs): self._client = client return - api_version = kwargs.pop("api_version", None) - generated = load_generated_api(api_version) + api_version = kwargs.pop("api_version", DEFAULT_VERSION) pipeline = kwargs.pop("pipeline", None) - if not pipeline: - config = generated.config_cls(credential, **kwargs) - config.authentication_policy = ChallengeAuthPolicy(credential) - pipeline = _build_pipeline(config, **kwargs) + transport = kwargs.pop("transport", RequestsTransport(**kwargs)) + try: + self._client = _KeyVaultClient( + api_version=api_version, + pipeline=pipeline, + transport=transport, + authentication_policy=ChallengeAuthPolicy(credential), + sdk_moniker=USER_AGENT, + **kwargs + ) + self._models = _KeyVaultClient.models(api_version=api_version) + except NotImplementedError: + raise NotImplementedError( + "This package doesn't support API version '{}'. ".format(api_version) + + "Supported versions: {}".format(", ".join(v.value for v in ApiVersion)) + ) - # generated clients don't use their credentials parameter - self._client = generated.client_cls(credentials="", pipeline=pipeline) - self._models = generated.models @property def vault_url(self): @@ -103,4 +89,4 @@ def close(self): Calling this method is unnecessary when using the client as a context manager. """ - self._client.__exit__() + self._client.close() diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/multi_api.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/multi_api.py deleted file mode 100644 index 83576cb9a26..00000000000 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/multi_api.py +++ /dev/null @@ -1,73 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -from collections import namedtuple -from enum import Enum -from typing import TYPE_CHECKING - -from .._generated.v7_1_preview.version import VERSION as V7_1_PREVIEW_VERSION -from .._generated.v7_0.version import VERSION as V7_0_VERSION -from .._generated.v2016_10_01.version import VERSION as V2016_10_01_VERSION - -if TYPE_CHECKING: - from typing import Union - - -class ApiVersion(Enum): - """Key Vault API versions supported by this package""" - - #: this is the default version - V7_1_preview = V7_1_PREVIEW_VERSION - V7_0 = V7_0_VERSION - V2016_10_01 = V2016_10_01_VERSION - - -DEFAULT_VERSION = ApiVersion.V7_1_preview - -GeneratedApi = namedtuple("GeneratedApi", ("models", "client_cls", "config_cls")) - - -def load_generated_api(api_version, aio=False): - # type: (Union[ApiVersion, str], bool) -> GeneratedApi - api_version = api_version or DEFAULT_VERSION - try: - # api_version could be a string; map it to an instance of ApiVersion - # (this is a no-op if it's already an instance of ApiVersion) - api_version = ApiVersion(api_version) - except ValueError: - # api_version is unknown to ApiVersion - raise NotImplementedError( - "This package doesn't support API version '{}'. ".format(api_version) - + "Supported versions: {}".format(", ".join(v.value for v in ApiVersion)) - ) - - if api_version == ApiVersion.V7_1_preview: - from .._generated.v7_1_preview import models - - if aio: - from .._generated.v7_1_preview.aio import KeyVaultClient - from .._generated.v7_1_preview.aio._configuration_async import KeyVaultClientConfiguration - else: - from .._generated.v7_1_preview import KeyVaultClient # type: ignore - from .._generated.v7_1_preview._configuration import KeyVaultClientConfiguration # type: ignore - elif api_version == ApiVersion.V7_0: - from .._generated.v7_0 import models # type: ignore - - if aio: - from .._generated.v7_0.aio import KeyVaultClient # type: ignore - from .._generated.v7_0.aio._configuration_async import KeyVaultClientConfiguration # type: ignore - else: - from .._generated.v7_0 import KeyVaultClient # type: ignore - from .._generated.v7_0._configuration import KeyVaultClientConfiguration # type: ignore - elif api_version == ApiVersion.V2016_10_01: - from .._generated.v2016_10_01 import models # type: ignore - - if aio: - from .._generated.v2016_10_01.aio import KeyVaultClient # type: ignore - from .._generated.v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration # type: ignore - else: - from .._generated.v2016_10_01 import KeyVaultClient # type: ignore - from .._generated.v2016_10_01._configuration import KeyVaultClientConfiguration # type: ignore - - return GeneratedApi(models=models, client_cls=KeyVaultClient, config_cls=KeyVaultClientConfiguration)