Skip to content

Commit

Permalink
add tool to generate test cases for BLS operators
Browse files Browse the repository at this point in the history
  • Loading branch information
arvidn committed May 16, 2023
1 parent cdfe1c8 commit 2e4aac5
Show file tree
Hide file tree
Showing 7 changed files with 771 additions and 0 deletions.
120 changes: 120 additions & 0 deletions op-tests/test-blspy-g1.txt

Large diffs are not rendered by default.

120 changes: 120 additions & 0 deletions op-tests/test-blspy-g2.txt

Large diffs are not rendered by default.

212 changes: 212 additions & 0 deletions op-tests/test-blspy-hash.txt

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions op-tests/test-blspy-pairing.txt

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions op-tests/test-blspy-verify.txt

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions src/test_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ use rstest::rstest;
#[case("test-core-ops")]
#[case("test-more-ops")]
#[case("test-bls-ops")]
#[case("test-blspy-g1")]
#[case("test-blspy-g2")]
#[case("test-blspy-hash")]
#[case("test-blspy-pairing")]
#[case("test-blspy-verify")]
fn test_ops(#[case] filename: &str) {
use std::fs::read_to_string;

Expand Down
250 changes: 250 additions & 0 deletions tools/generate-bls-tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import blspy
from random import randbytes, randint, seed, sample

def bytes_in_atom(n: int) -> int:
if n == 0:
return 0
if n <= 0x7f:
return 1
if n <= 0x7fff:
return 2
if n <= 0x7fffff:
return 3
if n <= 0x7fffffff:
return 4
if n <= 0x7fffffffff:
return 5
assert False

def flip_bit(b: bytes) -> bytearray:
idx = randint(0, len(b) - 1)
bit = 1 << randint(0, 7)
ret = bytearray(b)
ret[idx] ^= bit
return ret

def print_validation_test_case(f1, f2, num_cases, filter_pk, filter_msg, filter_sig, expect: str):
sks = sample(secret_keys, randint(1,min(10, num_cases)))
cost = 4999087
messages = []
sigs = []

args = ""
for sk in sks:
pk = sk.get_g1()
msg = randbytes(randint(3,40))
cost += len(msg) * 122 + 43 * 135
cost += 4515438
messages.append(msg)
sigs.append(blspy.AugSchemeMPL.sign(sk, msg))
args += f"(0x{bytes(filter_pk(pk)).hex()} . 0x{filter_msg(msg).hex()}) "

agg_sig = blspy.AugSchemeMPL.aggregate(sigs)

f1.write(f"bls_verify 0x{bytes(filter_sig(agg_sig)).hex()} ")
f1.write(args)
f1.write(f"=> {expect} | {cost}\n")

# interleave tests for bls_pairing_identity using the same parameters
cost = 4999087
f2.write("bls_pairing_identity ")
for sk, msg in zip(sks, messages):
pk = sk.get_g1()
cost += 4515438

# in the AUG scheme we prepend the public key to the message before
# hashing it to the G2 point
g2 = blspy.AugSchemeMPL.g2_from_message(bytes(pk) + filter_msg(msg))
f2.write(f"(0x{bytes(filter_pk(pk)).hex()} . 0x{bytes(g2).hex()}) ")

# this is the low-level pairing operation, we also need to include the
# signature and the negated generator
gen = "b7f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb"
f2.write(f"(0x{gen} . 0x{bytes(filter_sig(agg_sig)).hex()}) ")
cost += 4515438

f2.write(f"=> {expect} | {cost}\n")


seed(1337)

SIZE = 30

# generate a bunch of G1 points
g1_points = []
secret_keys = []
for i in range(SIZE):
seed = randbytes(32)
sk = blspy.AugSchemeMPL.key_gen(seed)
secret_keys.append(sk)
g1_points.append(sk.get_g1())

# generate a bunch of G2 points
g2_points = []
for i in range(SIZE):
seed = randbytes(32)
g2_points.append(blspy.AugSchemeMPL.g2_from_message(seed))

# generate a bunch of GT points
gt_points = []
for g1, g2 in zip(g1_points, g2_points):
gt_points.append(g2.pair(g1))

with open("../op-tests/test-blspy-g1.txt", "w+") as f:
f.write("; This file was generated by tools/generate-bls-tests.py\n\n")

# g1_add
aggregate = None
for g1 in g1_points:
if aggregate is None:
aggregate = g1
continue

cost = 101094 + 1343980 * 2 + 48 * 10
result = aggregate + g1
f.write(f"g1_add 0x{bytes(aggregate).hex()} 0x{bytes(g1).hex()} => 0x{bytes(result).hex()} | {cost}\n")

aggregate = result

# g1_subtract
aggregate = None
for g1 in g1_points:
if aggregate is None:
aggregate = g1
continue

cost = 2857918
result = aggregate + g1.negate()
f.write(f"g1_subtract 0x{bytes(aggregate).hex()} 0x{bytes(g1).hex()} => 0x{bytes(result).hex()} | {cost}\n")

aggregate = result

# g1_multiply
for g1 in g1_points:
scalar = randint(-100, 100)
cost = 2154347 + bytes_in_atom(scalar) * 12 + 48 * 10
# blspy does not expose multiplication, so we simulate it
result = blspy.G1Element()
if scalar < 0:
for i in range(-scalar):
result += g1
result = result.negate()
else:
for i in range(scalar):
result += g1
f.write(f"g1_multiply 0x{bytes(g1).hex()} {scalar} => 0x{bytes(result).hex()} | {cost}\n")

# g1_negate
for g1 in g1_points:

cost = 471259
result = g1.negate()
f.write(f"g1_negate 0x{bytes(g1).hex()} => 0x{bytes(result).hex()} | {cost}\n")

aggregate = result

with open("../op-tests/test-blspy-g2.txt", "w+") as f:
f.write("; This file was generated by tools/generate-bls-tests.py\n\n")

# g2_add
aggregate = None
for g2 in g2_points:
if aggregate is None:
aggregate = g2
continue

cost = 11135562
result = aggregate + g2
f.write(f"g2_add 0x{bytes(aggregate).hex()} 0x{bytes(g2).hex()} => 0x{bytes(result).hex()} | {cost}\n")

aggregate = result


# g2_subtract
aggregate = None
for g2 in g2_points:
if aggregate is None:
aggregate = g2
continue

cost = 11137794
result = aggregate + g2.negate()
f.write(f"g2_subtract 0x{bytes(aggregate).hex()} 0x{bytes(g2).hex()} => 0x{bytes(result).hex()} | {cost}\n")

aggregate = result

# g2_multiply
for g2 in g2_points:
scalar = randint(-100, 100)
cost = 10078145 + bytes_in_atom(scalar) * 12 + 96 * 10
# blspy does not expose multiplication, so we simulate it
result = blspy.G2Element()
if scalar < 0:
for i in range(-scalar):
result += g2
result = result.negate()
else:
for i in range(scalar):
result += g2
f.write(f"g2_multiply 0x{bytes(g2).hex()} {scalar} => 0x{bytes(result).hex()} | {cost}\n")

# g2_negate
for g2 in g2_points:

cost = 1882659
result = g2.negate()
f.write(f"g2_negate 0x{bytes(g2).hex()} => 0x{bytes(result).hex()} | {cost}\n")

aggregate = result

with open("../op-tests/test-blspy-hash.txt", "w+") as f:
f.write("; This file was generated by tools/generate-bls-tests.py\n\n")

# g1_map
for i in range(SIZE):
msg = randbytes(randint(3,40))
g1 = blspy.G1Element.from_message(msg, "BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_AUG_")
cost = 610907 + len(msg) * 122 + 43 * 135 + 48 * 10
f.write(f"g1_map 0x{bytes(msg).hex()} \"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_AUG_\" => 0x{bytes(g1).hex()} | {cost}\n")
f.write(f"g1_map 0x{bytes(msg).hex()} => 0x{bytes(g1).hex()} | {cost}\n")
g1 = blspy.G1Element.from_message(msg, "BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_")
f.write(f"g1_map 0x{bytes(msg).hex()} \"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_\" => 0x{bytes(g1).hex()} | {cost}\n")

# g2_map
for i in range(SIZE):
msg = randbytes(randint(3,40))
g2 = blspy.AugSchemeMPL.g2_from_message(msg)
cost = 3386788 + len(msg) * 122
f.write(f"g2_map 0x{bytes(msg).hex()} \"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_AUG_\" => 0x{bytes(g2).hex()} | {cost}\n")
# this scheme is the default, and doesn't need to be specified
# it has the same cost
f.write(f"g2_map 0x{bytes(msg).hex()} => 0x{bytes(g2).hex()} | {cost}\n")

g2 = blspy.BasicSchemeMPL.g2_from_message(msg)
f.write(f"g2_map 0x{bytes(msg).hex()} \"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_\" => 0x{bytes(g2).hex()} | {cost}\n")

g2 = blspy.PopSchemeMPL.g2_from_message(msg)
f.write(f"g2_map 0x{bytes(msg).hex()} \"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_\" => 0x{bytes(g2).hex()} | {cost}\n")

with open("../op-tests/test-blspy-verify.txt", "w+") as f1, \
open("../op-tests/test-blspy-pairing.txt", "w+") as f2:
f1.write("; This file was generated by tools/generate-bls-tests.py\n\n")
f2.write("; This file was generated by tools/generate-bls-tests.py\n\n")


# bls_verify
# bls_pairing_identity
for k in range(SIZE // 2):
print_validation_test_case(f1, f2, SIZE, lambda pk: pk, lambda msg: msg, lambda sig: sig, "1")

# negative tests (alter public key)
for k in range(5):
print_validation_test_case(f1, f2, 3, lambda pk: pk.negate(), lambda msg: msg, lambda sig: sig, "0")

# negative tests (alter message)
for k in range(5):
print_validation_test_case(f1, f2, 3, lambda pk: pk, flip_bit, lambda sig: sig, "0")

# negative tests (alter signature)
for k in range(5):
print_validation_test_case(f1, f2, 3, lambda pk: pk, lambda msg: msg, lambda sig: sig.negate(), "0")

0 comments on commit 2e4aac5

Please sign in to comment.