Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ClientCertificateCredential to use AadClient #11719

Merged
merged 2 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions sdk/identity/azure-identity/azure/identity/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import binascii

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.backends import default_backend
from msal.oauth2cli import JwtSigner
import six

try:
ABC = abc.ABC
Expand Down Expand Up @@ -41,48 +34,3 @@ def __init__(self, tenant_id, client_id, secret, **kwargs): # pylint:disable=un
)
self._form_data = {"client_id": client_id, "client_secret": secret, "grant_type": "client_credentials"}
super(ClientSecretCredentialBase, self).__init__()


class CertificateCredentialBase(ABC):
"""Sans I/O base for certificate credentials"""

def __init__(self, tenant_id, client_id, certificate_path, **kwargs): # pylint:disable=unused-argument
# type: (str, str, str, **Any) -> None
if not certificate_path:
raise ValueError(
"'certificate_path' must be the path to a PEM file containing an x509 certificate and its private key"
)

super(CertificateCredentialBase, self).__init__()

password = kwargs.pop("password", None)
if isinstance(password, six.text_type):
password = password.encode(encoding="utf-8")

with open(certificate_path, "rb") as f:
pem_bytes = f.read()

private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend())
cert = x509.load_pem_x509_certificate(pem_bytes, default_backend())
fingerprint = cert.fingerprint(hashes.SHA1()) #nosec

self._client = self._get_auth_client(tenant_id, **kwargs)
self._client_id = client_id
self._signer = JwtSigner(private_key, "RS256", sha1_thumbprint=binascii.hexlify(fingerprint))

def _get_request_data(self, *scopes):
assertion = self._signer.sign_assertion(audience=self._client.auth_url, issuer=self._client_id)
if isinstance(assertion, six.binary_type):
assertion = assertion.decode("utf-8")

return {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": self._client_id,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}

@abc.abstractmethod
def _get_auth_client(self, tenant_id, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# ------------------------------------
from typing import TYPE_CHECKING

from .._authn_client import AuthnClient
from .._base import CertificateCredentialBase
from .._internal import AadClient, CertificateCredentialBase

if TYPE_CHECKING:
from azure.core.credentials import AccessToken
Expand Down Expand Up @@ -42,11 +41,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_token(scopes)
token = self._client.get_cached_access_token(scopes)
if not token:
data = self._get_request_data(*scopes)
token = self._client.request_token(scopes, form_data=data)
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
return token

def _get_auth_client(self, tenant_id, **kwargs):
return AuthnClient(tenant=tenant_id, **kwargs)
def _get_auth_client(self, tenant_id, client_id, **kwargs):
return AadClient(tenant_id, client_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_default_authority():
from .aad_client import AadClient
from .aad_client_base import AadClientBase
from .auth_code_redirect_handler import AuthCodeRedirectServer
from .aadclient_certificate import AadClientCertificate
from .certificate_credential_base import CertificateCredentialBase
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Expand All @@ -56,6 +58,8 @@ def _scopes_to_resource(*scopes):
"AadClient",
"AadClientBase",
"AuthCodeRedirectServer",
"AadClientCertificate",
"CertificateCredentialBase",
"ConfidentialClientCredential",
"get_default_authority",
"InteractiveCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport
from .._internal import AadClientCertificate

Policy = Union[HTTPPolicy, SansIOHTTPPolicy]

Expand All @@ -41,6 +42,14 @@ def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (str, Sequence[str], **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import base64
import copy
import json
import time
from uuid import uuid4

import six
from msal import TokenCache

from azure.core.pipeline.transport import HttpRequest
Expand All @@ -29,6 +33,7 @@
from azure.core.pipeline import AsyncPipeline, Pipeline
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport
from .._internal import AadClientCertificate

PipelineType = Union[AsyncPipeline, Pipeline]
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
Expand Down Expand Up @@ -62,6 +67,10 @@ def get_cached_refresh_tokens(self, scopes):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass
Expand Down Expand Up @@ -90,8 +99,7 @@ def _process_response(self, response, scopes, now):
return AccessToken(response_copy["access_token"], expires_on)

def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
# type: (str, str, Sequence[str], Optional[str]) -> HttpRequest

# type: (Sequence[str], str, str, Optional[str]) -> HttpRequest
data = {
"client_id": self._client_id,
"code": code,
Expand All @@ -107,9 +115,49 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None)
)
return request

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (str, Sequence[str]) -> HttpRequest
def _get_client_certificate_request(self, scopes, certificate):
# type: (Sequence[str], AadClientCertificate) -> HttpRequest
assertion = self._get_jwt_assertion(certificate)
data = {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": self._client_id,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}

request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request

def _get_jwt_assertion(self, certificate):
# type: (AadClientCertificate) -> str
now = int(time.time())
header = six.ensure_binary(
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
)
payload = six.ensure_binary(
json.dumps(
{
"jti": str(uuid4()),
"aud": self._token_endpoint,
"iss": self._client_id,
"sub": self._client_id,
"nbf": now,
"exp": now + (60 * 30),
}
),
encoding="utf-8",
)
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
signature = certificate.sign(jws)
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)

return jwt_bytes.decode("utf-8")

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (Sequence[str], str) -> HttpRequest
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import base64
from typing import TYPE_CHECKING

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.backends import default_backend
import six

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Optional


class AadClientCertificate(object):
"""Wraps 'cryptography' to provide the crypto operations AadClient requires for certificate authentication.

:param bytes pem_bytes: bytes of a a PEM-encoded certificate including the private key
:param bytes password: (optional) the certificate's password
"""
def __init__(self, pem_bytes, password=None):
# type: (bytes, Optional[bytes]) -> None
cert = x509.load_pem_x509_certificate(pem_bytes, default_backend())
fingerprint = cert.fingerprint(hashes.SHA1()) # nosec
self._private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend())
self._thumbprint = six.ensure_str(base64.urlsafe_b64encode(fingerprint), encoding="utf-8")

@property
def thumbprint(self):
# type: () -> str
"""The certificate's SHA1 thumbprint as a base64url-encoded string"""
return self._thumbprint

def sign(self, plaintext):
# type: (bytes) -> bytes
"""Sign bytes using RS256"""
return self._private_key.sign(plaintext, padding.PKCS1v15(), hashes.SHA256())
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc

import six
from azure.identity._internal import AadClientCertificate

try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any


class CertificateCredentialBase(ABC):
def __init__(self, tenant_id, client_id, certificate_path, **kwargs):
# type: (str, str, str, **Any) -> None
if not certificate_path:
raise ValueError(
"'certificate_path' must be the path to a PEM file containing an x509 certificate and its private key"
)

super(CertificateCredentialBase, self).__init__()

password = kwargs.pop("password", None)
if isinstance(password, six.text_type):
password = password.encode(encoding="utf-8")

with open(certificate_path, "rb") as f:
pem_bytes = f.read()

self._certificate = AadClientCertificate(pem_bytes, password=password)
self._client = self._get_auth_client(tenant_id, client_id, **kwargs)

@abc.abstractmethod
def _get_auth_client(self, tenant_id, client_id, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import TYPE_CHECKING

from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._base import CertificateCredentialBase
from .._internal import AadClient
from ..._internal import CertificateCredentialBase

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -51,11 +51,10 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_token(scopes)
token = self._client.get_cached_access_token(scopes)
if not token:
data = self._get_request_data(*scopes)
token = await self._client.request_token(scopes, form_data=data)
return token # type: ignore
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
return token

def _get_auth_client(self, tenant_id, **kwargs):
return AsyncAuthnClient(tenant=tenant_id, **kwargs)
def _get_auth_client(self, tenant_id, client_id, **kwargs):
return AadClient(tenant_id, client_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport
from ..._internal import AadClientCertificate

Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy]

Expand Down Expand Up @@ -58,6 +59,14 @@ async def obtain_token_by_authorization_code(
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

async def obtain_token_by_refresh_token(
self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
Expand Down