Skip to content

Commit

Permalink
*: Add Digest struct to get bytes of digest (#1231)
Browse files Browse the repository at this point in the history
  • Loading branch information
crazycs520 authored May 25, 2021
1 parent f5c77b7 commit c37778a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
47 changes: 35 additions & 12 deletions digester.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package parser
import (
"bytes"
"crypto/sha256"
"fmt"
"encoding/hex"
hash2 "hash"
"reflect"
"strings"
Expand All @@ -27,16 +27,39 @@ import (
"github.com/pingcap/parser/charset"
)

type Digest struct {
b []byte
str string
}

// NewDigest returns a new digest.
func NewDigest(b []byte) *Digest {
return &Digest{
b: b,
str: hex.EncodeToString(b),
}
}

// String returns the digest hex string.
func (d *Digest) String() string {
return d.str
}

// Bytes returns the digest byte slice.
func (d *Digest) Bytes() []byte {
return d.b
}

// DigestHash generates the digest of statements.
// it will generate a hash on normalized form of statement text
// which removes general property of a statement but keeps specific property.
//
// for example: both DigestHash('select 1') and DigestHash('select 2') => e1c71d1661ae46e09b7aaec1c390957f0d6260410df4e4bc71b9c8d681021471
//
// Deprecated: It is logically consistent with NormalizeDigest.
func DigestHash(sql string) (result string) {
func DigestHash(sql string) (digest *Digest) {
d := digesterPool.Get().(*sqlDigester)
result = d.doDigest(sql)
digest = d.doDigest(sql)
digesterPool.Put(d)
return
}
Expand All @@ -48,9 +71,9 @@ func DigestHash(sql string) (result string) {
// for example: DigestNormalized('select ?')
// DigestNormalized should be called with a normalized SQL string (like 'select ?') generated by function Normalize.
// do not call with SQL which is not normalized, DigestNormalized('select 1') and DigestNormalized('select 2') is not the same
func DigestNormalized(normalized string) (result string) {
func DigestNormalized(normalized string) (digest *Digest) {
d := digesterPool.Get().(*sqlDigester)
result = d.doDigestNormalized(normalized)
digest = d.doDigestNormalized(normalized)
digesterPool.Put(d)
return
}
Expand All @@ -68,7 +91,7 @@ func Normalize(sql string) (result string) {
}

// NormalizeDigest combines Normalize and DigestNormalized into one method.
func NormalizeDigest(sql string) (normalized, digest string) {
func NormalizeDigest(sql string) (normalized string, digest *Digest) {
d := digesterPool.Get().(*sqlDigester)
normalized, digest = d.doNormalizeDigest(sql)
digesterPool.Put(d)
Expand All @@ -92,24 +115,24 @@ type sqlDigester struct {
tokens tokenDeque
}

func (d *sqlDigester) doDigestNormalized(normalized string) (result string) {
func (d *sqlDigester) doDigestNormalized(normalized string) (digest *Digest) {
hdr := *(*reflect.StringHeader)(unsafe.Pointer(&normalized))
b := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
Data: hdr.Data,
Len: hdr.Len,
Cap: hdr.Len,
}))
d.hasher.Write(b)
result = fmt.Sprintf("%x", d.hasher.Sum(nil))
digest = NewDigest(d.hasher.Sum(nil))
d.hasher.Reset()
return
}

func (d *sqlDigester) doDigest(sql string) (result string) {
func (d *sqlDigester) doDigest(sql string) (digest *Digest) {
d.normalize(sql)
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
result = fmt.Sprintf("%x", d.hasher.Sum(nil))
digest = NewDigest(d.hasher.Sum(nil))
d.hasher.Reset()
return
}
Expand All @@ -121,12 +144,12 @@ func (d *sqlDigester) doNormalize(sql string) (result string) {
return
}

func (d *sqlDigester) doNormalizeDigest(sql string) (normalized, digest string) {
func (d *sqlDigester) doNormalizeDigest(sql string) (normalized string, digest *Digest) {
d.normalize(sql)
normalized = d.buffer.String()
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
digest = fmt.Sprintf("%x", d.hasher.Sum(nil))
digest = NewDigest(d.hasher.Sum(nil))
d.hasher.Reset()
return
}
Expand Down
51 changes: 44 additions & 7 deletions digester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
package parser_test

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/parser"
)
Expand Down Expand Up @@ -71,7 +76,7 @@ func (s *testSQLDigestSuite) TestNormalize(c *C) {

normalized2, digest2 := parser.NormalizeDigest(test.input)
c.Assert(normalized2, Equals, normalized)
c.Assert(digest2, Equals, digest, Commentf("%+v", test))
c.Assert(digest2.String(), Equals, digest.String(), Commentf("%+v", test))
}
}

Expand All @@ -86,12 +91,12 @@ func (s *testSQLDigestSuite) TestNormalizeDigest(c *C) {
for _, test := range tests {
normalized, digest := parser.NormalizeDigest(test.sql)
c.Assert(normalized, Equals, test.normalized)
c.Assert(digest, Equals, test.digest)
c.Assert(digest.String(), Equals, test.digest)

normalized = parser.Normalize(test.sql)
digest = parser.DigestNormalized(normalized)
c.Assert(normalized, Equals, test.normalized)
c.Assert(digest, Equals, test.digest)
c.Assert(digest.String(), Equals, test.digest)
}
}

Expand All @@ -106,10 +111,10 @@ func (s *testSQLDigestSuite) TestDigestHashEqForSimpleSQL(c *C) {
for _, sql := range sqlGroup {
dig := parser.DigestHash(sql)
if d == "" {
d = dig
d = dig.String()
continue
}
c.Assert(d, Equals, dig)
c.Assert(d, Equals, dig.String())
}
}
}
Expand All @@ -123,10 +128,42 @@ func (s *testSQLDigestSuite) TestDigestHashNotEqForSimpleSQL(c *C) {
for _, sql := range sqlGroup {
dig := parser.DigestHash(sql)
if d == "" {
d = dig
d = dig.String()
continue
}
c.Assert(d, Not(Equals), dig)
c.Assert(d, Not(Equals), dig.String())
}
}
}

func (s *testSQLDigestSuite) TestGenDigest(c *C) {
hash := genRandDigest("abc")
digest := parser.NewDigest(hash)
c.Assert(digest.String(), Equals, fmt.Sprintf("%x", hash))
c.Assert(digest.Bytes(), DeepEquals, hash)
digest = parser.NewDigest(nil)
c.Assert(digest.String(), Equals, "")
c.Assert(digest.Bytes(), IsNil)
}

func genRandDigest(str string) []byte {
hasher := sha256.New()
hasher.Write([]byte(str))
return hasher.Sum(nil)
}

func BenchmarkDigestHexEncode(b *testing.B) {
digest1 := genRandDigest("abc")
b.ResetTimer()
for i := 0; i < b.N; i++ {
hex.EncodeToString(digest1)
}
}

func BenchmarkDigestSprintf(b *testing.B) {
digest1 := genRandDigest("abc")
b.ResetTimer()
for i := 0; i < b.N; i++ {
fmt.Sprintf("%x", digest1)
}
}

0 comments on commit c37778a

Please sign in to comment.