diff --git a/pgtype/hstore.go b/pgtype/hstore.go index e4695819a..a8559d8a6 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -1,13 +1,11 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "errors" "fmt" - "unicode" - "unicode/utf8" + "strings" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -247,21 +245,11 @@ func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { // scanString does not return nil hstore values because string cannot be nil. func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error { - keys, values, err := parseHstore(src) + hstore, err := parseHstore(src) if err != nil { return err } - - m := make(Hstore, len(keys)) - for i := range keys { - if values[i].Valid { - m[keys[i]] = &values[i].String - } else { - m[keys[i]] = nil - } - } - - return scanner.ScanHstore(m) + return scanner.ScanHstore(hstore) } func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { @@ -281,187 +269,215 @@ func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( return hstore, nil } -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) - type hstoreParser struct { - str string - pos int + str string + pos int + nextBackslash int } func newHSP(in string) *hstoreParser { return &hstoreParser{ - pos: 0, - str: in, + pos: 0, + str: in, + nextBackslash: strings.IndexByte(in, '\\'), } } -func (p *hstoreParser) Consume() (r rune, end bool) { +func (p *hstoreParser) atEnd() bool { + return p.pos >= len(p.str) +} + +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { if p.pos >= len(p.str) { - end = true - return + return 0, true } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return + b = p.str[p.pos] + p.pos++ + return b, false } -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return +func unexpectedByteErr(actualB byte, expectedB byte) error { + return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) +} + +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { + nextB, end := p.consume() + if end { + return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return + if nextB != expectedB { + return unexpectedByteErr(nextB, expectedB) + } + return nil } -// parseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores. -func parseHstore(s string) (k []string, v []Text, err error) { - if s == "" { - return +// consumeExpected2 consumes two expected bytes or returns an error. +// This was a bit faster than using a string argument (better inlining? Not sure). +func (p *hstoreParser) consumeExpected2(one byte, two byte) error { + if p.pos+2 > len(p.str) { + return errors.New("unexpected end of string") + } + if p.str[p.pos] != one { + return unexpectedByteErr(p.str[p.pos], one) + } + if p.str[p.pos+1] != two { + return unexpectedByteErr(p.str[p.pos+1], two) } + p.pos += 2 + return nil +} - buf := bytes.Buffer{} - keys := []string{} - values := []Text{} - p := newHSP(s) +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) - r, end := p.Consume() - state := hsPre +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { + // fast path: assume most keys/values do not contain escapes + nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') + if nextDoubleQuote == -1 { + return "", errEOSInQuoted + } + nextDoubleQuote += p.pos + if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { + // no escapes in this string + s := p.str[p.pos:nextDoubleQuote] + p.pos = nextDoubleQuote + 1 + return s, nil + } - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, Text{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, Text{}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + // slow path: string contains escapes + s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) + p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') + if p.nextBackslash != -1 { + p.nextBackslash += p.pos + } + return s, err +} + +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { + // copy the prefix that does not contain backslashes + var builder strings.Builder + builder.WriteString(p.str[p.pos:firstBackslash]) + + // skip to the backslash + p.pos = firstBackslash + + // copy bytes until the end, unescaping backslashes + for { + nextB, end := p.consume() + if end { + return "", errEOSInQuoted + } else if nextB == '"' { + break + } else if nextB == '\\' { + // escape: skip the backslash and copy the char + nextB, end = p.consume() + if end { + return "", errEOSInQuoted } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expecting space") - case (unicode.IsSpace(r)): - // after space is a doublequote to start the key - r, end = p.Consume() - if end { - err = errors.New("Found EOS after space, expecting \"") - return - } - if r != '"' { - err = fmt.Errorf("Invalid character '%c' after space, expecting \"", r) - return - } - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ',', expecting space", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + if !(nextB == '\\' || nextB == '"') { + return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) } + builder.WriteByte(nextB) + } else { + // normal byte: copy it + builder.WriteByte(nextB) } + } + return builder.String(), nil +} +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { + return p.consumeExpected2(',', ' ') +} + +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { + return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) { + // peek at the next byte + if p.atEnd() { + return Text{}, errors.New("found end instead of value") + } + next := p.str[p.pos] + if next == 'N' { + // must be the exact string NULL: use consumeExpected2 twice + err := p.consumeExpected2('N', 'U') if err != nil { - return + return Text{}, err } - r, end = p.Consume() + err = p.consumeExpected2('L', 'L') + if err != nil { + return Text{}, err + } + return Text{String: "", Valid: false}, nil + } else if next != '"' { + return Text{}, unexpectedByteErr(next, '"') } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return + + // skip the double quote + p.pos += 1 + s, err := p.consumeDoubleQuoted() + if err != nil { + return Text{}, err } - k = keys - v = values - return + return Text{String: s, Valid: true}, nil +} + +func parseHstore(s string) (Hstore, error) { + p := newHSP(s) + + // This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it + // is less likely to occur in keys/values than '=' or ','. + numPairsEstimate := strings.Count(s, ">") + // makes one allocation of strings for the entire Hstore, rather than one allocation per value. + valueStrings := make([]string, 0, numPairsEstimate) + result := make(Hstore, numPairsEstimate) + first := true + for !p.atEnd() { + if !first { + err := p.consumePairSeparator() + if err != nil { + return nil, err + } + } else { + first = false + } + + err := p.consumeExpectedByte('"') + if err != nil { + return nil, err + } + + key, err := p.consumeDoubleQuoted() + if err != nil { + return nil, err + } + + err = p.consumeKVSeparator() + if err != nil { + return nil, err + } + + value, err := p.consumeDoubleQuotedOrNull() + if err != nil { + return nil, err + } + if value.Valid { + valueStrings = append(valueStrings, value.String) + result[key] = &valueStrings[len(valueStrings)-1] + } else { + result[key] = nil + } + } + + return result, nil }