diff --git a/kms/api-client/asymmetric.py b/kms/api-client/asymmetric.py index 4d4ebcb4f3b2..bc313aaa8494 100644 --- a/kms/api-client/asymmetric.py +++ b/kms/api-client/asymmetric.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License.rom googleapiclient import discovery +# [START kms_asymmetric_imports] import base64 import hashlib @@ -20,6 +21,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec, padding, utils +# [END kms_asymmetric_imports] # [START kms_get_asymmetric_public] @@ -43,35 +45,34 @@ def getAsymmetricPublicKey(client, key_path): # [START kms_decrypt_rsa] def decryptRSA(ciphertext, client, key_path): """ - Decrypt a given ciphertext using an 'RSA_DECRYPT_OAEP_2048_SHA256' private - key stored on Cloud KMS + Decrypt the input ciphertext (bytes) using an + 'RSA_DECRYPT_OAEP_2048_SHA256' private key stored on Cloud KMS """ + request_body = {'ciphertext': base64.b64encode(ciphertext).decode('utf-8')} request = client.projects() \ .locations() \ .keyRings() \ .cryptoKeys() \ .cryptoKeyVersions() \ .asymmetricDecrypt(name=key_path, - body={'ciphertext': ciphertext}) + body=request_body) response = request.execute() - plaintext = base64.b64decode(response['plaintext']).decode('utf-8') + plaintext = base64.b64decode(response['plaintext']) return plaintext # [END kms_decrypt_rsa] # [START kms_encrypt_rsa] -def encryptRSA(message, client, key_path): +def encryptRSA(plaintext, client, key_path): """ - Encrypt message locally using an 'RSA_DECRYPT_OAEP_2048_SHA256' public - key retrieved from Cloud KMS + Encrypt the input plaintext (bytes) locally using an + 'RSA_DECRYPT_OAEP_2048_SHA256' public key retrieved from Cloud KMS """ public_key = getAsymmetricPublicKey(client, key_path) pad = padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None) - ciphertext = public_key.encrypt(message.encode('ascii'), pad) - ciphertext = base64.b64encode(ciphertext).decode('utf-8') - return ciphertext + return public_key.encrypt(plaintext, pad) # [END kms_encrypt_rsa] @@ -82,7 +83,7 @@ def signAsymmetric(message, client, key_path): """ # Note: some key algorithms will require a different hash function # For example, EC_SIGN_P384_SHA384 requires SHA384 - digest_bytes = hashlib.sha256(message.encode('ascii')).digest() + digest_bytes = hashlib.sha256(message).digest() digest64 = base64.b64encode(digest_bytes) digest_JSON = {'sha256': digest64.decode('utf-8')} @@ -94,7 +95,7 @@ def signAsymmetric(message, client, key_path): .asymmetricSign(name=key_path, body={'digest': digest_JSON}) response = request.execute() - return response.get('signature', None) + return base64.b64decode(response.get('signature', None)) # [END kms_sign_asymmetric] @@ -102,16 +103,14 @@ def signAsymmetric(message, client, key_path): def verifySignatureRSA(signature, message, client, key_path): """ Verify the validity of an 'RSA_SIGN_PSS_2048_SHA256' signature for the - specified plaintext message + specified message """ public_key = getAsymmetricPublicKey(client, key_path) - - digest_bytes = hashlib.sha256(message.encode('ascii')).digest() - sig_bytes = base64.b64decode(signature) + digest_bytes = hashlib.sha256(message).digest() try: # Attempt verification - public_key.verify(sig_bytes, + public_key.verify(signature, digest_bytes, padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=32), @@ -127,16 +126,14 @@ def verifySignatureRSA(signature, message, client, key_path): def verifySignatureEC(signature, message, client, key_path): """ Verify the validity of an 'EC_SIGN_P256_SHA256' signature - for the specified plaintext message + for the specified message """ public_key = getAsymmetricPublicKey(client, key_path) - - digest_bytes = hashlib.sha256(message.encode('ascii')).digest() - sig_bytes = base64.b64decode(signature) + digest_bytes = hashlib.sha256(message).digest() try: # Attempt verification - public_key.verify(sig_bytes, + public_key.verify(signature, digest_bytes, ec.ECDSA(utils.Prehashed(hashes.SHA256()))) # No errors were thrown. Verification was successful diff --git a/kms/api-client/asymmetric_test.py b/kms/api-client/asymmetric_test.py index 5f969be7db68..4ce9b32aed5b 100644 --- a/kms/api-client/asymmetric_test.py +++ b/kms/api-client/asymmetric_test.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from os import environ from time import sleep @@ -89,6 +88,7 @@ class TestKMSSamples: .format(parent, keyring, ecSignId) message = 'test message 123' + message_bytes = message.encode('utf-8') client = discovery.build('cloudkms', 'v1') @@ -99,44 +99,52 @@ def test_get_public_key(self): assert isinstance(ec_key, _EllipticCurvePublicKey), 'expected EC key' def test_rsa_encrypt_decrypt(self): - ciphertext = sample.encryptRSA(self.message, + ciphertext = sample.encryptRSA(self.message_bytes, self.client, self.rsaDecrypt) - # ciphertext should be 344 characters with base64 and RSA 2048 - assert len(ciphertext) == 344, \ - 'ciphertext should be 344 chars; got {}'.format(len(ciphertext)) - assert ciphertext[-2:] == '==', 'cipher text should end with ==' - plaintext = sample.decryptRSA(ciphertext, self.client, self.rsaDecrypt) + # ciphertext should be 256 characters with base64 and RSA 2048 + assert len(ciphertext) == 256, \ + 'ciphertext should be 256 chars; got {}'.format(len(ciphertext)) + plaintext_bytes = sample.decryptRSA(ciphertext, + self.client, + self.rsaDecrypt) + assert plaintext_bytes == self.message_bytes + plaintext = plaintext_bytes.decode('utf-8') assert plaintext == self.message def test_rsa_sign_verify(self): - sig = sample.signAsymmetric(self.message, self.client, self.rsaSign) + sig = sample.signAsymmetric(self.message_bytes, + self.client, + self.rsaSign) # ciphertext should be 344 characters with base64 and RSA 2048 - assert len(sig) == 344, \ - 'sig should be 344 chars; got {}'.format(len(sig)) - assert sig[-2:] == '==', 'sig should end with ==' + assert len(sig) == 256, \ + 'sig should be 256 chars; got {}'.format(len(sig)) success = sample.verifySignatureRSA(sig, - self.message, + self.message_bytes, self.client, self.rsaSign) assert success is True, 'RSA verification failed' + changed_bytes = self.message_bytes + b'.' success = sample.verifySignatureRSA(sig, - self.message+'.', + changed_bytes, self.client, self.rsaSign) assert success is False, 'verify should fail with modified message' def test_ec_sign_verify(self): - sig = sample.signAsymmetric(self.message, self.client, self.ecSign) + sig = sample.signAsymmetric(self.message_bytes, + self.client, + self.ecSign) assert len(sig) > 50 and len(sig) < 300, \ 'sig outside expected length range' success = sample.verifySignatureEC(sig, - self.message, + self.message_bytes, self.client, self.ecSign) assert success is True, 'EC verification failed' + changed_bytes = self.message_bytes + b'.' success = sample.verifySignatureEC(sig, - self.message+'.', + changed_bytes, self.client, self.ecSign) assert success is False, 'verify should fail with modified message'