diff --git a/hasher.go b/hasher.go index cb5b3fb8..f0d174ff 100644 --- a/hasher.go +++ b/hasher.go @@ -154,9 +154,8 @@ func (n *NmtHasher) BlockSize() int { func (n *NmtHasher) EmptyRoot() []byte { n.baseHasher.Reset() - emptyNs := bytes.Repeat([]byte{0}, int(n.NamespaceLen)) h := n.baseHasher.Sum(nil) - digest := append(append(emptyNs, emptyNs...), h...) + digest := append(make([]byte, int(n.NamespaceLen)*2), h...) return digest } @@ -212,42 +211,64 @@ func (n *NmtHasher) MustHashLeaf(ndata []byte) []byte { return res } -// ValidateNodeFormat checks whether the supplied node conforms to the -// namespaced hash format and returns ErrInvalidNodeLen if not. -func (n *NmtHasher) ValidateNodeFormat(node []byte) (err error) { +// nsIDRange represents the range of namespace IDs with minimum and maximum values. +type nsIDRange struct { + Min, Max namespace.ID +} + +// tryFetchNodeNSRange attempts to return the min and max namespace ids. +// It will return an ErrInvalidNodeLen | ErrInvalidNodeNamespaceOrder +// if the supplied node does not conform to the namespaced hash format. +func (n *NmtHasher) tryFetchNodeNSRange(node []byte) (nsIDRange, error) { expectedNodeLen := n.Size() nodeLen := len(node) if nodeLen != expectedNodeLen { - return fmt.Errorf("%w: got: %v, want %v", ErrInvalidNodeLen, nodeLen, expectedNodeLen) + return nsIDRange{}, fmt.Errorf("%w: got: %v, want %v", ErrInvalidNodeLen, nodeLen, expectedNodeLen) } // check the namespace order minNID := namespace.ID(MinNamespace(node, n.NamespaceSize())) maxNID := namespace.ID(MaxNamespace(node, n.NamespaceSize())) if maxNID.Less(minNID) { - return fmt.Errorf("%w: max namespace ID %d is less than min namespace ID %d ", ErrInvalidNodeNamespaceOrder, maxNID, minNID) + return nsIDRange{}, fmt.Errorf("%w: max namespace ID %d is less than min namespace ID %d ", ErrInvalidNodeNamespaceOrder, maxNID, minNID) } - return nil + return nsIDRange{Min: minNID, Max: maxNID}, nil } -// validateSiblingsNamespaceOrder checks whether left and right as two sibling -// nodes in an NMT have correct namespace IDs relative to each other, more -// specifically, the maximum namespace ID of the left sibling should not exceed -// the minimum namespace ID of the right sibling. It returns ErrUnorderedSiblings error if the check fails. -func (n *NmtHasher) validateSiblingsNamespaceOrder(left, right []byte) (err error) { - if err := n.ValidateNodeFormat(left); err != nil { - return fmt.Errorf("%w: left node does not match the namesapce hash format", err) +// ValidateNodeFormat checks whether the supplied node conforms to the +// namespaced hash format and returns an error if not. +func (n *NmtHasher) ValidateNodeFormat(node []byte) error { + _, err := n.tryFetchNodeNSRange(node) + return err +} + +// tryFetchLeftAndRightNSRange attempts to return the min/max namespace ids of both +// the left and right nodes. It verifies whether left +// and right comply by the namespace hash format, and are correctly ordered +// according to their namespace IDs. +func (n *NmtHasher) tryFetchLeftAndRightNSRanges(left, right []byte) ( + nsIDRange, + nsIDRange, + error, +) { + var lNsRange nsIDRange + var rNsRange nsIDRange + var err error + + lNsRange, err = n.tryFetchNodeNSRange(left) + if err != nil { + return lNsRange, rNsRange, err } - if err := n.ValidateNodeFormat(right); err != nil { - return fmt.Errorf("%w: right node does not match the namesapce hash format", err) + rNsRange, err = n.tryFetchNodeNSRange(right) + if err != nil { + return lNsRange, rNsRange, err } - leftMaxNs := namespace.ID(MaxNamespace(left, n.NamespaceSize())) - rightMinNs := namespace.ID(MinNamespace(right, n.NamespaceSize())) // check the namespace range of the left and right children - if rightMinNs.Less(leftMaxNs) { - return fmt.Errorf("%w: the maximum namespace of the left child %x is greater than the min namespace of the right child %x", ErrUnorderedSiblings, leftMaxNs, rightMinNs) + if rNsRange.Min.Less(lNsRange.Max) { + err = fmt.Errorf("%w: the min namespace ID of the right child %d is less than the max namespace ID of the left child %d", ErrUnorderedSiblings, rNsRange.Min, lNsRange.Max) } - return nil + + return lNsRange, rNsRange, err } // ValidateNodes is a helper function to verify the @@ -255,13 +276,8 @@ func (n *NmtHasher) validateSiblingsNamespaceOrder(left, right []byte) (err erro // and right comply by the namespace hash format, and are correctly ordered // according to their namespace IDs. func (n *NmtHasher) ValidateNodes(left, right []byte) error { - if err := n.ValidateNodeFormat(left); err != nil { - return err - } - if err := n.ValidateNodeFormat(right); err != nil { - return err - } - return n.validateSiblingsNamespaceOrder(left, right) + _, _, err := n.tryFetchLeftAndRightNSRanges(left, right) + return err } // HashNode calculates a namespaced hash of a node using the supplied left and @@ -278,21 +294,19 @@ func (n *NmtHasher) ValidateNodes(left, right []byte) error { // If the namespace range of the right child is start=end=MAXNID, indicating that it represents the root of a subtree whose leaves all have the namespace ID of `MAXNID`, then exclude the right child from the namespace range calculation. Instead, // assign the namespace range of the left child as the parent's namespace range. func (n *NmtHasher) HashNode(left, right []byte) ([]byte, error) { - // validate the inputs - if err := n.ValidateNodes(left, right); err != nil { + // validate the inputs & fetch the namespace ranges + lRange, rRange, err := n.tryFetchLeftAndRightNSRanges(left, right) + if err != nil { return nil, err } h := n.baseHasher h.Reset() - leftMinNs, leftMaxNs := MinNamespace(left, n.NamespaceLen), MaxNamespace(left, n.NamespaceLen) - rightMinNs, rightMaxNs := MinNamespace(right, n.NamespaceLen), MaxNamespace(right, n.NamespaceLen) - // compute the namespace range of the parent node - minNs, maxNs := computeNsRange(leftMinNs, leftMaxNs, rightMinNs, rightMaxNs, n.ignoreMaxNs, n.precomputedMaxNs) + minNs, maxNs := computeNsRange(lRange.Min, lRange.Max, rRange.Min, rRange.Max, n.ignoreMaxNs, n.precomputedMaxNs) - res := make([]byte, 0) + res := make([]byte, 0, len(minNs)*2) res = append(res, minNs...) res = append(res, maxNs...) diff --git a/hasher_test.go b/hasher_test.go index 0bf1178e..8014b0b3 100644 --- a/hasher_test.go +++ b/hasher_test.go @@ -283,56 +283,6 @@ func TestHashNode_Error(t *testing.T) { } } -func TestValidateSiblings(t *testing.T) { - // create a dummy hash to use as the digest of the left and right child - randHash := createByteSlice(sha256.Size, 0x01) - - type children struct { - l []byte // namespace hash of the left child with the format of MinNs||MaxNs||h - r []byte // namespace hash of the right child with the format of MinNs||MaxNs||h - } - - tests := []struct { - name string - nidLen namespace.IDSize - children children - wantErr bool - }{ - { - "wrong left node format", 2, - children{concat([]byte{0, 0, 1, 1}, randHash[:len(randHash)-1]), concat([]byte{0, 0, 1, 1}, randHash)}, - true, - }, - { - "wrong right node format", 2, - children{concat([]byte{0, 0, 1, 1}, randHash), concat([]byte{0, 0, 1, 1}, randHash[:len(randHash)-1])}, - true, - }, - { - "left.maxNs>right.minNs", 2, - children{concat([]byte{0, 0, 1, 1}, randHash), concat([]byte{0, 0, 1, 1}, randHash)}, - true, - }, - { - "left.maxNs=right.minNs", 2, - children{concat([]byte{0, 0, 1, 1}, randHash), concat([]byte{1, 1, 2, 2}, randHash)}, - false, - }, - { - "left.maxNs