Skip to content

Commit

Permalink
Use strings.Builder for grpc util methods
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane committed Jun 26, 2023
1 parent 291e3d3 commit 9aaddb0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
4 changes: 2 additions & 2 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er

func (w *ErrorWriter) writeGRPC(response http.ResponseWriter, err error) error {
trailers := make(http.Header, 2) // need space for at least code & message
grpcErrorToTrailer(w.bufferPool, trailers, w.protobuf, err)
grpcErrorToTrailer(trailers, w.protobuf, err)
// To make net/http reliably send trailers without a body, we must set the
// Trailers header rather than using http.TrailerPrefix. See
// https://github.com/golang/go/issues/54723.
Expand All @@ -162,7 +162,7 @@ func (w *ErrorWriter) writeGRPC(response http.ResponseWriter, err error) error {
func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) error {
// This is a trailers-only response. To match the behavior of Envoy and
// protocol_grpc.go, put the trailers in the HTTP headers.
grpcErrorToTrailer(w.bufferPool, response.Header(), w.protobuf, err)
grpcErrorToTrailer(response.Header(), w.protobuf, err)
response.WriteHeader(http.StatusOK)
return nil
}
38 changes: 17 additions & 21 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (cc *grpcClientConn) Receive(msg any) error {
cc.responseTrailer,
cc.readTrailers(&cc.unmarshaler, cc.duplexCall),
)
serverErr := grpcErrorFromTrailer(cc.bufferPool, cc.protobuf, cc.responseTrailer)
serverErr := grpcErrorFromTrailer(cc.protobuf, cc.responseTrailer)
if serverErr != nil && (errors.Is(err, io.EOF) || !errors.Is(serverErr, errTrailersWithoutGRPCStatus)) {
// We've either:
// - Cleanly read until the end of the response body and *not* received
Expand Down Expand Up @@ -434,7 +434,6 @@ func (cc *grpcClientConn) validateResponse(response *http.Response) *Error {
cc.responseHeader,
cc.responseTrailer,
cc.compressionPools,
cc.bufferPool,
cc.protobuf,
); err != nil {
return err
Expand Down Expand Up @@ -524,7 +523,7 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) {
len(hc.responseTrailer)+2, // always make space for status & message
)
mergeHeaders(mergedTrailers, hc.responseTrailer)
grpcErrorToTrailer(hc.bufferPool, mergedTrailers, hc.protobuf, err)
grpcErrorToTrailer(mergedTrailers, hc.protobuf, err)
if hc.web && !hc.wroteToBody {
// We're using gRPC-Web and we haven't yet written to the body. Since we're
// not sending any response messages, the gRPC specification calls this a
Expand Down Expand Up @@ -661,7 +660,6 @@ func grpcValidateResponse(
response *http.Response,
header, trailer http.Header,
availableCompressors readOnlyCompressionPools,
bufferPool *bufferPool,
protobuf Codec,
) *Error {
if response.StatusCode != http.StatusOK {
Expand All @@ -683,7 +681,6 @@ func grpcValidateResponse(
// When there's no body, gRPC and gRPC-Web servers may send error information
// in the HTTP headers.
if err := grpcErrorFromTrailer(
bufferPool,
protobuf,
response.Header,
); err != nil && !errors.Is(err, errTrailersWithoutGRPCStatus) {
Expand Down Expand Up @@ -729,7 +726,7 @@ func grpcHTTPToCode(httpCode int) Code {
// binary Protobuf format, even if the messages in the request/response stream
// use a different codec. Consequently, this function needs a Protobuf codec to
// unmarshal error information in the headers.
func grpcErrorFromTrailer(bufferPool *bufferPool, protobuf Codec, trailer http.Header) *Error {
func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error {
codeHeader := getHeaderCanonical(trailer, grpcHeaderStatus)
if codeHeader == "" {
return NewError(CodeInternal, errTrailersWithoutGRPCStatus)
Expand All @@ -742,7 +739,7 @@ func grpcErrorFromTrailer(bufferPool *bufferPool, protobuf Codec, trailer http.H
if err != nil {
return errorf(CodeInternal, "gRPC protocol error: invalid error code %q", codeHeader)
}
message := grpcPercentDecode(bufferPool, getHeaderCanonical(trailer, grpcHeaderMessage))
message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage))
retErr := NewWireError(Code(code), errors.New(message))

detailsBinaryEncoded := getHeaderCanonical(trailer, grpcHeaderDetails)
Expand Down Expand Up @@ -823,7 +820,7 @@ func grpcContentTypeFromCodecName(web bool, name string) string {
return grpcContentTypePrefix + name
}

func grpcErrorToTrailer(bufferPool *bufferPool, trailer http.Header, protobuf Codec, err error) {
func grpcErrorToTrailer(trailer http.Header, protobuf Codec, err error) {
if err == nil {
setHeaderCanonical(trailer, grpcHeaderStatus, "0") // zero is the gRPC OK status
setHeaderCanonical(trailer, grpcHeaderMessage, "")
Expand All @@ -842,7 +839,6 @@ func grpcErrorToTrailer(bufferPool *bufferPool, trailer http.Header, protobuf Co
trailer,
grpcHeaderMessage,
grpcPercentEncode(
bufferPool,
fmt.Sprintf("marshal protobuf status: %v", binErr),
),
)
Expand All @@ -852,7 +848,7 @@ func grpcErrorToTrailer(bufferPool *bufferPool, trailer http.Header, protobuf Co
mergeHeaders(trailer, connectErr.meta)
}
setHeaderCanonical(trailer, grpcHeaderStatus, code)
setHeaderCanonical(trailer, grpcHeaderMessage, grpcPercentEncode(bufferPool, status.Message))
setHeaderCanonical(trailer, grpcHeaderMessage, grpcPercentEncode(status.Message))
setHeaderCanonical(trailer, grpcHeaderDetails, EncodeBinaryHeader(bin))
}

Expand Down Expand Up @@ -881,48 +877,48 @@ func grpcStatusFromError(err error) *statusv1.Status {
//
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses
// https://datatracker.ietf.org/doc/html/rfc3986#section-2.1
func grpcPercentEncode(bufferPool *bufferPool, msg string) string {
func grpcPercentEncode(msg string) string {
for i := 0; i < len(msg); i++ {
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
if c := msg[i]; c < ' ' || c > '~' || c == '%' {
return grpcPercentEncodeSlow(bufferPool, msg, i)
return grpcPercentEncodeSlow(msg, i)
}
}
return msg
}

// msg needs some percent-escaping. Bytes before offset don't require
// percent-encoding, so they can be copied to the output as-is.
func grpcPercentEncodeSlow(bufferPool *bufferPool, msg string, offset int) string {
out := bufferPool.Get()
defer bufferPool.Put(out)
func grpcPercentEncodeSlow(msg string, offset int) string {
var out strings.Builder
out.Grow(2 * len(msg))
out.WriteString(msg[:offset])
for i := offset; i < len(msg); i++ {
c := msg[i]
if c < ' ' || c > '~' || c == '%' {
out.WriteString(fmt.Sprintf("%%%02X", c))
fmt.Fprintf(&out, "%%%02X", c)
continue
}
out.WriteByte(c)
}
return out.String()
}

func grpcPercentDecode(bufferPool *bufferPool, encoded string) string {
func grpcPercentDecode(encoded string) string {
for i := 0; i < len(encoded); i++ {
if c := encoded[i]; c == '%' && i+2 < len(encoded) {
return grpcPercentDecodeSlow(bufferPool, encoded, i)
return grpcPercentDecodeSlow(encoded, i)
}
}
return encoded
}

// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be
// decoded byte-by-byte starting at offset.
func grpcPercentDecodeSlow(bufferPool *bufferPool, encoded string, offset int) string {
out := bufferPool.Get()
defer bufferPool.Put(out)
func grpcPercentDecodeSlow(encoded string, offset int) string {
var out strings.Builder
out.Grow(len(encoded))
out.WriteString(encoded[:offset])
for i := offset; i < len(encoded); i++ {
c := encoded[i]
Expand Down
34 changes: 28 additions & 6 deletions protocol_grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,12 @@ func TestGRPCEncodeTimeoutQuick(t *testing.T) {

func TestGRPCPercentEncodingQuick(t *testing.T) {
t.Parallel()
pool := newBufferPool()
roundtrip := func(input string) bool {
if !utf8.ValidString(input) {
return true
}
encoded := grpcPercentEncode(pool, input)
decoded := grpcPercentDecode(pool, encoded)
encoded := grpcPercentEncode(input)
decoded := grpcPercentDecode(encoded)
return decoded == input
}
if err := quick.Check(roundtrip, nil /* config */); err != nil {
Expand All @@ -160,12 +159,11 @@ func TestGRPCPercentEncodingQuick(t *testing.T) {

func TestGRPCPercentEncoding(t *testing.T) {
t.Parallel()
pool := newBufferPool()
roundtrip := func(input string) {
assert.True(t, utf8.ValidString(input), assert.Sprintf("input invalid UTF-8"))
encoded := grpcPercentEncode(pool, input)
encoded := grpcPercentEncode(input)
t.Logf("%q encoded as %q", input, encoded)
decoded := grpcPercentDecode(pool, encoded)
decoded := grpcPercentDecode(encoded)
assert.Equal(t, decoded, input)
}

Expand Down Expand Up @@ -194,3 +192,27 @@ func TestGRPCWebTrailerMarshalling(t *testing.T) {
marshalled := responseWriter.Body.String()
assert.Equal(t, marshalled, "grpc-message: Foo\r\ngrpc-status: 0\r\nuser-provided: bar\r\n")
}

func BenchmarkGRPCPercentEncoding(b *testing.B) {
input := "Hello, 世界"
want := "Hello, %E4%B8%96%E7%95%8C"
b.ReportAllocs()
for i := 0; i < b.N; i++ {
got := grpcPercentEncode(input)
if got != want {
b.Fatalf("encodeGrpcMessage(%q) = %s, want %s", input, got, want)
}
}
}

func BenchmarkGRPCPercentDecoding(b *testing.B) {
input := "Hello, %E4%B8%96%E7%95%8C"
want := "Hello, 世界"
b.ReportAllocs()
for i := 0; i < b.N; i++ {
got := grpcPercentDecode(input)
if got != want {
b.Fatalf("decodeGrpcMessage(%q) = %s, want %s", input, got, want)
}
}
}

0 comments on commit 9aaddb0

Please sign in to comment.