From 83d99b3eaf7130a6bed00fe2e2d6dfc82e816f29 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 18 Dec 2024 15:38:03 +0100 Subject: [PATCH 01/10] add support for CSHAKE --- cng/hash.go | 147 +++++++++++++++++++++ cng/hash_test.go | 205 ++++++++++++++++++++++++++++++ internal/bcrypt/bcrypt_windows.go | 30 +++-- 3 files changed, 371 insertions(+), 11 deletions(-) diff --git a/cng/hash.go b/cng/hash.go index 87b1c95..223088a 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -11,6 +11,7 @@ import ( "crypto" "hash" "runtime" + "slices" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -304,3 +305,149 @@ func (h *hashX) Sum(in []byte) []byte { } return append(in, h.buf...) } + +// SumSHAKE128 applies the SHAKE128 extendable output function to data and +// returns an output of the given length in bytes. +func SumSHAKE128(data []byte, length int) []byte { + out := make([]byte, length) + if err := hashOneShot(bcrypt.CSHAKE128_ALGORITHM, data, out); err != nil { + panic("bcrypt: CSHAKE128_ALGORITHM failed") + } + return out +} + +// SumSHAKE256 applies the SHAKE256 extendable output function to data and +// returns an output of the given length in bytes. +func SumSHAKE256(data []byte, length int) []byte { + out := make([]byte, length) + if err := hashOneShot(bcrypt.CSHAKE256_ALGORITHM, data, out); err != nil { + panic("bcrypt: CSHAKE128_ALGORITHM failed") + } + return out +} + +// SHAKE is an instance of a SHAKE extendable output function. +type SHAKE struct { + alg *hashAlgorithm + ctx bcrypt.HASH_HANDLE + n, s []byte +} + +func newShake(id string, N, S []byte) *SHAKE { + alg, err := loadHash(id, bcrypt.ALG_NONE_FLAG) + if err != nil { + panic(err) + } + h := &SHAKE{alg: alg, n: slices.Clone(N), s: slices.Clone(S)} + err = bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + if len(N) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), N, 0); err != nil { + panic(err) + } + } + if len(S) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), S, 0); err != nil { + panic(err) + } + } + runtime.SetFinalizer(h, (*SHAKE).finalize) + return h +} + +// NewSHAKE128 creates a new SHAKE128 XOF. +func NewSHAKE128() *SHAKE { + return newShake(bcrypt.CSHAKE128_ALGORITHM, nil, nil) +} + +// NewSHAKE256 creates a new SHAKE256 XOF. +func NewSHAKE256() *SHAKE { + return newShake(bcrypt.CSHAKE256_ALGORITHM, nil, nil) +} + +// NewCSHAKE128 creates a new cSHAKE128 XOF. +// +// N is used to define functions based on cSHAKE, it can be empty when plain +// cSHAKE is desired. S is a customization byte string used for domain +// separation. When N and S are both empty, this is equivalent to NewSHAKE128. +func NewCSHAKE128(N, S []byte) *SHAKE { + return newShake(bcrypt.CSHAKE128_ALGORITHM, N, S) +} + +// NewCSHAKE256 creates a new cSHAKE256 XOF. +// +// N is used to define functions based on cSHAKE, it can be empty when plain +// cSHAKE is desired. S is a customization byte string used for domain +// separation. When N and S are both empty, this is equivalent to NewSHAKE256. +func NewCSHAKE256(N, S []byte) *SHAKE { + return newShake(bcrypt.CSHAKE256_ALGORITHM, N, S) +} + +func (h *SHAKE) finalize() { + bcrypt.DestroyHash(h.ctx) +} + +// Write absorbs more data into the XOF's state. +// +// It panics if any output has already been read. +func (s *SHAKE) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + defer runtime.KeepAlive(s) + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(s.ctx, p[n:n+nn], 0) + n += nn + } + if err != nil { + panic(err) + } + return len(p), nil +} + +// Read squeezes more output from the XOF. +// +// Any call to Write after a call to Read will panic. +func (s *SHAKE) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + defer runtime.KeepAlive(s) + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.FinishHash(s.ctx, p[n:n+nn], bcrypt.HASH_DONT_RESET_FLAG) + n += nn + } + if err != nil { + panic(err) + } + return len(p), nil +} + +// Reset resets the XOF to its initial state. +func (s *SHAKE) Reset() { + defer runtime.KeepAlive(s) + bcrypt.DestroyHash(s.ctx) + err := bcrypt.CreateHash(s.alg.handle, &s.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + if len(s.n) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), s.n, 0); err != nil { + panic(err) + } + } + if len(s.s) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), s.s, 0); err != nil { + panic(err) + } + } +} + +// BlockSize returns the rate of the XOF. +func (s *SHAKE) BlockSize() int { + return int(s.alg.blockSize) +} diff --git a/cng/hash_test.go b/cng/hash_test.go index 21a7fa8..c692ca6 100644 --- a/cng/hash_test.go +++ b/cng/hash_test.go @@ -9,8 +9,10 @@ package cng_test import ( "bytes" "crypto" + "encoding/hex" "hash" "io" + "math/rand" "testing" "github.com/microsoft/go-crypto-winnative/cng" @@ -212,3 +214,206 @@ func BenchmarkSHA256_OneShot(b *testing.B) { cng.SHA256(buf) } } + +// testShakes contains functions that return *sha3.SHAKE instances for +// with output-length equal to the KAT length. +var testShakes = map[string]struct { + constructor func(N []byte, S []byte) *cng.SHAKE + defAlgoName string + defCustomStr string +}{ + // NewCSHAKE without customization produces same result as SHAKE + "SHAKE128": {cng.NewCSHAKE128, "", ""}, + "SHAKE256": {cng.NewCSHAKE256, "", ""}, + "cSHAKE128": {cng.NewCSHAKE128, "CSHAKE128", "CustomString"}, + "cSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, +} + +// TestCSHAKESqueezing checks that squeezing the full output a single time produces +// the same output as repeatedly squeezing the instance. +func TestCSHAKESqueezing(t *testing.T) { + const testString = "brekeccakkeccak koax koax" + for algo, v := range testShakes { + d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d0.Write([]byte(testString)) + ref := make([]byte, 32) + d0.Read(ref) + + d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d1.Write([]byte(testString)) + var multiple []byte + for range ref { + d1.Read(make([]byte, 0)) + one := make([]byte, 1) + d1.Read(one) + multiple = append(multiple, one...) + } + if !bytes.Equal(ref, multiple) { + t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref)) + } + } +} + +// sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing. +func sequentialBytes(size int) []byte { + alignmentOffset := rand.Intn(8) + result := make([]byte, size+alignmentOffset)[alignmentOffset:] + for i := range result { + result[i] = byte(i) + } + return result +} + +func TestCSHAKEReset(t *testing.T) { + out1 := make([]byte, 32) + out2 := make([]byte, 32) + + for _, v := range testShakes { + // Calculate hash for the first time + c := v.constructor(nil, []byte{0x99, 0x98}) + c.Write(sequentialBytes(0x100)) + c.Read(out1) + + // Calculate hash again + c.Reset() + c.Write(sequentialBytes(0x100)) + c.Read(out2) + + if !bytes.Equal(out1, out2) { + t.Error("\nExpected:\n", out1, "\ngot:\n", out2) + } + } +} + +func TestCSHAKEAccumulated(t *testing.T) { + t.Run("CSHAKE128", func(t *testing.T) { + testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8, + "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") + }) + t.Run("CSHAKE256", func(t *testing.T) { + testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8, + "0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef") + }) +} + +func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE, rate int64, exp string) { + rnd := newCSHAKE(nil, nil) + acc := newCSHAKE(nil, nil) + for n := 0; n < 200; n++ { + N := make([]byte, n) + rnd.Read(N) + for s := 0; s < 200; s++ { + S := make([]byte, s) + rnd.Read(S) + + c := newCSHAKE(N, S) + io.CopyN(c, rnd, 100 /* < rate */) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, rate) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, 200 /* > rate */) + io.CopyN(acc, c, 200) + } + } + out := make([]byte, 32) + acc.Read(out) + if got := hex.EncodeToString(out); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKELargeS(t *testing.T) { + const s = (1<<32)/8 + 1000 // s * 8 > 2^32 + S := make([]byte, s) + rnd := cng.NewSHAKE128() + rnd.Read(S) + c := cng.NewCSHAKE128(nil, S) + io.CopyN(c, rnd, 1000) + out := make([]byte, 32) + c.Read(out) + + exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0" + if got := hex.EncodeToString(out); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKESum(t *testing.T) { + const testString = "hello world" + t.Run("CSHAKE128", func(t *testing.T) { + h := cng.NewCSHAKE128(nil, nil) + h.Write([]byte(testString[:5])) + h.Write([]byte(testString[5:])) + want := make([]byte, 32) + h.Read(want) + got := cng.SumSHAKE128([]byte(testString), 32) + if !bytes.Equal(got, want) { + t.Errorf("got:%x want:%x", got, want) + } + }) + t.Run("CSHAKE256", func(t *testing.T) { + h := cng.NewCSHAKE256(nil, nil) + h.Write([]byte(testString[:5])) + h.Write([]byte(testString[5:])) + want := make([]byte, 32) + h.Read(want) + got := cng.SumSHAKE256([]byte(testString), 32) + if !bytes.Equal(got, want) { + t.Errorf("got:%x want:%x", got, want) + } + }) +} + +// benchmarkHash tests the speed to hash num buffers of buflen each. +func benchmarkHash(b *testing.B, h hash.Hash, size, num int) { + b.StopTimer() + h.Reset() + data := sequentialBytes(size) + b.SetBytes(int64(size * num)) + b.StartTimer() + + var state []byte + for i := 0; i < b.N; i++ { + for j := 0; j < num; j++ { + h.Write(data) + } + state = h.Sum(state[:0]) + } + b.StopTimer() + h.Reset() +} + +// benchmarkCSHAKE is specialized to the Shake instances, which don't +// require a copy on reading output. +func benchmarkCSHAKE(b *testing.B, h *cng.SHAKE, size, num int) { + b.StopTimer() + h.Reset() + data := sequentialBytes(size) + d := make([]byte, 32) + + b.SetBytes(int64(size * num)) + b.StartTimer() + + for i := 0; i < b.N; i++ { + h.Reset() + for j := 0; j < num; j++ { + h.Write(data) + } + h.Read(d) + } +} + +func BenchmarkSHA3_512_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1350, 1) } +func BenchmarkSHA3_384_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_384(), 1350, 1) } +func BenchmarkSHA3_256_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_256(), 1350, 1) } + +func BenchmarkCSHAKE128_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE128(), 1350, 1) } +func BenchmarkCSHAKE256_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1350, 1) } +func BenchmarkCSHAKE256_16x(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 16, 1024) } +func BenchmarkCSHAKE256_1MiB(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1024, 1024) } + +func BenchmarkCSHA3_512_1MiB(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1024, 1024) } diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 090c74a..e3255e2 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -22,6 +22,8 @@ const ( SHA3_256_ALGORITHM = "SHA3-256" SHA3_384_ALGORITHM = "SHA3-384" SHA3_512_ALGORITHM = "SHA3-512" + CSHAKE128_ALGORITHM = "CSHAKE128" + CSHAKE256_ALGORITHM = "CSHAKE256" AES_ALGORITHM = "AES" RC4_ALGORITHM = "RC4" RSA_ALGORITHM = "RSA" @@ -47,17 +49,19 @@ const ( ) const ( - HASH_LENGTH = "HashDigestLength" - HASH_BLOCK_LENGTH = "HashBlockLength" - CHAINING_MODE = "ChainingMode" - CHAIN_MODE_ECB = "ChainingModeECB" - CHAIN_MODE_CBC = "ChainingModeCBC" - CHAIN_MODE_GCM = "ChainingModeGCM" - KEY_LENGTH = "KeyLength" - KEY_LENGTHS = "KeyLengths" - SIGNATURE_LENGTH = "SignatureLength" - BLOCK_LENGTH = "BlockLength" - ECC_CURVE_NAME = "ECCCurveName" + HASH_LENGTH = "HashDigestLength" + HASH_BLOCK_LENGTH = "HashBlockLength" + CHAINING_MODE = "ChainingMode" + CHAIN_MODE_ECB = "ChainingModeECB" + CHAIN_MODE_CBC = "ChainingModeCBC" + CHAIN_MODE_GCM = "ChainingModeGCM" + KEY_LENGTH = "KeyLength" + KEY_LENGTHS = "KeyLengths" + SIGNATURE_LENGTH = "SignatureLength" + BLOCK_LENGTH = "BlockLength" + ECC_CURVE_NAME = "ECCCurveName" + FUNCTION_NAME_STRING = "FunctionNameString" + CUSTOMIZATION_STRING = "CustomizationString" ) const ( @@ -113,6 +117,10 @@ const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) +const ( + HASH_DONT_RESET_FLAG = 0x00000001 +) + const ( KDF_RAW_SECRET = "TRUNCATE" ) From 804906ffa2db1854b3c994eed41d881eb080ab53 Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Wed, 18 Dec 2024 15:47:41 +0100 Subject: [PATCH 02/10] Update cng/hash.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cng/hash.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cng/hash.go b/cng/hash.go index 223088a..e8339f8 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -321,7 +321,7 @@ func SumSHAKE128(data []byte, length int) []byte { func SumSHAKE256(data []byte, length int) []byte { out := make([]byte, length) if err := hashOneShot(bcrypt.CSHAKE256_ALGORITHM, data, out); err != nil { - panic("bcrypt: CSHAKE128_ALGORITHM failed") + panic("bcrypt: CSHAKE256_ALGORITHM failed") } return out } From c95a35fd045a4ce68db8938d1e9194f942d2ebcd Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 18 Dec 2024 15:52:27 +0100 Subject: [PATCH 03/10] add Supports* functions --- cng/hash.go | 14 ++++++++++++++ cng/hash_test.go | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/cng/hash.go b/cng/hash.go index 223088a..06d6b68 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -306,6 +306,20 @@ func (h *hashX) Sum(in []byte) []byte { return append(in, h.buf...) } +// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is +// supported. +func SupportsSHAKE128() bool { + _, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG) + return err == nil +} + +// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is +// supported. +func SupportsSHAKE256() bool { + _, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG) + return err == nil +} + // SumSHAKE128 applies the SHAKE128 extendable output function to data and // returns an output of the given length in bytes. func SumSHAKE128(data []byte, length int) []byte { diff --git a/cng/hash_test.go b/cng/hash_test.go index c692ca6..c49a4d2 100644 --- a/cng/hash_test.go +++ b/cng/hash_test.go @@ -225,8 +225,8 @@ var testShakes = map[string]struct { // NewCSHAKE without customization produces same result as SHAKE "SHAKE128": {cng.NewCSHAKE128, "", ""}, "SHAKE256": {cng.NewCSHAKE256, "", ""}, - "cSHAKE128": {cng.NewCSHAKE128, "CSHAKE128", "CustomString"}, - "cSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, + "CSHAKE128": {cng.NewCSHAKE128, "CSHAKE128", "CustomString"}, + "CSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, } // TestCSHAKESqueezing checks that squeezing the full output a single time produces @@ -234,6 +234,13 @@ var testShakes = map[string]struct { func TestCSHAKESqueezing(t *testing.T) { const testString = "brekeccakkeccak koax koax" for algo, v := range testShakes { + if algo == "SHAKE128" && !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + if algo == "SHAKE256" && !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } + d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) d0.Write([]byte(testString)) ref := make([]byte, 32) @@ -268,7 +275,13 @@ func TestCSHAKEReset(t *testing.T) { out1 := make([]byte, 32) out2 := make([]byte, 32) - for _, v := range testShakes { + for algo, v := range testShakes { + if algo == "SHAKE128" && !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + if algo == "SHAKE256" && !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } // Calculate hash for the first time c := v.constructor(nil, []byte{0x99, 0x98}) c.Write(sequentialBytes(0x100)) @@ -287,10 +300,16 @@ func TestCSHAKEReset(t *testing.T) { func TestCSHAKEAccumulated(t *testing.T) { t.Run("CSHAKE128", func(t *testing.T) { + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8, "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") }) t.Run("CSHAKE256", func(t *testing.T) { + if !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8, "0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef") }) @@ -327,6 +346,9 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE, } func TestCSHAKELargeS(t *testing.T) { + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } const s = (1<<32)/8 + 1000 // s * 8 > 2^32 S := make([]byte, s) rnd := cng.NewSHAKE128() @@ -345,6 +367,9 @@ func TestCSHAKELargeS(t *testing.T) { func TestCSHAKESum(t *testing.T) { const testString = "hello world" t.Run("CSHAKE128", func(t *testing.T) { + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } h := cng.NewCSHAKE128(nil, nil) h.Write([]byte(testString[:5])) h.Write([]byte(testString[5:])) @@ -356,6 +381,9 @@ func TestCSHAKESum(t *testing.T) { } }) t.Run("CSHAKE256", func(t *testing.T) { + if !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } h := cng.NewCSHAKE256(nil, nil) h.Write([]byte(testString[:5])) h.Write([]byte(testString[5:])) From 77726dda6fa9564f6bc6491e472e91331c0226b9 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 18 Dec 2024 18:52:33 +0100 Subject: [PATCH 04/10] improve skips --- cng/hash_test.go | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/cng/hash_test.go b/cng/hash_test.go index c49a4d2..3af6045 100644 --- a/cng/hash_test.go +++ b/cng/hash_test.go @@ -229,17 +229,25 @@ var testShakes = map[string]struct { "CSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, } +func skipCSHAKEIfNotSupported(t *testing.T, algo string) { + switch algo { + case "SHAKE128", "CSHAKE128": + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + case "SHAKE256", "CSHAKE256": + if !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } + } +} + // TestCSHAKESqueezing checks that squeezing the full output a single time produces // the same output as repeatedly squeezing the instance. func TestCSHAKESqueezing(t *testing.T) { const testString = "brekeccakkeccak koax koax" for algo, v := range testShakes { - if algo == "SHAKE128" && !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } - if algo == "SHAKE256" && !cng.SupportsSHAKE256() { - t.Skip("skipping: not supported") - } + skipCSHAKEIfNotSupported(t, algo) d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) d0.Write([]byte(testString)) @@ -276,12 +284,8 @@ func TestCSHAKEReset(t *testing.T) { out2 := make([]byte, 32) for algo, v := range testShakes { - if algo == "SHAKE128" && !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } - if algo == "SHAKE256" && !cng.SupportsSHAKE256() { - t.Skip("skipping: not supported") - } + skipCSHAKEIfNotSupported(t, algo) + // Calculate hash for the first time c := v.constructor(nil, []byte{0x99, 0x98}) c.Write(sequentialBytes(0x100)) From 029af5574930d5787817f11f6786ee9d910ce9a8 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 19 Dec 2024 11:39:55 +0100 Subject: [PATCH 05/10] refactor pure SHA3 API --- cng/hash.go | 197 ---------------------------- cng/hash_test.go | 249 +---------------------------------- cng/sha3.go | 332 +++++++++++++++++++++++++++++++++++++++++++++++ cng/sha3_test.go | 253 ++++++++++++++++++++++++++++++++++++ 4 files changed, 591 insertions(+), 440 deletions(-) create mode 100644 cng/sha3.go create mode 100644 cng/sha3_test.go diff --git a/cng/hash.go b/cng/hash.go index b87e7e6..35a9467 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -11,7 +11,6 @@ import ( "crypto" "hash" "runtime" - "slices" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -85,27 +84,6 @@ func SHA512(p []byte) (sum [64]byte) { return } -func SHA3_256(p []byte) (sum [32]byte) { - if err := hashOneShot(bcrypt.SHA3_256_ALGORITHM, p, sum[:]); err != nil { - panic("bcrypt: SHA3_256 failed") - } - return -} - -func SHA3_384(p []byte) (sum [48]byte) { - if err := hashOneShot(bcrypt.SHA3_384_ALGORITHM, p, sum[:]); err != nil { - panic("bcrypt: SHA3_384 failed") - } - return -} - -func SHA3_512(p []byte) (sum [64]byte) { - if err := hashOneShot(bcrypt.SHA3_512_ALGORITHM, p, sum[:]); err != nil { - panic("bcrypt: SHA3_512 failed") - } - return -} - // NewMD4 returns a new MD4 hash. func NewMD4() hash.Hash { return newHashX(bcrypt.MD4_ALGORITHM, bcrypt.ALG_NONE_FLAG, nil) @@ -136,21 +114,6 @@ func NewSHA512() hash.Hash { return newHashX(bcrypt.SHA512_ALGORITHM, bcrypt.ALG_NONE_FLAG, nil) } -// NewSHA3_256 returns a new SHA256 hash. -func NewSHA3_256() hash.Hash { - return newHashX(bcrypt.SHA3_256_ALGORITHM, bcrypt.ALG_NONE_FLAG, nil) -} - -// NewSHA3_384 returns a new SHA384 hash. -func NewSHA3_384() hash.Hash { - return newHashX(bcrypt.SHA3_384_ALGORITHM, bcrypt.ALG_NONE_FLAG, nil) -} - -// NewSHA3_512 returns a new SHA512 hash. -func NewSHA3_512() hash.Hash { - return newHashX(bcrypt.SHA3_512_ALGORITHM, bcrypt.ALG_NONE_FLAG, nil) -} - type hashAlgorithm struct { handle bcrypt.ALG_HANDLE id string @@ -305,163 +268,3 @@ func (h *hashX) Sum(in []byte) []byte { } return append(in, h.buf...) } - -// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is -// supported. -func SupportsSHAKE128() bool { - _, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG) - return err == nil -} - -// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is -// supported. -func SupportsSHAKE256() bool { - _, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG) - return err == nil -} - -// SumSHAKE128 applies the SHAKE128 extendable output function to data and -// returns an output of the given length in bytes. -func SumSHAKE128(data []byte, length int) []byte { - out := make([]byte, length) - if err := hashOneShot(bcrypt.CSHAKE128_ALGORITHM, data, out); err != nil { - panic("bcrypt: CSHAKE128_ALGORITHM failed") - } - return out -} - -// SumSHAKE256 applies the SHAKE256 extendable output function to data and -// returns an output of the given length in bytes. -func SumSHAKE256(data []byte, length int) []byte { - out := make([]byte, length) - if err := hashOneShot(bcrypt.CSHAKE256_ALGORITHM, data, out); err != nil { - panic("bcrypt: CSHAKE256_ALGORITHM failed") - } - return out -} - -// SHAKE is an instance of a SHAKE extendable output function. -type SHAKE struct { - alg *hashAlgorithm - ctx bcrypt.HASH_HANDLE - n, s []byte -} - -func newShake(id string, N, S []byte) *SHAKE { - alg, err := loadHash(id, bcrypt.ALG_NONE_FLAG) - if err != nil { - panic(err) - } - h := &SHAKE{alg: alg, n: slices.Clone(N), s: slices.Clone(S)} - err = bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) - if err != nil { - panic(err) - } - if len(N) != 0 { - if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), N, 0); err != nil { - panic(err) - } - } - if len(S) != 0 { - if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), S, 0); err != nil { - panic(err) - } - } - runtime.SetFinalizer(h, (*SHAKE).finalize) - return h -} - -// NewSHAKE128 creates a new SHAKE128 XOF. -func NewSHAKE128() *SHAKE { - return newShake(bcrypt.CSHAKE128_ALGORITHM, nil, nil) -} - -// NewSHAKE256 creates a new SHAKE256 XOF. -func NewSHAKE256() *SHAKE { - return newShake(bcrypt.CSHAKE256_ALGORITHM, nil, nil) -} - -// NewCSHAKE128 creates a new cSHAKE128 XOF. -// -// N is used to define functions based on cSHAKE, it can be empty when plain -// cSHAKE is desired. S is a customization byte string used for domain -// separation. When N and S are both empty, this is equivalent to NewSHAKE128. -func NewCSHAKE128(N, S []byte) *SHAKE { - return newShake(bcrypt.CSHAKE128_ALGORITHM, N, S) -} - -// NewCSHAKE256 creates a new cSHAKE256 XOF. -// -// N is used to define functions based on cSHAKE, it can be empty when plain -// cSHAKE is desired. S is a customization byte string used for domain -// separation. When N and S are both empty, this is equivalent to NewSHAKE256. -func NewCSHAKE256(N, S []byte) *SHAKE { - return newShake(bcrypt.CSHAKE256_ALGORITHM, N, S) -} - -func (h *SHAKE) finalize() { - bcrypt.DestroyHash(h.ctx) -} - -// Write absorbs more data into the XOF's state. -// -// It panics if any output has already been read. -func (s *SHAKE) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - defer runtime.KeepAlive(s) - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(s.ctx, p[n:n+nn], 0) - n += nn - } - if err != nil { - panic(err) - } - return len(p), nil -} - -// Read squeezes more output from the XOF. -// -// Any call to Write after a call to Read will panic. -func (s *SHAKE) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - defer runtime.KeepAlive(s) - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.FinishHash(s.ctx, p[n:n+nn], bcrypt.HASH_DONT_RESET_FLAG) - n += nn - } - if err != nil { - panic(err) - } - return len(p), nil -} - -// Reset resets the XOF to its initial state. -func (s *SHAKE) Reset() { - defer runtime.KeepAlive(s) - bcrypt.DestroyHash(s.ctx) - err := bcrypt.CreateHash(s.alg.handle, &s.ctx, nil, nil, 0) - if err != nil { - panic(err) - } - if len(s.n) != 0 { - if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), s.n, 0); err != nil { - panic(err) - } - } - if len(s.s) != 0 { - if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), s.s, 0); err != nil { - panic(err) - } - } -} - -// BlockSize returns the rate of the XOF. -func (s *SHAKE) BlockSize() int { - return int(s.alg.blockSize) -} diff --git a/cng/hash_test.go b/cng/hash_test.go index 3af6045..33f6549 100644 --- a/cng/hash_test.go +++ b/cng/hash_test.go @@ -9,10 +9,8 @@ package cng_test import ( "bytes" "crypto" - "encoding/hex" "hash" "io" - "math/rand" "testing" "github.com/microsoft/go-crypto-winnative/cng" @@ -34,11 +32,11 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash { case crypto.SHA512: return cng.NewSHA512 case crypto.SHA3_256: - return cng.NewSHA3_256 + return func() hash.Hash { return cng.NewSHA3_256() } case crypto.SHA3_384: - return cng.NewSHA3_384 + return func() hash.Hash { return cng.NewSHA3_384() } case crypto.SHA3_512: - return cng.NewSHA3_512 + return func() hash.Hash { return cng.NewSHA3_512() } } return nil } @@ -158,15 +156,15 @@ func TestHash_OneShot(t *testing.T) { return b[:] }}, {crypto.SHA3_256, func(p []byte) []byte { - b := cng.SHA3_256(p) + b := cng.SumSHA3_256(p) return b[:] }}, {crypto.SHA3_384, func(p []byte) []byte { - b := cng.SHA3_384(p) + b := cng.SumSHA3_384(p) return b[:] }}, {crypto.SHA3_512, func(p []byte) []byte { - b := cng.SHA3_512(p) + b := cng.SumSHA3_512(p) return b[:] }}, } @@ -214,238 +212,3 @@ func BenchmarkSHA256_OneShot(b *testing.B) { cng.SHA256(buf) } } - -// testShakes contains functions that return *sha3.SHAKE instances for -// with output-length equal to the KAT length. -var testShakes = map[string]struct { - constructor func(N []byte, S []byte) *cng.SHAKE - defAlgoName string - defCustomStr string -}{ - // NewCSHAKE without customization produces same result as SHAKE - "SHAKE128": {cng.NewCSHAKE128, "", ""}, - "SHAKE256": {cng.NewCSHAKE256, "", ""}, - "CSHAKE128": {cng.NewCSHAKE128, "CSHAKE128", "CustomString"}, - "CSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, -} - -func skipCSHAKEIfNotSupported(t *testing.T, algo string) { - switch algo { - case "SHAKE128", "CSHAKE128": - if !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } - case "SHAKE256", "CSHAKE256": - if !cng.SupportsSHAKE256() { - t.Skip("skipping: not supported") - } - } -} - -// TestCSHAKESqueezing checks that squeezing the full output a single time produces -// the same output as repeatedly squeezing the instance. -func TestCSHAKESqueezing(t *testing.T) { - const testString = "brekeccakkeccak koax koax" - for algo, v := range testShakes { - skipCSHAKEIfNotSupported(t, algo) - - d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) - d0.Write([]byte(testString)) - ref := make([]byte, 32) - d0.Read(ref) - - d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) - d1.Write([]byte(testString)) - var multiple []byte - for range ref { - d1.Read(make([]byte, 0)) - one := make([]byte, 1) - d1.Read(one) - multiple = append(multiple, one...) - } - if !bytes.Equal(ref, multiple) { - t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref)) - } - } -} - -// sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing. -func sequentialBytes(size int) []byte { - alignmentOffset := rand.Intn(8) - result := make([]byte, size+alignmentOffset)[alignmentOffset:] - for i := range result { - result[i] = byte(i) - } - return result -} - -func TestCSHAKEReset(t *testing.T) { - out1 := make([]byte, 32) - out2 := make([]byte, 32) - - for algo, v := range testShakes { - skipCSHAKEIfNotSupported(t, algo) - - // Calculate hash for the first time - c := v.constructor(nil, []byte{0x99, 0x98}) - c.Write(sequentialBytes(0x100)) - c.Read(out1) - - // Calculate hash again - c.Reset() - c.Write(sequentialBytes(0x100)) - c.Read(out2) - - if !bytes.Equal(out1, out2) { - t.Error("\nExpected:\n", out1, "\ngot:\n", out2) - } - } -} - -func TestCSHAKEAccumulated(t *testing.T) { - t.Run("CSHAKE128", func(t *testing.T) { - if !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } - testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8, - "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") - }) - t.Run("CSHAKE256", func(t *testing.T) { - if !cng.SupportsSHAKE256() { - t.Skip("skipping: not supported") - } - testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8, - "0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef") - }) -} - -func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE, rate int64, exp string) { - rnd := newCSHAKE(nil, nil) - acc := newCSHAKE(nil, nil) - for n := 0; n < 200; n++ { - N := make([]byte, n) - rnd.Read(N) - for s := 0; s < 200; s++ { - S := make([]byte, s) - rnd.Read(S) - - c := newCSHAKE(N, S) - io.CopyN(c, rnd, 100 /* < rate */) - io.CopyN(acc, c, 200) - - c.Reset() - io.CopyN(c, rnd, rate) - io.CopyN(acc, c, 200) - - c.Reset() - io.CopyN(c, rnd, 200 /* > rate */) - io.CopyN(acc, c, 200) - } - } - out := make([]byte, 32) - acc.Read(out) - if got := hex.EncodeToString(out); got != exp { - t.Errorf("got %s, want %s", got, exp) - } -} - -func TestCSHAKELargeS(t *testing.T) { - if !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } - const s = (1<<32)/8 + 1000 // s * 8 > 2^32 - S := make([]byte, s) - rnd := cng.NewSHAKE128() - rnd.Read(S) - c := cng.NewCSHAKE128(nil, S) - io.CopyN(c, rnd, 1000) - out := make([]byte, 32) - c.Read(out) - - exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0" - if got := hex.EncodeToString(out); got != exp { - t.Errorf("got %s, want %s", got, exp) - } -} - -func TestCSHAKESum(t *testing.T) { - const testString = "hello world" - t.Run("CSHAKE128", func(t *testing.T) { - if !cng.SupportsSHAKE128() { - t.Skip("skipping: not supported") - } - h := cng.NewCSHAKE128(nil, nil) - h.Write([]byte(testString[:5])) - h.Write([]byte(testString[5:])) - want := make([]byte, 32) - h.Read(want) - got := cng.SumSHAKE128([]byte(testString), 32) - if !bytes.Equal(got, want) { - t.Errorf("got:%x want:%x", got, want) - } - }) - t.Run("CSHAKE256", func(t *testing.T) { - if !cng.SupportsSHAKE256() { - t.Skip("skipping: not supported") - } - h := cng.NewCSHAKE256(nil, nil) - h.Write([]byte(testString[:5])) - h.Write([]byte(testString[5:])) - want := make([]byte, 32) - h.Read(want) - got := cng.SumSHAKE256([]byte(testString), 32) - if !bytes.Equal(got, want) { - t.Errorf("got:%x want:%x", got, want) - } - }) -} - -// benchmarkHash tests the speed to hash num buffers of buflen each. -func benchmarkHash(b *testing.B, h hash.Hash, size, num int) { - b.StopTimer() - h.Reset() - data := sequentialBytes(size) - b.SetBytes(int64(size * num)) - b.StartTimer() - - var state []byte - for i := 0; i < b.N; i++ { - for j := 0; j < num; j++ { - h.Write(data) - } - state = h.Sum(state[:0]) - } - b.StopTimer() - h.Reset() -} - -// benchmarkCSHAKE is specialized to the Shake instances, which don't -// require a copy on reading output. -func benchmarkCSHAKE(b *testing.B, h *cng.SHAKE, size, num int) { - b.StopTimer() - h.Reset() - data := sequentialBytes(size) - d := make([]byte, 32) - - b.SetBytes(int64(size * num)) - b.StartTimer() - - for i := 0; i < b.N; i++ { - h.Reset() - for j := 0; j < num; j++ { - h.Write(data) - } - h.Read(d) - } -} - -func BenchmarkSHA3_512_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1350, 1) } -func BenchmarkSHA3_384_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_384(), 1350, 1) } -func BenchmarkSHA3_256_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_256(), 1350, 1) } - -func BenchmarkCSHAKE128_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE128(), 1350, 1) } -func BenchmarkCSHAKE256_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1350, 1) } -func BenchmarkCSHAKE256_16x(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 16, 1024) } -func BenchmarkCSHAKE256_1MiB(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1024, 1024) } - -func BenchmarkCSHA3_512_1MiB(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1024, 1024) } diff --git a/cng/sha3.go b/cng/sha3.go new file mode 100644 index 0000000..becda66 --- /dev/null +++ b/cng/sha3.go @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build windows +// +build windows + +package cng + +import ( + "hash" + "runtime" + "slices" + "unsafe" + + "github.com/microsoft/go-crypto-winnative/internal/bcrypt" +) + +// SumSHA3_256 returns the SHA3-256 checksum of the data. +func SumSHA3_256(p []byte) (sum [32]byte) { + if err := hashOneShot(bcrypt.SHA3_256_ALGORITHM, p, sum[:]); err != nil { + panic("bcrypt: SHA3_256 failed") + } + return +} + +// SumSHA3_384 returns the SHA3-384 checksum of the data. +func SumSHA3_384(p []byte) (sum [48]byte) { + if err := hashOneShot(bcrypt.SHA3_384_ALGORITHM, p, sum[:]); err != nil { + panic("bcrypt: SHA3_384 failed") + } + return +} + +// SumSHA3_512 returns the SHA3-512 checksum of the data. +func SumSHA3_512(p []byte) (sum [64]byte) { + if err := hashOneShot(bcrypt.SHA3_512_ALGORITHM, p, sum[:]); err != nil { + panic("bcrypt: SHA3_512 failed") + } + return +} + +// SumSHAKE128 applies the SHAKE128 extendable output function to data and +// returns an output of the given length in bytes. +func SumSHAKE128(data []byte, length int) []byte { + out := make([]byte, length) + if err := hashOneShot(bcrypt.CSHAKE128_ALGORITHM, data, out); err != nil { + panic("bcrypt: CSHAKE128_ALGORITHM failed") + } + return out +} + +// SumSHAKE256 applies the SHAKE256 extendable output function to data and +// returns an output of the given length in bytes. +func SumSHAKE256(data []byte, length int) []byte { + out := make([]byte, length) + if err := hashOneShot(bcrypt.CSHAKE256_ALGORITHM, data, out); err != nil { + panic("bcrypt: CSHAKE256_ALGORITHM failed") + } + return out +} + +// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is +// supported. +func SupportsSHAKE128() bool { + _, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG) + return err == nil +} + +var _ hash.Hash = (*DigestSHA3)(nil) + +// DigestSHA3 is the [sha3.SHA3] implementation using the CNG API. +type DigestSHA3 struct { + alg *hashAlgorithm + ctx bcrypt.HASH_HANDLE +} + +// newDigestSHA3 returns a new hash.Hash using the specified algorithm. +func newDigestSHA3(id string) *DigestSHA3 { + alg, err := loadHash(id, bcrypt.ALG_NONE_FLAG) + if err != nil { + panic(err) + } + h := &DigestSHA3{alg: alg} + // Don't call bcrypt.CreateHash yet, it would be wasteful + // if the caller only wants to know the hash type. This + // is a common pattern in this package, as some functions + // accept a `func() hash.Hash` parameter and call it just + // to know the hash type. + return h +} + +func (h *DigestSHA3) finalize() { + bcrypt.DestroyHash(h.ctx) +} + +func (h *DigestSHA3) init() { + defer runtime.KeepAlive(h) + if h.ctx != 0 { + return + } + err := bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + runtime.SetFinalizer(h, (*DigestSHA3).finalize) +} + +func (h *DigestSHA3) Clone() (hash.Hash, error) { + defer runtime.KeepAlive(h) + h2 := &DigestSHA3{alg: h.alg} + if h.ctx != 0 { + err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0) + if err != nil { + return nil, err + } + runtime.SetFinalizer(h2, (*DigestSHA3).finalize) + } + return h2, nil +} + +func (h *DigestSHA3) Reset() { + defer runtime.KeepAlive(h) + if h.ctx != 0 { + bcrypt.DestroyHash(h.ctx) + h.ctx = 0 + runtime.SetFinalizer(h, nil) + } +} + +func (h *DigestSHA3) Write(p []byte) (n int, err error) { + defer runtime.KeepAlive(h) + h.init() + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(h.ctx, p[n:n+nn], 0) + n += nn + } + if err != nil { + // hash.Hash interface mandates Write should never return an error. + panic(err) + } + return len(p), nil +} + +func (h *DigestSHA3) WriteString(s string) (int, error) { + defer runtime.KeepAlive(h) + return h.Write(unsafe.Slice(unsafe.StringData(s), len(s))) +} + +func (h *DigestSHA3) WriteByte(c byte) error { + defer runtime.KeepAlive(h) + h.init() + err := bcrypt.HashDataRaw(h.ctx, &c, 1, 0) + if err != nil { + // hash.Hash interface mandates Write should never return an error. + panic(err) + } + return nil +} + +func (h *DigestSHA3) Size() int { + return int(h.alg.size) +} + +func (h *DigestSHA3) BlockSize() int { + return int(h.alg.blockSize) +} + +func (h *DigestSHA3) Sum(in []byte) []byte { + defer runtime.KeepAlive(h) + h.init() + var ctx2 bcrypt.HASH_HANDLE + err := bcrypt.DuplicateHash(h.ctx, &ctx2, nil, 0) + if err != nil { + panic(err) + } + defer bcrypt.DestroyHash(ctx2) + buf := make([]byte, h.alg.size, 64) // explicit cap to allow stack allocation + err = bcrypt.FinishHash(ctx2, buf, 0) + if err != nil { + panic(err) + } + return append(in, buf...) +} + +// NewSHA3_256 returns a new SHA256 hash. +func NewSHA3_256() *DigestSHA3 { + return newDigestSHA3(bcrypt.SHA3_256_ALGORITHM) +} + +// NewSHA3_384 returns a new SHA384 hash. +func NewSHA3_384() *DigestSHA3 { + return newDigestSHA3(bcrypt.SHA3_384_ALGORITHM) +} + +// NewSHA3_512 returns a new SHA512 hash. +func NewSHA3_512() *DigestSHA3 { + return newDigestSHA3(bcrypt.SHA3_512_ALGORITHM) +} + +// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is +// supported. +func SupportsSHAKE256() bool { + _, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG) + return err == nil +} + +// SHAKE is an instance of a SHAKE extendable output function. +type SHAKE struct { + alg *hashAlgorithm + ctx bcrypt.HASH_HANDLE + n, s []byte +} + +func newShake(id string, N, S []byte) *SHAKE { + alg, err := loadHash(id, bcrypt.ALG_NONE_FLAG) + if err != nil { + panic(err) + } + h := &SHAKE{alg: alg, n: slices.Clone(N), s: slices.Clone(S)} + err = bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + if len(N) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), N, 0); err != nil { + panic(err) + } + } + if len(S) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(h.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), S, 0); err != nil { + panic(err) + } + } + runtime.SetFinalizer(h, (*SHAKE).finalize) + return h +} + +// NewSHAKE128 creates a new SHAKE128 XOF. +func NewSHAKE128() *SHAKE { + return newShake(bcrypt.CSHAKE128_ALGORITHM, nil, nil) +} + +// NewSHAKE256 creates a new SHAKE256 XOF. +func NewSHAKE256() *SHAKE { + return newShake(bcrypt.CSHAKE256_ALGORITHM, nil, nil) +} + +// NewCSHAKE128 creates a new cSHAKE128 XOF. +// +// N is used to define functions based on cSHAKE, it can be empty when plain +// cSHAKE is desired. S is a customization byte string used for domain +// separation. When N and S are both empty, this is equivalent to NewSHAKE128. +func NewCSHAKE128(N, S []byte) *SHAKE { + return newShake(bcrypt.CSHAKE128_ALGORITHM, N, S) +} + +// NewCSHAKE256 creates a new cSHAKE256 XOF. +// +// N is used to define functions based on cSHAKE, it can be empty when plain +// cSHAKE is desired. S is a customization byte string used for domain +// separation. When N and S are both empty, this is equivalent to NewSHAKE256. +func NewCSHAKE256(N, S []byte) *SHAKE { + return newShake(bcrypt.CSHAKE256_ALGORITHM, N, S) +} + +func (h *SHAKE) finalize() { + bcrypt.DestroyHash(h.ctx) +} + +// Write absorbs more data into the XOF's state. +// +// It panics if any output has already been read. +func (s *SHAKE) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + defer runtime.KeepAlive(s) + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(s.ctx, p[n:n+nn], 0) + n += nn + } + if err != nil { + panic(err) + } + return len(p), nil +} + +// Read squeezes more output from the XOF. +// +// Any call to Write after a call to Read will panic. +func (s *SHAKE) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + defer runtime.KeepAlive(s) + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.FinishHash(s.ctx, p[n:n+nn], bcrypt.HASH_DONT_RESET_FLAG) + n += nn + } + if err != nil { + panic(err) + } + return len(p), nil +} + +// Reset resets the XOF to its initial state. +func (s *SHAKE) Reset() { + defer runtime.KeepAlive(s) + bcrypt.DestroyHash(s.ctx) + err := bcrypt.CreateHash(s.alg.handle, &s.ctx, nil, nil, 0) + if err != nil { + panic(err) + } + if len(s.n) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), s.n, 0); err != nil { + panic(err) + } + } + if len(s.s) != 0 { + if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), s.s, 0); err != nil { + panic(err) + } + } +} + +// BlockSize returns the rate of the XOF. +func (s *SHAKE) BlockSize() int { + return int(s.alg.blockSize) +} diff --git a/cng/sha3_test.go b/cng/sha3_test.go new file mode 100644 index 0000000..121efea --- /dev/null +++ b/cng/sha3_test.go @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build windows +// +build windows + +package cng_test + +import ( + "bytes" + "encoding/hex" + "hash" + "io" + "math/rand" + "testing" + + "github.com/microsoft/go-crypto-winnative/cng" +) + +// testShakes contains functions that return *sha3.SHAKE instances for +// with output-length equal to the KAT length. +var testShakes = map[string]struct { + constructor func(N []byte, S []byte) *cng.SHAKE + defAlgoName string + defCustomStr string +}{ + // NewCSHAKE without customization produces same result as SHAKE + "SHAKE128": {cng.NewCSHAKE128, "", ""}, + "SHAKE256": {cng.NewCSHAKE256, "", ""}, + "CSHAKE128": {cng.NewCSHAKE128, "CSHAKE128", "CustomString"}, + "CSHAKE256": {cng.NewCSHAKE256, "CSHAKE256", "CustomString"}, +} + +func skipCSHAKEIfNotSupported(t *testing.T, algo string) { + switch algo { + case "SHAKE128", "CSHAKE128": + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + case "SHAKE256", "CSHAKE256": + if !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } + } +} + +// TestCSHAKESqueezing checks that squeezing the full output a single time produces +// the same output as repeatedly squeezing the instance. +func TestCSHAKESqueezing(t *testing.T) { + const testString = "brekeccakkeccak koax koax" + for algo, v := range testShakes { + skipCSHAKEIfNotSupported(t, algo) + + d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d0.Write([]byte(testString)) + ref := make([]byte, 32) + d0.Read(ref) + + d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d1.Write([]byte(testString)) + var multiple []byte + for range ref { + d1.Read(make([]byte, 0)) + one := make([]byte, 1) + d1.Read(one) + multiple = append(multiple, one...) + } + if !bytes.Equal(ref, multiple) { + t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref)) + } + } +} + +// sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing. +func sequentialBytes(size int) []byte { + alignmentOffset := rand.Intn(8) + result := make([]byte, size+alignmentOffset)[alignmentOffset:] + for i := range result { + result[i] = byte(i) + } + return result +} + +func TestCSHAKEReset(t *testing.T) { + out1 := make([]byte, 32) + out2 := make([]byte, 32) + + for algo, v := range testShakes { + skipCSHAKEIfNotSupported(t, algo) + + // Calculate hash for the first time + c := v.constructor(nil, []byte{0x99, 0x98}) + c.Write(sequentialBytes(0x100)) + c.Read(out1) + + // Calculate hash again + c.Reset() + c.Write(sequentialBytes(0x100)) + c.Read(out2) + + if !bytes.Equal(out1, out2) { + t.Error("\nExpected:\n", out1, "\ngot:\n", out2) + } + } +} + +func TestCSHAKEAccumulated(t *testing.T) { + t.Run("CSHAKE128", func(t *testing.T) { + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8, + "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") + }) + t.Run("CSHAKE256", func(t *testing.T) { + if !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } + testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8, + "0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef") + }) +} + +func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE, rate int64, exp string) { + rnd := newCSHAKE(nil, nil) + acc := newCSHAKE(nil, nil) + for n := 0; n < 200; n++ { + N := make([]byte, n) + rnd.Read(N) + for s := 0; s < 200; s++ { + S := make([]byte, s) + rnd.Read(S) + + c := newCSHAKE(N, S) + io.CopyN(c, rnd, 100 /* < rate */) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, rate) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, 200 /* > rate */) + io.CopyN(acc, c, 200) + } + } + out := make([]byte, 32) + acc.Read(out) + if got := hex.EncodeToString(out); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKELargeS(t *testing.T) { + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + const s = (1<<32)/8 + 1000 // s * 8 > 2^32 + S := make([]byte, s) + rnd := cng.NewSHAKE128() + rnd.Read(S) + c := cng.NewCSHAKE128(nil, S) + io.CopyN(c, rnd, 1000) + out := make([]byte, 32) + c.Read(out) + + exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0" + if got := hex.EncodeToString(out); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKESum(t *testing.T) { + const testString = "hello world" + t.Run("CSHAKE128", func(t *testing.T) { + if !cng.SupportsSHAKE128() { + t.Skip("skipping: not supported") + } + h := cng.NewCSHAKE128(nil, nil) + h.Write([]byte(testString[:5])) + h.Write([]byte(testString[5:])) + want := make([]byte, 32) + h.Read(want) + got := cng.SumSHAKE128([]byte(testString), 32) + if !bytes.Equal(got, want) { + t.Errorf("got:%x want:%x", got, want) + } + }) + t.Run("CSHAKE256", func(t *testing.T) { + if !cng.SupportsSHAKE256() { + t.Skip("skipping: not supported") + } + h := cng.NewCSHAKE256(nil, nil) + h.Write([]byte(testString[:5])) + h.Write([]byte(testString[5:])) + want := make([]byte, 32) + h.Read(want) + got := cng.SumSHAKE256([]byte(testString), 32) + if !bytes.Equal(got, want) { + t.Errorf("got:%x want:%x", got, want) + } + }) +} + +// benchmarkHash tests the speed to hash num buffers of buflen each. +func benchmarkHash(b *testing.B, h hash.Hash, size, num int) { + b.StopTimer() + h.Reset() + data := sequentialBytes(size) + b.SetBytes(int64(size * num)) + b.StartTimer() + + var state []byte + for i := 0; i < b.N; i++ { + for j := 0; j < num; j++ { + h.Write(data) + } + state = h.Sum(state[:0]) + } + b.StopTimer() + h.Reset() +} + +// benchmarkCSHAKE is specialized to the Shake instances, which don't +// require a copy on reading output. +func benchmarkCSHAKE(b *testing.B, h *cng.SHAKE, size, num int) { + b.StopTimer() + h.Reset() + data := sequentialBytes(size) + d := make([]byte, 32) + + b.SetBytes(int64(size * num)) + b.StartTimer() + + for i := 0; i < b.N; i++ { + h.Reset() + for j := 0; j < num; j++ { + h.Write(data) + } + h.Read(d) + } +} + +func BenchmarkSHA3_512_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1350, 1) } +func BenchmarkSHA3_384_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_384(), 1350, 1) } +func BenchmarkSHA3_256_MTU(b *testing.B) { benchmarkHash(b, cng.NewSHA3_256(), 1350, 1) } + +func BenchmarkCSHAKE128_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE128(), 1350, 1) } +func BenchmarkCSHAKE256_MTU(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1350, 1) } +func BenchmarkCSHAKE256_16x(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 16, 1024) } +func BenchmarkCSHAKE256_1MiB(b *testing.B) { benchmarkCSHAKE(b, cng.NewSHAKE256(), 1024, 1024) } + +func BenchmarkCSHA3_512_1MiB(b *testing.B) { benchmarkHash(b, cng.NewSHA3_512(), 1024, 1024) } From be8db5975729adde59e28ccf129fbb8373604e04 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 19 Dec 2024 15:09:52 +0100 Subject: [PATCH 06/10] speed-up sha3 reset --- cng/sha3.go | 44 +++++++++++++------------------ internal/bcrypt/bcrypt_windows.go | 1 + 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/cng/sha3.go b/cng/sha3.go index becda66..6136b2a 100644 --- a/cng/sha3.go +++ b/cng/sha3.go @@ -9,12 +9,14 @@ package cng import ( "hash" "runtime" - "slices" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) +// maxSHA3Size is the size of SHA3_512, the largest SHA3 hash we support. +const maxSHA3Size = 64 + // SumSHA3_256 returns the SHA3-256 checksum of the data. func SumSHA3_256(p []byte) (sum [32]byte) { if err := hashOneShot(bcrypt.SHA3_256_ALGORITHM, p, sum[:]); err != nil { @@ -98,7 +100,7 @@ func (h *DigestSHA3) init() { if h.ctx != 0 { return } - err := bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) + err := bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, bcrypt.HASH_REUSABLE_FLAG) if err != nil { panic(err) } @@ -121,9 +123,13 @@ func (h *DigestSHA3) Clone() (hash.Hash, error) { func (h *DigestSHA3) Reset() { defer runtime.KeepAlive(h) if h.ctx != 0 { - bcrypt.DestroyHash(h.ctx) - h.ctx = 0 - runtime.SetFinalizer(h, nil) + // bcrypt.FinishHash expects the output buffer to match the hash size. + // We don't care about the output, so we just pass a stack-allocated buffer + // that is large enough to hold the largest hash size we support. + var discard [maxSHA3Size]byte + if err := bcrypt.FinishHash(h.ctx, discard[:h.Size()], 0); err != nil { + panic(err) + } } } @@ -175,7 +181,7 @@ func (h *DigestSHA3) Sum(in []byte) []byte { panic(err) } defer bcrypt.DestroyHash(ctx2) - buf := make([]byte, h.alg.size, 64) // explicit cap to allow stack allocation + buf := make([]byte, h.alg.size, maxSHA3Size) // explicit cap to allow stack allocation err = bcrypt.FinishHash(ctx2, buf, 0) if err != nil { panic(err) @@ -207,9 +213,8 @@ func SupportsSHAKE256() bool { // SHAKE is an instance of a SHAKE extendable output function. type SHAKE struct { - alg *hashAlgorithm - ctx bcrypt.HASH_HANDLE - n, s []byte + ctx bcrypt.HASH_HANDLE + blockSize uint32 } func newShake(id string, N, S []byte) *SHAKE { @@ -217,8 +222,8 @@ func newShake(id string, N, S []byte) *SHAKE { if err != nil { panic(err) } - h := &SHAKE{alg: alg, n: slices.Clone(N), s: slices.Clone(S)} - err = bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, nil, 0) + h := &SHAKE{blockSize: alg.blockSize} + err = bcrypt.CreateHash(alg.handle, &h.ctx, nil, nil, bcrypt.HASH_REUSABLE_FLAG) if err != nil { panic(err) } @@ -309,24 +314,13 @@ func (s *SHAKE) Read(p []byte) (n int, err error) { // Reset resets the XOF to its initial state. func (s *SHAKE) Reset() { defer runtime.KeepAlive(s) - bcrypt.DestroyHash(s.ctx) - err := bcrypt.CreateHash(s.alg.handle, &s.ctx, nil, nil, 0) - if err != nil { + var discard [1]byte + if err := bcrypt.FinishHash(s.ctx, discard[:], 0); err != nil { panic(err) } - if len(s.n) != 0 { - if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.FUNCTION_NAME_STRING), s.n, 0); err != nil { - panic(err) - } - } - if len(s.s) != 0 { - if err := bcrypt.SetProperty(bcrypt.HANDLE(s.ctx), utf16PtrFromString(bcrypt.CUSTOMIZATION_STRING), s.s, 0); err != nil { - panic(err) - } - } } // BlockSize returns the rate of the XOF. func (s *SHAKE) BlockSize() int { - return int(s.alg.blockSize) + return int(s.blockSize) } diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index e3255e2..a31b83a 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -119,6 +119,7 @@ const ( const ( HASH_DONT_RESET_FLAG = 0x00000001 + HASH_REUSABLE_FLAG = 0x00000020 ) const ( From 74f51ccd1c54283bfed2e1b46e3d7392d560a75f Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 20 Dec 2024 10:54:17 +0100 Subject: [PATCH 07/10] deduplicate code --- cng/hash.go | 142 ++++++++++++++++++++++++++++++---------------------- cng/sha3.go | 66 +++++------------------- 2 files changed, 95 insertions(+), 113 deletions(-) diff --git a/cng/hash.go b/cng/hash.go index 35a9467..de06c90 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -16,6 +16,9 @@ import ( "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) +// maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. +const maxHashSize = 64 + // SupportsHash returns true if a hash.Hash implementation is supported for h. func SupportsHash(h crypto.Hash) bool { switch h { @@ -145,11 +148,11 @@ func hashToID(h hash.Hash) string { return hx.alg.id } +// hashX implements [hash.Hash]. type hashX struct { - alg *hashAlgorithm - _ctx bcrypt.HASH_HANDLE // access it using withCtx + alg *hashAlgorithm + ctx bcrypt.HASH_HANDLE - buf []byte key []byte } @@ -160,37 +163,34 @@ func newHashX(id string, flag bcrypt.AlgorithmProviderFlags, key []byte) *hashX panic(err) } h := &hashX{alg: alg, key: bytes.Clone(key)} - // Don't allocate hx.buf nor call bcrypt.CreateHash yet, - // which would be wasteful if the caller only wants to know - // the hash type. This is a common pattern in this package, - // as some functions accept a `func() hash.Hash` parameter - // and call it just to know the hash type. - runtime.SetFinalizer(h, (*hashX).finalize) + // Don't call bcrypt.CreateHash yet, it would be wasteful + // if the caller only wants to know the hash type. This + // is a common pattern in this package, as some functions + // accept a `func() hash.Hash` parameter and call it just + // to know the hash type. return h } func (h *hashX) finalize() { - if h._ctx != 0 { - bcrypt.DestroyHash(h._ctx) - } + bcrypt.DestroyHash(h.ctx) } -func (h *hashX) withCtx(fn func(ctx bcrypt.HASH_HANDLE) error) error { +func (h *hashX) init() { defer runtime.KeepAlive(h) - if h._ctx == 0 { - err := bcrypt.CreateHash(h.alg.handle, &h._ctx, nil, h.key, 0) - if err != nil { - panic(err) - } + if h.ctx != 0 { + return + } + err := bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, h.key, bcrypt.HASH_REUSABLE_FLAG) + if err != nil { + panic(err) } - return fn(h._ctx) + runtime.SetFinalizer(h, (*hashX).finalize) } func (h *hashX) Clone() (hash.Hash, error) { + defer runtime.KeepAlive(h) h2 := &hashX{alg: h.alg, key: bytes.Clone(h.key)} - err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - return bcrypt.DuplicateHash(ctx, &h2._ctx, nil, 0) - }) + err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0) if err != nil { return nil, err } @@ -199,49 +199,37 @@ func (h *hashX) Clone() (hash.Hash, error) { } func (h *hashX) Reset() { - if h._ctx != 0 { - bcrypt.DestroyHash(h._ctx) - h._ctx = 0 + defer runtime.KeepAlive(h) + if h.ctx != 0 { + hashReset(h.ctx, h.Size()) } } func (h *hashX) Write(p []byte) (n int, err error) { - err = h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(h._ctx, p[n:n+nn], 0) - n += nn - } - return err - }) - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + defer runtime.KeepAlive(h) + h.init() + hashData(h.ctx, p) return len(p), nil } func (h *hashX) WriteString(s string) (int, error) { - // TODO: use unsafe.StringData once we drop support - // for go1.19 and earlier. - hdr := (*struct { - Data *byte - Len int - })(unsafe.Pointer(&s)) - return h.Write(unsafe.Slice(hdr.Data, len(s))) + defer runtime.KeepAlive(h) + return h.Write(unsafe.Slice(unsafe.StringData(s), len(s))) } func (h *hashX) WriteByte(c byte) error { - err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - return bcrypt.HashDataRaw(h._ctx, &c, 1, 0) - }) - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + defer runtime.KeepAlive(h) + h.init() + hashByte(h.ctx, c) return nil } +func (h *hashX) Sum(in []byte) []byte { + defer runtime.KeepAlive(h) + h.init() + return hashSum(h.ctx, h.Size(), in) +} + func (h *hashX) Size() int { return int(h.alg.size) } @@ -250,21 +238,55 @@ func (h *hashX) BlockSize() int { return int(h.alg.blockSize) } -func (h *hashX) Sum(in []byte) []byte { +// hashData writes p to ctx. It panics on error. +func hashData(ctx bcrypt.HASH_HANDLE, p []byte) { + var n int + var err error + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(ctx, p[n:n+nn], 0) + n += nn + } + if err != nil { + panic(err) + } +} + +// hashByte writes c to ctx. It panics on error. +func hashByte(ctx bcrypt.HASH_HANDLE, c byte) { + err := bcrypt.HashDataRaw(ctx, &c, 1, 0) + if err != nil { + panic(err) + } +} + +// hashSum writes the hash of ctx to in and returns the result. +// size is the size of the hash output. +// It panics on error. +func hashSum(ctx bcrypt.HASH_HANDLE, size int, in []byte) []byte { var ctx2 bcrypt.HASH_HANDLE - err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - return bcrypt.DuplicateHash(ctx, &ctx2, nil, 0) - }) + err := bcrypt.DuplicateHash(ctx, &ctx2, nil, 0) if err != nil { panic(err) } defer bcrypt.DestroyHash(ctx2) - if h.buf == nil { - h.buf = make([]byte, h.alg.size) - } - err = bcrypt.FinishHash(ctx2, h.buf, 0) + buf := make([]byte, size, maxHashSize) // explicit cap to allow stack allocation + err = bcrypt.FinishHash(ctx2, buf, 0) if err != nil { panic(err) } - return append(in, h.buf...) + return append(in, buf...) +} + +// hashReset resets the hash state of ctx. +// size is the size of the hash output. +// It panics on error. +func hashReset(ctx bcrypt.HASH_HANDLE, size int) { + // bcrypt.FinishHash expects the output buffer to match the hash size. + // We don't care about the output, so we just pass a stack-allocated buffer + // that is large enough to hold the largest hash size we support. + var discard [maxHashSize]byte + if err := bcrypt.FinishHash(ctx, discard[:size], 0); err != nil { + panic(err) + } } diff --git a/cng/sha3.go b/cng/sha3.go index 6136b2a..f3d4986 100644 --- a/cng/sha3.go +++ b/cng/sha3.go @@ -14,9 +14,6 @@ import ( "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -// maxSHA3Size is the size of SHA3_512, the largest SHA3 hash we support. -const maxSHA3Size = 64 - // SumSHA3_256 returns the SHA3-256 checksum of the data. func SumSHA3_256(p []byte) (sum [32]byte) { if err := hashOneShot(bcrypt.SHA3_256_ALGORITHM, p, sum[:]); err != nil { @@ -123,28 +120,14 @@ func (h *DigestSHA3) Clone() (hash.Hash, error) { func (h *DigestSHA3) Reset() { defer runtime.KeepAlive(h) if h.ctx != 0 { - // bcrypt.FinishHash expects the output buffer to match the hash size. - // We don't care about the output, so we just pass a stack-allocated buffer - // that is large enough to hold the largest hash size we support. - var discard [maxSHA3Size]byte - if err := bcrypt.FinishHash(h.ctx, discard[:h.Size()], 0); err != nil { - panic(err) - } + hashReset(h.ctx, h.Size()) } } func (h *DigestSHA3) Write(p []byte) (n int, err error) { defer runtime.KeepAlive(h) h.init() - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(h.ctx, p[n:n+nn], 0) - n += nn - } - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + hashData(h.ctx, p) return len(p), nil } @@ -156,14 +139,16 @@ func (h *DigestSHA3) WriteString(s string) (int, error) { func (h *DigestSHA3) WriteByte(c byte) error { defer runtime.KeepAlive(h) h.init() - err := bcrypt.HashDataRaw(h.ctx, &c, 1, 0) - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + hashByte(h.ctx, c) return nil } +func (h *DigestSHA3) Sum(in []byte) []byte { + defer runtime.KeepAlive(h) + h.init() + return hashSum(h.ctx, h.Size(), in) +} + func (h *DigestSHA3) Size() int { return int(h.alg.size) } @@ -172,23 +157,6 @@ func (h *DigestSHA3) BlockSize() int { return int(h.alg.blockSize) } -func (h *DigestSHA3) Sum(in []byte) []byte { - defer runtime.KeepAlive(h) - h.init() - var ctx2 bcrypt.HASH_HANDLE - err := bcrypt.DuplicateHash(h.ctx, &ctx2, nil, 0) - if err != nil { - panic(err) - } - defer bcrypt.DestroyHash(ctx2) - buf := make([]byte, h.alg.size, maxSHA3Size) // explicit cap to allow stack allocation - err = bcrypt.FinishHash(ctx2, buf, 0) - if err != nil { - panic(err) - } - return append(in, buf...) -} - // NewSHA3_256 returns a new SHA256 hash. func NewSHA3_256() *DigestSHA3 { return newDigestSHA3(bcrypt.SHA3_256_ALGORITHM) @@ -281,14 +249,7 @@ func (s *SHAKE) Write(p []byte) (n int, err error) { return 0, nil } defer runtime.KeepAlive(s) - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(s.ctx, p[n:n+nn], 0) - n += nn - } - if err != nil { - panic(err) - } + hashData(s.ctx, p) return len(p), nil } @@ -314,10 +275,9 @@ func (s *SHAKE) Read(p []byte) (n int, err error) { // Reset resets the XOF to its initial state. func (s *SHAKE) Reset() { defer runtime.KeepAlive(s) - var discard [1]byte - if err := bcrypt.FinishHash(s.ctx, discard[:], 0); err != nil { - panic(err) - } + // SHAKE has a variable size, CNG doesn't change the size of the hash + // when resetting, so we can pass a small value here. + hashReset(s.ctx, 1) } // BlockSize returns the rate of the XOF. From 9b21315fdc56cae1805239ef3028ec2eace2fc16 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 20 Dec 2024 17:11:21 +0100 Subject: [PATCH 08/10] fix hashX.Clone --- cng/hash.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cng/hash.go b/cng/hash.go index de06c90..d8e951d 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -190,11 +190,13 @@ func (h *hashX) init() { func (h *hashX) Clone() (hash.Hash, error) { defer runtime.KeepAlive(h) h2 := &hashX{alg: h.alg, key: bytes.Clone(h.key)} - err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0) - if err != nil { - return nil, err + if h.ctx != 0 { + err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0) + if err != nil { + return nil, err + } + runtime.SetFinalizer(h2, (*hashX).finalize) } - runtime.SetFinalizer(h2, (*hashX).finalize) return h2, nil } From 34746fec85c5d85f6e1924928da45912462d6d1c Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Fri, 3 Jan 2025 15:49:09 +0100 Subject: [PATCH 09/10] Update cng/hash.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cng/hash.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cng/hash.go b/cng/hash.go index d8e951d..b1b42f9 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -16,7 +16,7 @@ import ( "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -// maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. +// maxHashSize is the size of SHA512 and SHA3_512, the largest hashes we support. const maxHashSize = 64 // SupportsHash returns true if a hash.Hash implementation is supported for h. From b09ee976be67c1f832ac1bcfef3d787c30cff5ad Mon Sep 17 00:00:00 2001 From: qmuntal Date: Tue, 7 Jan 2025 09:49:04 +0100 Subject: [PATCH 10/10] PR comments --- cng/sha3.go | 14 +++++++------- cng/sha3_test.go | 3 +++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/cng/sha3.go b/cng/sha3.go index f3d4986..1504b9b 100644 --- a/cng/sha3.go +++ b/cng/sha3.go @@ -65,6 +65,13 @@ func SupportsSHAKE128() bool { return err == nil } +// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is +// supported. +func SupportsSHAKE256() bool { + _, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG) + return err == nil +} + var _ hash.Hash = (*DigestSHA3)(nil) // DigestSHA3 is the [sha3.SHA3] implementation using the CNG API. @@ -172,13 +179,6 @@ func NewSHA3_512() *DigestSHA3 { return newDigestSHA3(bcrypt.SHA3_512_ALGORITHM) } -// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is -// supported. -func SupportsSHAKE256() bool { - _, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG) - return err == nil -} - // SHAKE is an instance of a SHAKE extendable output function. type SHAKE struct { ctx bcrypt.HASH_HANDLE diff --git a/cng/sha3_test.go b/cng/sha3_test.go index 121efea..c52f48b 100644 --- a/cng/sha3_test.go +++ b/cng/sha3_test.go @@ -72,6 +72,9 @@ func TestCSHAKESqueezing(t *testing.T) { } // sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing. +// +// The alignment of each slice is intentionally randomized to detect alignment +// issues in the implementation. See https://golang.org/issue/37644. func sequentialBytes(size int) []byte { alignmentOffset := rand.Intn(8) result := make([]byte, size+alignmentOffset)[alignmentOffset:]