Skip to content

Commit

Permalink
fix wiring with new generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft committed Jun 17, 2020
1 parent ac09492 commit fd1deb6
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
LifetimeAction,
KeyVaultCertificate
)
from ._shared.multi_api import ApiVersion
from ._shared.client_base import ApiVersion

__all__ = [
"ApiVersion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .http_challenge import HttpChallenge
from . import http_challenge_cache as HttpChallengeCache


__all__ = [
"ChallengeAuthPolicy",
"ChallengeAuthPolicyBase",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# ------------------------------------
from typing import TYPE_CHECKING

from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.transport import AioHttpTransport

from . import AsyncChallengeAuthPolicy
from .client_base import _get_policies
from .multi_api import load_generated_api
from .client_base import ApiVersion
from .._user_agent import USER_AGENT
from .._generated.aio import KeyVaultClient as _KeyVaultClient

if TYPE_CHECKING:
try:
Expand All @@ -21,16 +22,7 @@
# AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package
pass


def _build_pipeline(config: "Configuration", transport: "AsyncHttpTransport" = None, **kwargs: "Any") -> AsyncPipeline:
policies = _get_policies(config, **kwargs)
if transport is None:
from azure.core.pipeline.transport import AioHttpTransport

transport = AioHttpTransport(**kwargs)

return AsyncPipeline(transport, policies=policies)

DEFAULT_VERSION = ApiVersion.V7_1_preview

class AsyncKeyVaultClientBase(object):
def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None:
Expand All @@ -49,18 +41,26 @@ def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs:
self._client = client
return

api_version = kwargs.pop("api_version", None)
generated = load_generated_api(api_version, aio=True)
api_version = kwargs.pop("api_version", DEFAULT_VERSION)

pipeline = kwargs.pop("pipeline", None)
if not pipeline:
config = generated.config_cls(credential, **kwargs)
config.authentication_policy = AsyncChallengeAuthPolicy(credential)
pipeline = _build_pipeline(config, **kwargs)

# generated clients don't use their credentials parameter
self._client = generated.client_cls(credentials="", pipeline=pipeline)
self._models = generated.models
transport = kwargs.pop("transport", AioHttpTransport(**kwargs))

try:
self._client = _KeyVaultClient(
api_version=api_version,
pipeline=pipeline,
transport=transport,
authentication_policy=AsyncChallengeAuthPolicy(credential),
sdk_moniker=USER_AGENT,
**kwargs
)
self._models = _KeyVaultClient.models(api_version=api_version)
except NotImplementedError:
raise NotImplementedError(
"This package doesn't support API version '{}'. ".format(api_version)
+ "Supported versions: {}".format(", ".join(v.value for v in ApiVersion))
)

@property
def vault_url(self) -> str:
Expand All @@ -78,4 +78,4 @@ async def close(self) -> None:
Calling this method is unnecessary when using the client as a context manager.
"""
await self._client.__aexit__()
await self._client.close()
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,31 @@
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING
from enum import Enum

from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import (
ContentDecodePolicy,
UserAgentPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
)
from azure.core.pipeline.transport import RequestsTransport

from .multi_api import load_generated_api
from .challenge_auth_policy import ChallengeAuthPolicy
from . import ChallengeAuthPolicy
from .._generated import KeyVaultClient as _KeyVaultClient
from .._user_agent import USER_AGENT

if TYPE_CHECKING:
# pylint:disable=unused-import
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from azure.core.credentials import TokenCredential
from azure.core.pipeline.transport import HttpTransport
from azure.core.configuration import Configuration

class ApiVersion(str, Enum):
"""Key Vault API versions supported by this package"""

def _get_policies(config, **kwargs):
logging_policy = HttpLoggingPolicy(**kwargs)
logging_policy.allowed_header_names.add("x-ms-keyvault-network-info")
#: this is the default version
V7_1_preview = "7.1-preview"
V7_0 = "7.0"
V2016_10_01 = "2016-10-01"

return [
config.headers_policy,
UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs),
config.proxy_policy,
ContentDecodePolicy(),
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
logging_policy,
]


def _build_pipeline(config, transport=None, **kwargs):
# type: (Configuration, HttpTransport, **Any) -> Pipeline
policies = _get_policies(config)
if transport is None:
transport = RequestsTransport(**kwargs)

return Pipeline(transport, policies=policies)
DEFAULT_VERSION = ApiVersion.V7_1_preview


class KeyVaultClientBase(object):
Expand All @@ -70,18 +48,26 @@ def __init__(self, vault_url, credential, **kwargs):
self._client = client
return

api_version = kwargs.pop("api_version", None)
generated = load_generated_api(api_version)
api_version = kwargs.pop("api_version", DEFAULT_VERSION)

pipeline = kwargs.pop("pipeline", None)
if not pipeline:
config = generated.config_cls(credential, **kwargs)
config.authentication_policy = ChallengeAuthPolicy(credential)
pipeline = _build_pipeline(config, **kwargs)
transport = kwargs.pop("transport", RequestsTransport(**kwargs))
try:
self._client = _KeyVaultClient(
api_version=api_version,
pipeline=pipeline,
transport=transport,
authentication_policy=ChallengeAuthPolicy(credential),
sdk_moniker=USER_AGENT,
**kwargs
)
self._models = _KeyVaultClient.models(api_version=api_version)
except NotImplementedError:
raise NotImplementedError(
"This package doesn't support API version '{}'. ".format(api_version)
+ "Supported versions: {}".format(", ".join(v.value for v in ApiVersion))
)

# generated clients don't use their credentials parameter
self._client = generated.client_cls(credentials="", pipeline=pipeline)
self._models = generated.models

@property
def vault_url(self):
Expand All @@ -103,4 +89,4 @@ def close(self):
Calling this method is unnecessary when using the client as a context manager.
"""
self._client.__exit__()
self._client.close()

This file was deleted.

0 comments on commit fd1deb6

Please sign in to comment.