Skip to content

Commit

Permalink
Merge pull request #82 from microsoft/shakesup
Browse files Browse the repository at this point in the history
Deduplicate SupportsSHAKE
  • Loading branch information
qmuntal authored Jan 9, 2025
2 parents b49854c + 0fc0aaf commit 0e4a51c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 27 deletions.
24 changes: 13 additions & 11 deletions cng/sha3.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,19 @@ func SumSHAKE256(data []byte, length int) []byte {
return out
}

// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is
// supported.
func SupportsSHAKE128() bool {
_, err := loadHash(bcrypt.CSHAKE128_ALGORITHM, bcrypt.ALG_NONE_FLAG)
return err == nil
}

// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is
// supported.
func SupportsSHAKE256() bool {
_, err := loadHash(bcrypt.CSHAKE256_ALGORITHM, bcrypt.ALG_NONE_FLAG)
// SupportsSHAKE returns true if the SHAKE and CSHAKE extendable output functions
// with the given securityBits are supported.
func SupportsSHAKE(securityBits int) bool {
var id string
switch securityBits {
case 128:
id = bcrypt.CSHAKE128_ALGORITHM
case 256:
id = bcrypt.CSHAKE256_ALGORITHM
default:
return false
}
_, err := loadHash(id, bcrypt.ALG_NONE_FLAG)
return err == nil
}

Expand Down
32 changes: 16 additions & 16 deletions cng/sha3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ var testShakes = map[string]struct {
}

func skipCSHAKEIfNotSupported(t *testing.T, algo string) {
var supported bool
switch algo {
case "SHAKE128", "CSHAKE128":
if !cng.SupportsSHAKE128() {
t.Skip("skipping: not supported")
}
supported = cng.SupportsSHAKE(128)
case "SHAKE256", "CSHAKE256":
if !cng.SupportsSHAKE256() {
t.Skip("skipping: not supported")
}
supported = cng.SupportsSHAKE(256)
}
if !supported {
t.Skip("skipping: not supported")
}
}

Expand Down Expand Up @@ -109,14 +109,14 @@ func TestCSHAKEReset(t *testing.T) {

func TestCSHAKEAccumulated(t *testing.T) {
t.Run("CSHAKE128", func(t *testing.T) {
if !cng.SupportsSHAKE128() {
if !cng.SupportsSHAKE(128) {
t.Skip("skipping: not supported")
}
testCSHAKEAccumulated(t, cng.NewCSHAKE128, (1600-256)/8,
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
})
t.Run("CSHAKE256", func(t *testing.T) {
if !cng.SupportsSHAKE256() {
if !cng.SupportsSHAKE(256) {
t.Skip("skipping: not supported")
}
testCSHAKEAccumulated(t, cng.NewCSHAKE256, (1600-512)/8,
Expand Down Expand Up @@ -155,7 +155,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *cng.SHAKE,
}

func TestCSHAKELargeS(t *testing.T) {
if !cng.SupportsSHAKE128() {
if !cng.SupportsSHAKE(128) {
t.Skip("skipping: not supported")
}
const s = (1<<32)/8 + 1000 // s * 8 > 2^32
Expand All @@ -173,13 +173,13 @@ func TestCSHAKELargeS(t *testing.T) {
}
}

func TestCSHAKESum(t *testing.T) {
func TestSHAKESum(t *testing.T) {
const testString = "hello world"
t.Run("CSHAKE128", func(t *testing.T) {
if !cng.SupportsSHAKE128() {
t.Run("SHAKE128", func(t *testing.T) {
if !cng.SupportsSHAKE(128) {
t.Skip("skipping: not supported")
}
h := cng.NewCSHAKE128(nil, nil)
h := cng.NewSHAKE128()
h.Write([]byte(testString[:5]))
h.Write([]byte(testString[5:]))
want := make([]byte, 32)
Expand All @@ -189,11 +189,11 @@ func TestCSHAKESum(t *testing.T) {
t.Errorf("got:%x want:%x", got, want)
}
})
t.Run("CSHAKE256", func(t *testing.T) {
if !cng.SupportsSHAKE256() {
t.Run("SHAKE256", func(t *testing.T) {
if !cng.SupportsSHAKE(256) {
t.Skip("skipping: not supported")
}
h := cng.NewCSHAKE256(nil, nil)
h := cng.NewSHAKE256()
h.Write([]byte(testString[:5]))
h.Write([]byte(testString[5:]))
want := make([]byte, 32)
Expand Down

0 comments on commit 0e4a51c

Please sign in to comment.