From 68b736955050c44e205fd4bfbcfb6044965d3e10 Mon Sep 17 00:00:00 2001 From: fulder Date: Tue, 25 Aug 2020 09:53:01 +0200 Subject: [PATCH] Test creating superclass for sign algorithm --- httpsig/sign.py | 10 +++------- httpsig/sign_algorithms.py | 20 ++++++++++++++------ httpsig/verify.py | 5 ++--- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/httpsig/sign.py b/httpsig/sign.py index 1a0a3fe..af63f3e 100644 --- a/httpsig/sign.py +++ b/httpsig/sign.py @@ -5,7 +5,7 @@ from Crypto.Hash import HMAC from Crypto.PublicKey import RSA from Crypto.Signature import PKCS1_v1_5 -from .sign_algorithms import SIGN_ALGORITHMS +from .sign_algorithms import SignAlgorithm from .utils import * DEFAULT_SIGN_ALGORITHM = "hs2019" @@ -19,15 +19,12 @@ class Signer(object): Password-protected keyfiles are not supported. """ - def __init__(self, secret, algorithm=None, sign_algorithm=None): + def __init__(self, secret, algorithm=None, sign_algorithm: SignAlgorithm=None): if algorithm is None: algorithm = DEFAULT_SIGN_ALGORITHM assert algorithm in ALGORITHMS, "Unknown algorithm" - if sign_algorithm is not None and sign_algorithm.__class__.__name__ not in SIGN_ALGORITHMS: - raise HttpSigException("Unsupported digital signature algorithm") - if algorithm != DEFAULT_SIGN_ALGORITHM: print("Algorithm: {} is deprecated please update to {}".format(algorithm, DEFAULT_SIGN_ALGORITHM)) @@ -79,7 +76,7 @@ def sign(self, data): signed = self._sign_rsa(data) elif self._hash: signed = self._sign_hmac(data) - elif self.sign_algorithm.__class__.__name__ in SIGN_ALGORITHMS: + elif isinstance(self.sign_algorithm, SignAlgorithm): signed = self.sign_algorithm.sign(self.secret, data) if not signed: raise SystemError('No valid encryptor found.') @@ -98,7 +95,6 @@ class HeaderSigner(Signer): match the algorithm) :param algorithm: one of the seven specified algorithms :param sign_algorithm: required for 'hs2019' algorithm. Sign algorithm for the secret - :param sign_algorithm: Custom salt length for 'hs2019' and 'PSS' sign algorithm. :param headers: a list of http headers to be included in the signing string, defaulting to ['date']. :param sign_header: header used to include signature, defaulting to diff --git a/httpsig/sign_algorithms.py b/httpsig/sign_algorithms.py index 8018175..a10fb4c 100644 --- a/httpsig/sign_algorithms.py +++ b/httpsig/sign_algorithms.py @@ -4,11 +4,24 @@ from Crypto.PublicKey import RSA from Crypto.Signature import PKCS1_PSS from httpsig.utils import HttpSigException, HASHES +from abc import ABCMeta, abstractmethod DEFAULT_HASH_ALGORITHM = "sha512" -class PSS(object): +class SignAlgorithm(object): + __metaclass__ = ABCMeta + + @abstractmethod + def sign(self, *args): + raise NotImplementedError() + + @abstractmethod + def verify(self, *args): + raise NotImplementedError() + + +class PSS(SignAlgorithm): def __init__(self, hash_algorithm=DEFAULT_HASH_ALGORITHM, salt_length=None, mgfunc=None): if hash_algorithm not in HASHES: @@ -46,8 +59,3 @@ def verify(self, public_key, data, signature): h = self.hash_algorithm.new() h.update(data) return pss.verify(h, base64.b64decode(signature)) - - -SIGN_ALGORITHMS = frozenset([ - "PSS" -]) diff --git a/httpsig/verify.py b/httpsig/verify.py index b6bc26e..c529391 100644 --- a/httpsig/verify.py +++ b/httpsig/verify.py @@ -5,7 +5,7 @@ import six from .sign import Signer, DEFAULT_SIGN_ALGORITHM -from .sign_algorithms import SIGN_ALGORITHMS +from .sign_algorithms import SignAlgorithm from .utils import * @@ -38,7 +38,7 @@ def _verify(self, data, signature): s = base64.b64decode(signature) return ct_bytes_compare(h, s) - elif self.sign_algorithm.__class__.__name__ in SIGN_ALGORITHMS: + elif isinstance(self.sign_algorithm, SignAlgorithm): return self.sign_algorithm.verify(self.secret, data, signature) else: @@ -72,7 +72,6 @@ def __init__(self, headers, secret, required_headers=None, method=None, Default is 'authorization'. :param sign_algorithm: Required for 'hs2019' algorithm, specifies the digital signature algorithm (derived from keyId) to use. - :param sign_algorithm: Custom salt length for 'hs2019' and 'PSS' sign algorithm. """ required_headers = required_headers or ['date'] self.headers = CaseInsensitiveDict(headers)