diff --git a/sdk/internal/go.mod b/sdk/internal/go.mod index 8c500ce1c435..c23772f28d9d 100644 --- a/sdk/internal/go.mod +++ b/sdk/internal/go.mod @@ -2,4 +2,9 @@ module github.com/Azure/azure-sdk-for-go/sdk/internal go 1.14 -require golang.org/x/net v0.0.0-20201010224723-4f7140c49acb +require ( + github.com/dnaeon/go-vcr v1.1.0 + github.com/stretchr/testify v1.7.0 + golang.org/x/net v0.0.0-20201010224723-4f7140c49acb + gopkg.in/yaml.v2 v2.4.0 +) diff --git a/sdk/internal/go.sum b/sdk/internal/go.sum index c59642cdfa12..7c8192467fcd 100644 --- a/sdk/internal/go.sum +++ b/sdk/internal/go.sum @@ -1,3 +1,13 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.1.0 h1:ReYa/UBrRyQdant9B4fNHGoCNKw6qh6P0fsdGmZpR7c= +github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -10,3 +20,10 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/testframework/recording.go new file mode 100644 index 000000000000..0f1462812c59 --- /dev/null +++ b/sdk/internal/testframework/recording.go @@ -0,0 +1,414 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/dnaeon/go-vcr/cassette" + "github.com/dnaeon/go-vcr/recorder" + "gopkg.in/yaml.v2" +) + +type Recording struct { + SessionName string + RecordingFile string + VariablesFile string + Mode RecordMode + variables map[string]*string `yaml:"variables"` + previousSessionVariables map[string]*string `yaml:"variables"` + recorder *recorder.Recorder + src rand.Source + now *time.Time + sanitizer *RecordingSanitizer + c TestContext +} + +const ( + alphanumericBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" + alphanumericLowercaseBytes = "abcdefghijklmnopqrstuvwxyz1234567890" + randomSeedVariableName = "randomSeed" + nowVariableName = "now" + ModeEnvironmentVariableName = "AZURE_TEST_MODE" +) + +// Inspired by https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go +const ( + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1< 0 { + // Merge values from previousVariables that are not in variables to variables + for k, v := range r.previousSessionVariables { + if _, ok := r.variables[k]; ok { + // skip variables that were new in the current session + continue + } + r.variables[k] = v + } + + // Marshal to YAML and save variables + data, err := yaml.Marshal(r.variables) + if err != nil { + return err + } + + f, err := r.createVariablesFileIfNotExists() + if err != nil { + return err + } + + defer f.Close() + + // http://www.yaml.org/spec/1.2/spec.html#id2760395 + _, err = f.Write([]byte("---\n")) + if err != nil { + return err + } + + _, err = f.Write(data) + if err != nil { + return err + } + } + return nil +} + +func (r *Recording) Now() time.Time { + r.initNow() + + return *r.now +} + +func (r *Recording) UUID() uuid.UUID { + r.initRandomSource() + + return uuid.FromSource(r.src) +} + +// GenerateAlphaNumericID will generate a recorded random alpha numeric id +// if the recording has a randomSeed already set, the value will be generated from that seed, else a new random seed will be used +func (r *Recording) GenerateAlphaNumericID(prefix string, length int, lowercaseOnly bool) (string, error) { + + if length <= len(prefix) { + return "", errors.New("length must be greater than prefix") + } + + r.initRandomSource() + + sb := strings.Builder{} + sb.Grow(length) + sb.WriteString(prefix) + i := length - len(prefix) - 1 + // A src.Int63() generates 63 random bits, enough for letterIdxMax characters! + for cache, remain := r.src.Int63(), letterIdxMax; i >= 0; { + if remain == 0 { + cache, remain = r.src.Int63(), letterIdxMax + } + if lowercaseOnly { + if idx := int(cache & letterIdxMask); idx < len(alphanumericLowercaseBytes) { + sb.WriteByte(alphanumericLowercaseBytes[idx]) + i-- + } + } else { + if idx := int(cache & letterIdxMask); idx < len(alphanumericBytes) { + sb.WriteByte(alphanumericBytes[idx]) + i-- + } + } + cache >>= letterIdxBits + remain-- + } + str := sb.String() + return str, nil +} + +// getRequiredEnv gets an environment variable by name and returns an error if it is not found +func getRequiredEnv(name string) (*string, error) { + env, ok := os.LookupEnv(name) + if ok { + return &env, nil + } else { + return nil, errors.New(envNotExistsError(name)) + } +} + +// getOptionalEnv gets an environment variable by name and returns the defaultValue if not found +func getOptionalEnv(name string, defaultValue string) *string { + env, ok := os.LookupEnv(name) + if ok { + return &env + } else { + return &defaultValue + } +} + +func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool { + isMatch := compareMethods(req, rec, r.c) && + compareURLs(req, rec, r.c) && + compareHeaders(req, rec, r.c) && + compareBodies(req, rec, r.c) + + return isMatch +} + +func missingRequestError(req *http.Request) string { + reqUrl := req.URL.String() + return fmt.Sprintf("\nNo matching recorded request found.\nRequest: [%s] %s\n", req.Method, reqUrl) +} + +func envNotExistsError(varName string) string { + return "Required environment variable not set: " + varName +} + +// applyVariableOptions applies the VariableType transform to the value +// If variableType is not provided or Default, return result +// If variableType is Secret_String, return SanitizedValue +// If variableType isSecret_Base64String return SanitizedBase64Value +func applyVariableOptions(val *string, variableType VariableType) *string { + var ret string + + switch variableType { + case Secret_String: + ret = SanitizedValue + return &ret + case Secret_Base64String: + ret = SanitizedBase64Value + return &ret + default: + return val + } +} + +// initRandomSource initializes the Source to be used for random value creation in this Recording +func (r *Recording) initRandomSource() { + // if we already have a Source generated, return immediately + if r.src != nil { + return + } + + var seed int64 + var err error + + // check to see if we already have a random seed stored, use that if so + seedString, ok := r.previousSessionVariables[randomSeedVariableName] + if ok { + seed, err = strconv.ParseInt(*seedString, 10, 64) + } + + // We did not have a random seed already stored; create a new one + if !ok || err != nil || r.Mode == Live { + seed = time.Now().Unix() + val := strconv.FormatInt(seed, 10) + r.variables[randomSeedVariableName] = &val + } + + // create a Source with the seed + r.src = rand.NewSource(seed) +} + +// initNow initializes the Source to be used for random value creation in this Recording +func (r *Recording) initNow() { + // if we already have a now generated, return immediately + if r.now != nil { + return + } + + var err error + var nowStr *string + var newNow time.Time + + // check to see if we already have a random seed stored, use that if so + nowStr, ok := r.previousSessionVariables[nowVariableName] + if ok { + newNow, err = time.Parse(time.RFC3339Nano, *nowStr) + } + + // We did not have a random seed already stored; create a new one + if !ok || err != nil || r.Mode == Live { + newNow = time.Now() + nowStr = new(string) + *nowStr = newNow.Format(time.RFC3339Nano) + r.variables[nowVariableName] = nowStr + } + + // save the now value. + r.now = &newNow +} + +// getFilePaths returns (recordingFilePath, variablesFilePath) +func getFilePaths(name string) (string, string) { + recPath := "recordings/" + name + varPath := fmt.Sprintf("%s-variables.yaml", recPath) + return recPath, varPath +} + +// createVariablesFileIfNotExists calls os.Create on the VariablesFile and creates it if it or the path does not exist +// Callers must call Close on the result +func (r *Recording) createVariablesFileIfNotExists() (*os.File, error) { + f, err := os.Create(r.VariablesFile) + if err != nil { + if !os.IsNotExist(err) { + return nil, err + } + // Create directory for the variables if missing + variablesDir := filepath.Dir(r.VariablesFile) + if _, err := os.Stat(variablesDir); os.IsNotExist(err) { + if err = os.MkdirAll(variablesDir, 0755); err != nil { + return nil, err + } + } + + f, err = os.Create(r.VariablesFile) + if err != nil { + return nil, err + } + } + + return f, nil +} + +func (r *Recording) unmarshalVariablesFile(out interface{}) error { + data, err := ioutil.ReadFile(r.VariablesFile) + if err != nil { + // If the file or dir do not exist, this is not an error to report + if os.IsNotExist(err) { + r.c.Log(fmt.Sprintf("Did not find recording for test '%s'", r.RecordingFile)) + return nil + } else { + return err + } + } else { + err = yaml.Unmarshal(data, out) + } + return nil +} + +func (r *Recording) initVariables() error { + return r.unmarshalVariablesFile(r.previousSessionVariables) +} + +var modeMap = map[RecordMode]recorder.Mode{ + Record: recorder.ModeRecording, + Live: recorder.ModeDisabled, + Playback: recorder.ModeReplaying, +} diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/recording_sanitizer.go new file mode 100644 index 000000000000..ac4311ef46a2 --- /dev/null +++ b/sdk/internal/testframework/recording_sanitizer.go @@ -0,0 +1,83 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "net/http" + + "github.com/dnaeon/go-vcr/cassette" + "github.com/dnaeon/go-vcr/recorder" +) + +type RecordingSanitizer struct { + recorder *recorder.Recorder + headersToSanitize map[string]*string + urlSanitizer StringSanitizer + bodySanitizer StringSanitizer +} + +type StringSanitizer func(*string) + +const SanitizedValue string = "sanitized" +const SanitizedBase64Value string = "Kg==" + +var sanitizedValueSlice = []string{SanitizedValue} + +func DefaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { + // The default sanitizer sanitizes the Authorization header + s := &RecordingSanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer} + recorder.AddSaveFilter(s.applySaveFilter) + + return s +} + +// AddSanitizedHeaders adds the supplied header names to the list of headers to be sanitized on request and response recordings. +func (s *RecordingSanitizer) AddSanitizedHeaders(headers ...string) { + for _, headerName := range headers { + s.headersToSanitize[headerName] = nil + } +} + +// AddBodysanitizer configures the supplied StringSanitizer to sanitize recording request and response bodies +func (s *RecordingSanitizer) AddBodysanitizer(sanitizer StringSanitizer) { + s.bodySanitizer = sanitizer +} + +// AddUriSanitizer configures the supplied StringSanitizer to sanitize recording request and response URLs +func (s *RecordingSanitizer) AddUrlSanitizer(sanitizer StringSanitizer) { + s.urlSanitizer = sanitizer +} + +func (s *RecordingSanitizer) sanitizeHeaders(header http.Header) { + for headerName := range s.headersToSanitize { + if _, ok := header[headerName]; ok { + header[headerName] = sanitizedValueSlice + } + } +} + +func (s *RecordingSanitizer) sanitizeBodies(body *string) { + s.bodySanitizer(body) +} + +func (s *RecordingSanitizer) sanitizeURL(url *string) { + s.urlSanitizer(url) +} + +func (s *RecordingSanitizer) applySaveFilter(i *cassette.Interaction) error { + s.sanitizeHeaders(i.Request.Headers) + s.sanitizeHeaders(i.Response.Headers) + s.sanitizeURL(&i.Request.URL) + if len(i.Request.Body) > 0 { + s.sanitizeBodies(&i.Request.Body) + } + if len(i.Response.Body) > 0 { + s.sanitizeBodies(&i.Response.Body) + } + return nil +} + +func DefaultStringSanitizer(s *string) {} diff --git a/sdk/internal/testframework/recording_sanitizer_test.go b/sdk/internal/testframework/recording_sanitizer_test.go new file mode 100644 index 000000000000..570dfb3b005b --- /dev/null +++ b/sdk/internal/testframework/recording_sanitizer_test.go @@ -0,0 +1,157 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "net/http" + "os" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/dnaeon/go-vcr/cassette" + "github.com/dnaeon/go-vcr/recorder" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type recordingSanitizerTests struct { + suite.Suite +} + +const authHeader string = "Authorization" +const customHeader1 string = "Fooheader" +const customHeader2 string = "Barheader" +const nonSanitizedHeader string = "notsanitized" + +func TestRecordingSanitizer(t *testing.T) { + suite.Run(t, new(recordingSanitizerTests)) +} + +func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { + assert := assert.New(s.T()) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) + + DefaultSanitizer(r) + + req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) + req.Header.Add(authHeader, "superSecret") + + r.RoundTrip(req) + r.Stop() + + assert.Equal(SanitizedValue, req.Header.Get(authHeader)) + + rec, err := cassette.Load(getTestFileName(s.T(), false)) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(SanitizedValue, i.Request.Headers.Get(authHeader)) + } +} + +func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() { + assert := assert.New(s.T()) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) + + target := DefaultSanitizer(r) + target.AddSanitizedHeaders(customHeader1, customHeader2) + + req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) + req.Header.Add(customHeader1, "superSecret") + req.Header.Add(customHeader2, "verySecret") + safeValue := "safeValue" + req.Header.Add(nonSanitizedHeader, safeValue) + + r.RoundTrip(req) + r.Stop() + + assert.Equal(SanitizedValue, req.Header.Get(customHeader1)) + assert.Equal(SanitizedValue, req.Header.Get(customHeader2)) + assert.Equal(safeValue, req.Header.Get(nonSanitizedHeader)) + + rec, err := cassette.Load(getTestFileName(s.T(), false)) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(SanitizedValue, i.Request.Headers.Get(customHeader1)) + assert.Equal(SanitizedValue, i.Request.Headers.Get(customHeader2)) + assert.Equal(safeValue, i.Request.Headers.Get(nonSanitizedHeader)) + } +} + +func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() { + assert := assert.New(s.T()) + secret := "secretvalue" + secretBody := "some body content that contains a " + secret + server, cleanup := mock.NewServer() + server.SetResponse(mock.WithStatusCode(http.StatusCreated), mock.WithBody([]byte(secretBody))) + defer cleanup() + rt := NewMockRoundTripper(server) + r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) + + baseUrl := server.URL() + "/" + + target := DefaultSanitizer(r) + target.AddUrlSanitizer(func(url *string) { + *url = strings.Replace(*url, secret, SanitizedValue, -1) + }) + target.AddBodysanitizer(func(body *string) { + *body = strings.Replace(*body, secret, SanitizedValue, -1) + }) + + req, _ := http.NewRequest(http.MethodPost, baseUrl+secret, closerFromString(secretBody)) + + r.RoundTrip(req) + r.Stop() + + rec, err := cassette.Load(getTestFileName(s.T(), false)) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.NotContains(i.Response.Body, secret) + assert.NotContains(i.Request.URL, secret) + assert.NotContains(i.Request.Body, secret) + assert.Contains(i.Request.URL, SanitizedValue) + assert.Contains(i.Request.Body, SanitizedValue) + assert.Contains(i.Response.Body, SanitizedValue) + } +} + +func (s *recordingSanitizerTests) TearDownSuite() { + assert := assert.New(s.T()) + // cleanup test files + err := os.RemoveAll("testfiles") + assert.Nil(err) +} + +func getTestFileName(t *testing.T, addSuffix bool) string { + name := "testfiles/" + t.Name() + if addSuffix { + name = name + ".yaml" + } + return name +} + +type mockRoundTripper struct { + server *mock.Server +} + +func NewMockRoundTripper(server *mock.Server) *mockRoundTripper { + return &mockRoundTripper{server: server} +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.server.Do(req) +} diff --git a/sdk/internal/testframework/recording_test.go b/sdk/internal/testframework/recording_test.go new file mode 100644 index 000000000000..4ed73f5805ce --- /dev/null +++ b/sdk/internal/testframework/recording_test.go @@ -0,0 +1,356 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "fmt" + "io/ioutil" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/dnaeon/go-vcr/cassette" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type recordingTests struct { + suite.Suite +} + +func TestRecording(t *testing.T) { + suite.Run(t, new(recordingTests)) +} + +func (s *recordingTests) TestInitializeRecording() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + expectedMode := Playback + + target, err := NewRecording(context, expectedMode) + assert.Nil(err) + assert.NotNil(target.RecordingFile) + assert.NotNil(target.VariablesFile) + assert.Equal(expectedMode, target.Mode) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestStopDoesNotSaveVariablesWhenNoVariablesExist() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + err = target.Stop() + assert.Nil(err) + + _, err = ioutil.ReadFile(target.VariablesFile) + assert.Equal(true, os.IsNotExist(err)) +} + +func (s *recordingTests) TestRecordedVariables() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { s.T().Log(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + nonExistingEnvVar := "nonExistingEnvVar" + expectedVariableValue := "foobar" + variablesMap := map[string]string{} + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + // optional variables always succeed. + assert.Equal(expectedVariableValue, target.GetOptionalRecordedVariable(nonExistingEnvVar, expectedVariableValue, Default)) + + // non existent variables return an error + val, err := target.GetRecordedVariable(nonExistingEnvVar, Default) + // mark test as succeeded + assert.Equal(envNotExistsError(nonExistingEnvVar), err.Error()) + + // now create the env variable and check that it can be fetched + os.Setenv(nonExistingEnvVar, expectedVariableValue) + defer os.Unsetenv(nonExistingEnvVar) + val, err = target.GetRecordedVariable(nonExistingEnvVar, Default) + assert.Equal(expectedVariableValue, val) + + err = target.Stop() + assert.Nil(err) + + // check that a variables file was created with the correct variable + target.unmarshalVariablesFile(variablesMap) + actualValue, ok := variablesMap[nonExistingEnvVar] + assert.Equal(true, ok) + assert.Equal(expectedVariableValue, actualValue) +} + +func (s *recordingTests) TestRecordedVariablesSanitized() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + SanitizedStringVar := "sanitizedvar" + SanitizedBase64StrigVar := "sanitizedbase64var" + secret := "secretstring" + secretBase64 := "asdfasdf==" + variablesMap := map[string]string{} + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + // call GetOptionalRecordedVariable with the Secret_String VariableType arg + assert.Equal(secret, target.GetOptionalRecordedVariable(SanitizedStringVar, secret, Secret_String)) + + // call GetOptionalRecordedVariable with the Secret_Base64String VariableType arg + assert.Equal(secretBase64, target.GetOptionalRecordedVariable(SanitizedBase64StrigVar, secretBase64, Secret_Base64String)) + + // Calling Stop will save the variables and apply the sanitization options + err = target.Stop() + assert.Nil(err) + + // check that a variables file was created with the correct variables + target.unmarshalVariablesFile(variablesMap) + actualValue, ok := variablesMap[SanitizedStringVar] + assert.Equal(true, ok) + // the saved value is sanitized + assert.Equal(SanitizedValue, actualValue) + + target.unmarshalVariablesFile(variablesMap) + actualValue, ok = variablesMap[SanitizedBase64StrigVar] + assert.Equal(true, ok) + // the saved value is sanitized + assert.Equal(SanitizedBase64Value, actualValue) +} + +func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + expectedVariableName := "someVariable" + expectedVariableValue := "foobar" + addedVariableName := "addedVariable" + addedVariableValue := "fizzbuzz" + variablesMap := map[string]string{} + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + target.GetOptionalRecordedVariable(expectedVariableName, expectedVariableValue, Default) + + err = target.Stop() + assert.Nil(err) + + // check that a variables file was created with the correct variable + target.unmarshalVariablesFile(variablesMap) + actualValue, ok := variablesMap[expectedVariableName] + assert.True(ok) + assert.Equal(expectedVariableValue, actualValue) + + variablesMap = map[string]string{} + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + // add a new variable to the existing batch + target2.GetOptionalRecordedVariable(addedVariableName, addedVariableValue, Default) + + err = target2.Stop() + assert.Nil(err) + + // check that a variables file was created with the variables loaded from the previous recording + target2.unmarshalVariablesFile(variablesMap) + actualValue, ok = variablesMap[addedVariableName] + assert.Truef(ok, fmt.Sprintf("Should have found %s", addedVariableName)) + assert.Equal(addedVariableValue, actualValue) + actualValue, ok = variablesMap[expectedVariableName] + assert.Truef(ok, fmt.Sprintf("Should have found %s", expectedVariableName)) + assert.Equal(expectedVariableValue, actualValue) +} + +func (s *recordingTests) TestUUID() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedUUID1 := target.UUID() + recordedUUID1a := target.UUID() + assert.NotEqual(recordedUUID1.String(), recordedUUID1a.String()) + + err = target.Stop() + assert.Nil(err) + + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedUUID2 := target2.UUID() + + // The two generated UUIDs should be the same since target2 loaded the saved random seed from target + assert.Equal(recordedUUID1.String(), recordedUUID2.String()) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestNow() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedNow1 := target.Now() + + time.Sleep(time.Millisecond * 100) + + recordedNow1a := target.Now() + assert.Equal(recordedNow1.UnixNano(), recordedNow1a.UnixNano()) + + err = target.Stop() + assert.Nil(err) + + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedNow2 := target2.Now() + + // The two generated nows should be the same since target2 loaded the saved random seed from target + assert.Equal(recordedNow1.UnixNano(), recordedNow2.UnixNano()) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestGenerateAlphaNumericID() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + prefix := "myprefix" + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + generated1, err := target.GenerateAlphaNumericID(prefix, 10, true) + + assert.Equal(10, len(generated1)) + assert.Equal(true, strings.HasPrefix(generated1, prefix)) + + generated1a, err := target.GenerateAlphaNumericID(prefix, 10, true) + assert.NotEqual(generated1, generated1a) + + err = target.Stop() + assert.Nil(err) + + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + generated2, err := target2.GenerateAlphaNumericID(prefix, 10, true) + + // The two generated Ids should be the same since target2 loaded the saved random seed from target + assert.Equal(generated2, generated1) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestRecordRequestsAndDoMatching() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + + target, err := NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + path, err := target.GenerateAlphaNumericID("", 5, true) + reqUrl := server.URL() + "/" + path + + req, _ := http.NewRequest(http.MethodPost, reqUrl, nil) + + // record the request + target.Do(req) + err = target.Stop() + assert.Nil(err) + + rec, err := cassette.Load(target.SessionName) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(reqUrl, i.Request.URL) + } + + // re-initialize the recording + target, err = NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + // re-create the random url using the recorded variables + path, err = target.GenerateAlphaNumericID("", 5, true) + reqUrl = server.URL() + "/" + path + req, _ = http.NewRequest(http.MethodPost, reqUrl, nil) + + // playback the request + target.Do(req) + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestRecordRequestsAndFailMatchingForMissingRecording() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { s.T().Log(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + + target, err := NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + path, err := target.GenerateAlphaNumericID("", 5, true) + reqUrl := server.URL() + "/" + path + + req, _ := http.NewRequest(http.MethodPost, reqUrl, nil) + + // record the request + target.Do(req) + err = target.Stop() + assert.Nil(err) + + rec, err := cassette.Load(target.SessionName) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(reqUrl, i.Request.URL) + } + + // re-initialize the recording + target, err = NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + // re-create the random url using the recorded variables + reqUrl = server.URL() + "/" + "mismatchedRequest" + req, _ = http.NewRequest(http.MethodPost, reqUrl, nil) + + // playback the request + _, err = target.Do(req) + assert.Equal(missingRequestError(req), err.Error()) + // mark succeeded + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TearDownSuite() { + + // cleanup test files + err := os.RemoveAll("recordings") + assert.Nil(s.T(), err) +} diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go new file mode 100644 index 000000000000..38997a5ffa50 --- /dev/null +++ b/sdk/internal/testframework/request_matcher.go @@ -0,0 +1,111 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "reflect" + + "github.com/dnaeon/go-vcr/cassette" +) + +type RequestMatcher struct { + ignoredHeaders map[string]*string +} + +var ignoredHeaders = map[string]*string{ + "Date": nil, + "X-Ms-Date": nil, + "x-ms-date": nil, + "x-ms-client-request-id": nil, + "User-Agent": nil, + "Request-Id": nil, + "traceparent": nil, + "Authorization": nil, +} + +var recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." +var requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." +var headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" +var methodMismatch = "Test recording methods do not match. request: %s, recording: %s" +var urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" +var bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" + +func compareBodies(r *http.Request, i cassette.Request, c TestContext) bool { + body := bytes.Buffer{} + if r.Body != nil { + _, err := body.ReadFrom(r.Body) + if err != nil { + return false + } + r.Body = ioutil.NopCloser(&body) + } + bodiesMatch := body.String() == i.Body + if !bodiesMatch { + c.Log(fmt.Sprintf(bodiesMismatch, body.String(), i.Body)) + } + return bodiesMatch +} + +func compareURLs(r *http.Request, i cassette.Request, c TestContext) bool { + if r.URL.String() != i.URL { + c.Log(fmt.Sprintf(urlMismatch, r.URL.String(), i.URL)) + return false + } + return true +} + +func compareMethods(r *http.Request, i cassette.Request, c TestContext) bool { + if r.Method != i.Method { + c.Log(fmt.Sprintf(methodMismatch, r.Method, i.Method)) + return false + } + return true +} + +func compareHeaders(r *http.Request, i cassette.Request, c TestContext) bool { + unVisitedCassetteKeys := make(map[string]*string, len(i.Headers)) + // clone the cassette keys to track which we have seen + for k := range i.Headers { + if _, ignore := ignoredHeaders[k]; ignore { + // don't copy ignored headers + continue + } + unVisitedCassetteKeys[k] = nil + } + //iterate through all the request headers to compare them to cassette headers + for key, requestHeader := range r.Header { + if _, ignore := ignoredHeaders[key]; ignore { + // this is an ignorable header + continue + } + delete(unVisitedCassetteKeys, key) + if recordedHeader, foundMatch := i.Headers[key]; foundMatch { + headersMatch := reflect.DeepEqual(requestHeader, recordedHeader) + if !headersMatch { + // headers don't match + c.Log(fmt.Sprintf(headerValuesMismatch, key, requestHeader, recordedHeader)) + return false + } + + } else { + // header not found + c.Log(fmt.Sprintf(recordingHeaderMissing, key)) + return false + } + } + if len(unVisitedCassetteKeys) > 0 { + // headers exist in the recording that do not exist in the request + for headerName := range unVisitedCassetteKeys { + c.Log(fmt.Sprintf(requestHeaderMissing, headerName)) + } + return false + } + return true +} diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/testframework/request_matcher_test.go new file mode 100644 index 000000000000..90c945458917 --- /dev/null +++ b/sdk/internal/testframework/request_matcher_test.go @@ -0,0 +1,193 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/dnaeon/go-vcr/cassette" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type requestMatcherTests struct { + suite.Suite +} + +func TestRequestMatcher(t *testing.T) { + suite.Run(t, new(requestMatcherTests)) +} + +const matchedBody string = "Matching body." +const unMatchedBody string = "This body does not match." + +func (s *requestMatcherTests) TestCompareBodies() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + req := http.Request{Body: closerFromString(matchedBody)} + recReq := cassette.Request{Body: matchedBody} + + isMatch := compareBodies(&req, recReq, context) + + assert.Equal(true, isMatch) + + // make the requests mis-match + req.Body = closerFromString((unMatchedBody)) + + isMatch = compareBodies(&req, recReq, context) + + assert.False(isMatch) +} + +func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + for headerName := range ignoredHeaders { + reqHeaders[headerName] = []string{uuid.New().String()} + recordedHeaders[headerName] = []string{uuid.New().String()} + } + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // All headers match + assert.True(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + headerValue := []string{"some value"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + assert.True(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + header2 := "header2" + headerValue := []string{"some value"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // add a new header to the just req + reqHeaders[header2] = headerValue + + assert.False(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + header2 := "header2" + headerValue := []string{"some value"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // add a new header to just the recording + recordedHeaders[header2] = headerValue + + assert.False(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + header2 := "header2" + headerValue := []string{"some value"} + mismatch := []string{"mismatch"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // header names match but values are different + recordedHeaders[header2] = headerValue + reqHeaders[header2] = mismatch + + assert.False(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareURLs() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + scheme := "https" + host := "foo.bar" + req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}} + recReq := cassette.Request{URL: scheme + "://" + host} + + assert.True(compareURLs(&req, recReq, context)) + + req.URL.Path = "noMatch" + + assert.False(compareURLs(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareMethods() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + methodPost := "POST" + methodPatch := "PATCH" + req := http.Request{Method: methodPost} + recReq := cassette.Request{Method: methodPost} + + assert.True(compareMethods(&req, recReq, context)) + + req.Method = methodPatch + + assert.False(compareMethods(&req, recReq, context)) +} + +func closerFromString(content string) io.ReadCloser { + return ioutil.NopCloser(strings.NewReader(content)) +} diff --git a/sdk/internal/testframework/testcontext.go b/sdk/internal/testframework/testcontext.go new file mode 100644 index 000000000000..97bcc132e5ef --- /dev/null +++ b/sdk/internal/testframework/testcontext.go @@ -0,0 +1,50 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +type TestContext interface { + Fail(string) + Log(string) + Name() string + IsFailed() bool +} + +type testContext struct { + failed bool + fail Failer + log Logger + name string +} + +type Failer func(string) +type Logger func(string) +type Name func() string + +// NewTestContext initializes a new TestContext +func NewTestContext(failer Failer, logger Logger, name Name) TestContext { + return &testContext{fail: failer, log: logger, name: name()} +} + +// Fail calls the Failer func and makes IsFailed return true. +func (c *testContext) Fail(msg string) { + c.failed = true + c.fail(msg) +} + +// Log calls the Logger func. +func (c *testContext) Log(msg string) { + c.log(msg) +} + +// Name calls the Name func and returns the result. +func (c *testContext) Name() string { + return c.name +} + +// IsFailed returns true if the Failer has been called. +func (c *testContext) IsFailed() bool { + return c.failed +} diff --git a/sdk/internal/uuid/uuid.go b/sdk/internal/uuid/uuid.go index 4b288d81fecd..2f3c55d0e633 100644 --- a/sdk/internal/uuid/uuid.go +++ b/sdk/internal/uuid/uuid.go @@ -41,6 +41,20 @@ func New() UUID { return u } +// FromSource returns a new uuid based on the supplied rand.Source as a seed. +func FromSource(src rand.Source) UUID { + u := UUID{} + // Set all bits to randomly (or pseudo-randomly) chosen values. + // math/rand.Read() is no-fail so we omit any error checking. + rnd := rand.New(src) + rnd.Read(u[:]) + u[8] = (u[8] | reservedRFC4122) & 0x7F // u.setVariant(ReservedRFC4122) + + var version byte = 4 + u[6] = (u[6] & 0xF) | (version << 4) // u.setVersion(4) + return u +} + // String returns an unparsed version of the generated UUID sequence. func (u UUID) String() string { return fmt.Sprintf("%x-%x-%x-%x-%x", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) diff --git a/sdk/tables/aztables/tableClient.go b/sdk/tables/aztables/tableClient.go index edc74ece6075..d24ae0a14931 100644 --- a/sdk/tables/aztables/tableClient.go +++ b/sdk/tables/aztables/tableClient.go @@ -5,6 +5,7 @@ package aztables import ( "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) diff --git a/sdk/tables/aztables/tableClient_test.go b/sdk/tables/aztables/tableClient_test.go index 0b76e48fd15d..45b97edc2d47 100644 --- a/sdk/tables/aztables/tableClient_test.go +++ b/sdk/tables/aztables/tableClient_test.go @@ -4,9 +4,15 @@ package aztables import ( - chk "gopkg.in/check.v1" // go get gopkg.in/check.v1 + "testing" ) -func (s *aztestsSuite) TestContainerCreateAccessContainer(c *chk.C) { +func TestContainerCreateAccessContainer(t *testing.T) { // TODO + cred, err := NewSharedKeyCredential("foo", "Kg==") + if err != nil { + t.Fatal(err) + } + + NewTableClient("https://foo", cred, &TableClientOptions{}) } diff --git a/sdk/tables/aztables/zc_client_options.go b/sdk/tables/aztables/zc_client_options.go index cf24278fd2d2..f0a7200fd141 100644 --- a/sdk/tables/aztables/zc_client_options.go +++ b/sdk/tables/aztables/zc_client_options.go @@ -23,7 +23,7 @@ func (o *TableClientOptions) getConnectionOptions() *connectionOptions { return &connectionOptions{ HTTPClient: o.HTTPClient, - Retry: o.Retry, - Telemetry: o.Telemetry, + Retry: o.Retry, + Telemetry: o.Telemetry, } -} \ No newline at end of file +} diff --git a/sdk/tables/aztables/zz_generated_connection.go b/sdk/tables/aztables/zz_generated_connection.go index 990122c2b6e3..325fa053b1bd 100644 --- a/sdk/tables/aztables/zz_generated_connection.go +++ b/sdk/tables/aztables/zz_generated_connection.go @@ -14,6 +14,7 @@ import ( const scope = "foo" const telemetryInfo = "azsdk-go-tables/" + // connectionOptions contains configuration settings for the connection's pipeline. // All zero-value fields will be initialized with their default values. type connectionOptions struct { @@ -67,7 +68,6 @@ func (c *connection) Endpoint() string { } // Pipeline returns the connection's pipeline. -func (c *connection) Pipeline() (azcore.Pipeline) { +func (c *connection) Pipeline() azcore.Pipeline { return c.p } -