Skip to content

Commit

Permalink
KMS changes [(#1723)](#1723)
Browse files Browse the repository at this point in the history
use byte parameters instead of strings
  • Loading branch information
daniel-sanche authored and busunkim96 committed Jun 4, 2020
1 parent f210de2 commit dc873e2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 38 deletions.
41 changes: 19 additions & 22 deletions kms/snippets/asymmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.rom googleapiclient import discovery

# [START kms_asymmetric_imports]
import base64
import hashlib

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, padding, utils
# [END kms_asymmetric_imports]


# [START kms_get_asymmetric_public]
Expand All @@ -43,35 +45,34 @@ def getAsymmetricPublicKey(client, key_path):
# [START kms_decrypt_rsa]
def decryptRSA(ciphertext, client, key_path):
"""
Decrypt a given ciphertext using an 'RSA_DECRYPT_OAEP_2048_SHA256' private
key stored on Cloud KMS
Decrypt the input ciphertext (bytes) using an
'RSA_DECRYPT_OAEP_2048_SHA256' private key stored on Cloud KMS
"""
request_body = {'ciphertext': base64.b64encode(ciphertext).decode('utf-8')}
request = client.projects() \
.locations() \
.keyRings() \
.cryptoKeys() \
.cryptoKeyVersions() \
.asymmetricDecrypt(name=key_path,
body={'ciphertext': ciphertext})
body=request_body)
response = request.execute()
plaintext = base64.b64decode(response['plaintext']).decode('utf-8')
plaintext = base64.b64decode(response['plaintext'])
return plaintext
# [END kms_decrypt_rsa]


# [START kms_encrypt_rsa]
def encryptRSA(message, client, key_path):
def encryptRSA(plaintext, client, key_path):
"""
Encrypt message locally using an 'RSA_DECRYPT_OAEP_2048_SHA256' public
key retrieved from Cloud KMS
Encrypt the input plaintext (bytes) locally using an
'RSA_DECRYPT_OAEP_2048_SHA256' public key retrieved from Cloud KMS
"""
public_key = getAsymmetricPublicKey(client, key_path)
pad = padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None)
ciphertext = public_key.encrypt(message.encode('ascii'), pad)
ciphertext = base64.b64encode(ciphertext).decode('utf-8')
return ciphertext
return public_key.encrypt(plaintext, pad)
# [END kms_encrypt_rsa]


Expand All @@ -82,7 +83,7 @@ def signAsymmetric(message, client, key_path):
"""
# Note: some key algorithms will require a different hash function
# For example, EC_SIGN_P384_SHA384 requires SHA384
digest_bytes = hashlib.sha256(message.encode('ascii')).digest()
digest_bytes = hashlib.sha256(message).digest()
digest64 = base64.b64encode(digest_bytes)

digest_JSON = {'sha256': digest64.decode('utf-8')}
Expand All @@ -94,24 +95,22 @@ def signAsymmetric(message, client, key_path):
.asymmetricSign(name=key_path,
body={'digest': digest_JSON})
response = request.execute()
return response.get('signature', None)
return base64.b64decode(response.get('signature', None))
# [END kms_sign_asymmetric]


# [START kms_verify_signature_rsa]
def verifySignatureRSA(signature, message, client, key_path):
"""
Verify the validity of an 'RSA_SIGN_PSS_2048_SHA256' signature for the
specified plaintext message
specified message
"""
public_key = getAsymmetricPublicKey(client, key_path)

digest_bytes = hashlib.sha256(message.encode('ascii')).digest()
sig_bytes = base64.b64decode(signature)
digest_bytes = hashlib.sha256(message).digest()

try:
# Attempt verification
public_key.verify(sig_bytes,
public_key.verify(signature,
digest_bytes,
padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
salt_length=32),
Expand All @@ -127,16 +126,14 @@ def verifySignatureRSA(signature, message, client, key_path):
def verifySignatureEC(signature, message, client, key_path):
"""
Verify the validity of an 'EC_SIGN_P256_SHA256' signature
for the specified plaintext message
for the specified message
"""
public_key = getAsymmetricPublicKey(client, key_path)

digest_bytes = hashlib.sha256(message.encode('ascii')).digest()
sig_bytes = base64.b64decode(signature)
digest_bytes = hashlib.sha256(message).digest()

try:
# Attempt verification
public_key.verify(sig_bytes,
public_key.verify(signature,
digest_bytes,
ec.ECDSA(utils.Prehashed(hashes.SHA256())))
# No errors were thrown. Verification was successful
Expand Down
40 changes: 24 additions & 16 deletions kms/snippets/asymmetric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from os import environ
from time import sleep

Expand Down Expand Up @@ -89,6 +88,7 @@ class TestKMSSamples:
.format(parent, keyring, ecSignId)

message = 'test message 123'
message_bytes = message.encode('utf-8')

client = discovery.build('cloudkms', 'v1')

Expand All @@ -99,44 +99,52 @@ def test_get_public_key(self):
assert isinstance(ec_key, _EllipticCurvePublicKey), 'expected EC key'

def test_rsa_encrypt_decrypt(self):
ciphertext = sample.encryptRSA(self.message,
ciphertext = sample.encryptRSA(self.message_bytes,
self.client,
self.rsaDecrypt)
# ciphertext should be 344 characters with base64 and RSA 2048
assert len(ciphertext) == 344, \
'ciphertext should be 344 chars; got {}'.format(len(ciphertext))
assert ciphertext[-2:] == '==', 'cipher text should end with =='
plaintext = sample.decryptRSA(ciphertext, self.client, self.rsaDecrypt)
# ciphertext should be 256 characters with base64 and RSA 2048
assert len(ciphertext) == 256, \
'ciphertext should be 256 chars; got {}'.format(len(ciphertext))
plaintext_bytes = sample.decryptRSA(ciphertext,
self.client,
self.rsaDecrypt)
assert plaintext_bytes == self.message_bytes
plaintext = plaintext_bytes.decode('utf-8')
assert plaintext == self.message

def test_rsa_sign_verify(self):
sig = sample.signAsymmetric(self.message, self.client, self.rsaSign)
sig = sample.signAsymmetric(self.message_bytes,
self.client,
self.rsaSign)
# ciphertext should be 344 characters with base64 and RSA 2048
assert len(sig) == 344, \
'sig should be 344 chars; got {}'.format(len(sig))
assert sig[-2:] == '==', 'sig should end with =='
assert len(sig) == 256, \
'sig should be 256 chars; got {}'.format(len(sig))
success = sample.verifySignatureRSA(sig,
self.message,
self.message_bytes,
self.client,
self.rsaSign)
assert success is True, 'RSA verification failed'
changed_bytes = self.message_bytes + b'.'
success = sample.verifySignatureRSA(sig,
self.message+'.',
changed_bytes,
self.client,
self.rsaSign)
assert success is False, 'verify should fail with modified message'

def test_ec_sign_verify(self):
sig = sample.signAsymmetric(self.message, self.client, self.ecSign)
sig = sample.signAsymmetric(self.message_bytes,
self.client,
self.ecSign)
assert len(sig) > 50 and len(sig) < 300, \
'sig outside expected length range'
success = sample.verifySignatureEC(sig,
self.message,
self.message_bytes,
self.client,
self.ecSign)
assert success is True, 'EC verification failed'
changed_bytes = self.message_bytes + b'.'
success = sample.verifySignatureEC(sig,
self.message+'.',
changed_bytes,
self.client,
self.ecSign)
assert success is False, 'verify should fail with modified message'

0 comments on commit dc873e2

Please sign in to comment.