Skip to content

Commit

Permalink
Merge GH #3345 In SM2 encryption cache the hash and KDF objects
Browse files Browse the repository at this point in the history
  • Loading branch information
randombit committed Mar 5, 2023
2 parents 4bca07a + ad77e5a commit ab7dd3c
Showing 1 changed file with 25 additions and 34 deletions.
59 changes: 25 additions & 34 deletions src/lib/pubkey/sm2/sm2_enc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ class SM2_Encryption_Operation final : public PK_Ops::Encryption
RandomNumberGenerator& rng,
const std::string& kdf_hash) :
m_group(key.domain()),
m_kdf_hash(kdf_hash),
m_ws(EC_Point::WORKSPACE_SIZE),
m_mul_public_point(key.public_point(), rng, m_ws)
{
std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
m_hash_size = hash->output_length();
m_hash = HashFunction::create_or_throw(kdf_hash);
m_kdf = KDF::create_or_throw("KDF2(" + kdf_hash + ")");
}

size_t max_input_bits() const override
Expand All @@ -43,16 +42,13 @@ class SM2_Encryption_Operation final : public PK_Ops::Encryption
const size_t elem_size = m_group.get_order_bytes();
const size_t der_overhead = 16;

return der_overhead + 2*elem_size + m_hash_size + ptext_len;
return der_overhead + 2*elem_size + m_hash->output_length() + ptext_len;
}

secure_vector<uint8_t> encrypt(const uint8_t msg[],
size_t msg_len,
RandomNumberGenerator& rng) override
{
std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");

const size_t p_bytes = m_group.get_p_bytes();

const BigInt k = m_group.random_scalar(rng);
Expand All @@ -79,16 +75,16 @@ class SM2_Encryption_Operation final : public PK_Ops::Encryption
kdf_input += y2_bytes;

const secure_vector<uint8_t> kdf_output =
kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size());
m_kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size());

secure_vector<uint8_t> masked_msg(msg_len);
xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len);

hash->update(x2_bytes);
hash->update(msg, msg_len);
hash->update(y2_bytes);
std::vector<uint8_t> C3(hash->output_length());
hash->final(C3.data());
m_hash->update(x2_bytes);
m_hash->update(msg, msg_len);
m_hash->update(y2_bytes);
std::vector<uint8_t> C3(m_hash->output_length());
m_hash->final(C3.data());

return DER_Encoder()
.start_sequence()
Expand All @@ -102,11 +98,10 @@ class SM2_Encryption_Operation final : public PK_Ops::Encryption

private:
const EC_Group m_group;
const std::string m_kdf_hash;

std::unique_ptr<HashFunction> m_hash;
std::unique_ptr<KDF> m_kdf;
std::vector<BigInt> m_ws;
EC_Point_Var_Point_Precompute m_mul_public_point;
size_t m_hash_size;
};

class SM2_Decryption_Operation final : public PK_Ops::Decryption
Expand All @@ -116,11 +111,10 @@ class SM2_Decryption_Operation final : public PK_Ops::Decryption
RandomNumberGenerator& rng,
const std::string& kdf_hash) :
m_key(key),
m_rng(rng),
m_kdf_hash(kdf_hash)
m_rng(rng)
{
std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
m_hash_size = hash->output_length();
m_hash = HashFunction::create_or_throw(kdf_hash);
m_kdf = KDF::create_or_throw("KDF2(" + kdf_hash + ")");
}

size_t plaintext_length(size_t ptext_len) const override
Expand All @@ -131,10 +125,10 @@ class SM2_Decryption_Operation final : public PK_Ops::Decryption
*/
const size_t elem_size = m_key.domain().get_order_bytes();

if(ptext_len < 2*elem_size + m_hash_size)
if(ptext_len < 2*elem_size + m_hash->output_length())
return 0;

return ptext_len - (2*elem_size + m_hash_size);
return ptext_len - (2*elem_size + m_hash->output_length());
}

secure_vector<uint8_t> decrypt(uint8_t& valid_mask,
Expand All @@ -147,11 +141,8 @@ class SM2_Decryption_Operation final : public PK_Ops::Decryption

valid_mask = 0x00;

std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");

// Too short to be valid - no timing problem from early return
if(ciphertext_len < 1 + p_bytes*2 + hash->output_length())
if(ciphertext_len < 1 + p_bytes*2 + m_hash->output_length())
{
return secure_vector<uint8_t>();
}
Expand Down Expand Up @@ -211,16 +202,16 @@ class SM2_Decryption_Operation final : public PK_Ops::Decryption
kdf_input += y2_bytes;

const secure_vector<uint8_t> kdf_output =
kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size());
m_kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size());

xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size());

hash->update(x2_bytes);
hash->update(masked_msg);
hash->update(y2_bytes);
secure_vector<uint8_t> u = hash->final();
m_hash->update(x2_bytes);
m_hash->update(masked_msg);
m_hash->update(y2_bytes);
secure_vector<uint8_t> u = m_hash->final();

if(constant_time_compare(u.data(), C3.data(), hash->output_length()) == false)
if(constant_time_compare(u.data(), C3.data(), m_hash->output_length()) == false)
return secure_vector<uint8_t>();

valid_mask = 0xFF;
Expand All @@ -229,9 +220,9 @@ class SM2_Decryption_Operation final : public PK_Ops::Decryption
private:
const SM2_Encryption_PrivateKey& m_key;
RandomNumberGenerator& m_rng;
const std::string m_kdf_hash;
std::vector<BigInt> m_ws;
size_t m_hash_size;
std::unique_ptr<HashFunction> m_hash;
std::unique_ptr<KDF> m_kdf;
};

}
Expand Down

0 comments on commit ab7dd3c

Please sign in to comment.