diff --git a/go/trace/opentracing.go b/go/trace/opentracing.go index 6e6c6c5bc4e..7aea62cb3bc 100644 --- a/go/trace/opentracing.go +++ b/go/trace/opentracing.go @@ -17,13 +17,11 @@ limitations under the License. package trace import ( - "strings" - otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "golang.org/x/net/context" "google.golang.org/grpc" - "vitess.io/vitess/go/vt/proto/vtrpc" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) @@ -86,19 +84,72 @@ func (jf openTracingService) New(parent Span, label string) Span { } func extractMapFromString(in string) (opentracing.TextMapCarrier, error) { - m := make(opentracing.TextMapCarrier) - items := strings.Split(in, ":") - if len(items) < 2 { - return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "expected transmitted context to contain at least span id and trace id") + result := map[string]string{} + if in == "" { + return result, nil + } + + // the following code is a little state machine that uses these + // three variables to hold it's state as it scans the input string + readingKey := true + var currentKey string + var currentValue string + + addCurrentKV := func() error { + if readingKey { + return vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "every element in the context string has to be in the form key=value") + } + result[currentKey] = currentValue + currentKey = "" + currentValue = "" + readingKey = true + return nil } - for _, v := range items { - idx := strings.Index(v, "=") - if idx < 1 { - return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "every element in the context string has to be in the form key=value") + addChar := func(char uint8) { + if readingKey { + currentKey += string(char) + } else { + currentValue += string(char) } - m[v[0:idx]] = v[idx+1:] } - return m, nil + size := len(in) + for i := 0; i < size; i++ { + atEnd := i == size-1 + char := in[i] + switch char { + case '=': + readingKey = false + case ':': + err := addCurrentKV() + if err != nil { + return nil, err + } + + case '\\': + if atEnd { // can't end with a trailing escape char + return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "malformed escaping - cannot end with an escape character ") + } + nextChar := in[i+1] + i++ + switch nextChar { + case '\\': + addChar('\\') + case ':': + addChar(':') + case '=': + addChar('=') + default: + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "malformed escaping - [\\%s] is invalid", string(nextChar)) + } + default: + addChar(char) + } + } + err := addCurrentKV() + if err != nil { + return nil, err + } + return result, nil } func (jf openTracingService) NewFromString(parent, label string) (Span, error) { diff --git a/go/trace/opentracing_test.go b/go/trace/opentracing_test.go index 19bdbce9019..eebb066e22f 100644 --- a/go/trace/opentracing_test.go +++ b/go/trace/opentracing_test.go @@ -19,23 +19,63 @@ package trace import ( "testing" + "github.com/stretchr/testify/require" + "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" ) func TestExtractMapFromString(t *testing.T) { - expected := make(opentracing.TextMapCarrier) - expected["apa"] = "12" - expected["banan"] = "x-tracing-backend-12" - result, err := extractMapFromString("apa=12:banan=x-tracing-backend-12") - assert.NoError(t, err) - assert.Equal(t, expected, result) -} + type testCase struct { + str string + expected opentracing.TextMapCarrier + err bool + } -func TestErrorConditions(t *testing.T) { - _, err := extractMapFromString("") - assert.Error(t, err) + tests := []testCase{{ + str: "apa=12:banan=x-tracing-backend-12", + expected: map[string]string{ + "apa": "12", + "banan": "x-tracing-backend-12", + }, + }, { + str: `uber-trace-id=123\:456\:789\:1`, + expected: map[string]string{"uber-trace-id": "123:456:789:1"}, + }, { + str: `key:`, + err: true, + }, { + str: ``, + expected: map[string]string{}, + }, { + str: `=`, + expected: map[string]string{"": ""}, + }, { + str: `so\=confusing=42`, + expected: map[string]string{"so=confusing": "42"}, + }, { + str: `key=\=42\=`, + expected: map[string]string{"key": "=42="}, + }, { + str: `key=\\`, + expected: map[string]string{"key": `\`}, + }, { + str: `key=\r`, + err: true, + }, { + str: `key=r\`, + err: true, + }} - _, err = extractMapFromString("key=value:keywithnovalue") - assert.Error(t, err) + for _, tc := range tests { + t.Run(tc.str, func(t *testing.T) { + result, err := extractMapFromString(tc.str) + if tc.err { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } }