Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jaedle committed Jun 27, 2024
1 parent 714147d commit fdbf1b2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 31 deletions.
8 changes: 6 additions & 2 deletions samplers/aws/xray/internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (c *xrayClient) getSamplingRules(ctx context.Context) (*getSamplingRulesOut
if err != nil {
return nil, fmt.Errorf("xray client: unable to retrieve sampling settings: %w", err)
}
defer output.Body.Close()
defer func() { _ = output.Body.Close() }()

if output.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xray client: unable to retrieve sampling settings, expected response status code 200, got: %d", output.StatusCode)
Expand Down Expand Up @@ -172,7 +172,11 @@ func (c *xrayClient) getSamplingTargets(ctx context.Context, s []*samplingStatis
if err != nil {
return nil, fmt.Errorf("xray client: unable to retrieve sampling settings: %w", err)
}
defer output.Body.Close()
defer func() { _ = output.Body.Close() }()

if output.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xray client: unable to retrieve sampling targets, expected response status code 200, got: %d", output.StatusCode)
}

var samplingTargetsOutput *getSamplingTargetsOutput
if err := json.NewDecoder(output.Body).Decode(&samplingTargetsOutput); err != nil {
Expand Down
84 changes: 55 additions & 29 deletions samplers/aws/xray/internal/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package internal

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -15,7 +16,12 @@ import (
)

func createTestClient(t *testing.T, body []byte) *xrayClient {
return createTestClientWithStatusCode(t, http.StatusOK, body)
}

func createTestClientWithStatusCode(t *testing.T, status int, body []byte) *xrayClient {
testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, _ *http.Request) {
res.WriteHeader(status)
_, err := res.Write(body)
require.NoError(t, err)
}))
Expand All @@ -26,7 +32,6 @@ func createTestClient(t *testing.T, body []byte) *xrayClient {

client, err := newClient(*u)
require.NoError(t, err)

return client
}

Expand Down Expand Up @@ -222,22 +227,11 @@ func TestGetSamplingTargetsMissingValues(t *testing.T) {

client := createTestClient(t, body)

samplingTragets, err := client.getSamplingTargets(ctx, nil)
samplingTargets, err := client.getSamplingTargets(ctx, nil)
require.NoError(t, err)

assert.Nil(t, samplingTragets.SamplingTargetDocuments[0].Interval)
assert.Nil(t, samplingTragets.SamplingTargetDocuments[0].ReservoirQuota)
}

func TestNilContext(t *testing.T) {
client := createTestClient(t, []byte(``))
samplingRulesOutput, err := client.getSamplingRules(context.TODO())
require.Error(t, err)
require.Nil(t, samplingRulesOutput)

samplingTargetsOutput, err := client.getSamplingTargets(context.TODO(), nil)
require.Error(t, err)
require.Nil(t, samplingTargetsOutput)
assert.Nil(t, samplingTargets.SamplingTargetDocuments[0].Interval)
assert.Nil(t, samplingTargets.SamplingTargetDocuments[0].ReservoirQuota)
}

func TestNewClient(t *testing.T) {
Expand All @@ -258,25 +252,57 @@ func TestEndpointIsNotReachable(t *testing.T) {
client, err := newClient(*endpoint)
require.NoError(t, err)

_, err = client.getSamplingRules(context.Background())
actualRules, err := client.getSamplingRules(context.Background())
assert.Error(t, err)
assert.Nil(t, actualRules)

actualTargets, err := client.getSamplingTargets(context.Background(), nil)
assert.Error(t, err)
assert.Nil(t, actualTargets)
}

func TestRespondsWithErrorStatusCode(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, _ *http.Request) {
res.WriteHeader(http.StatusForbidden)
_, err := res.Write([]byte(`{}`))
require.NoError(t, err)
}))
t.Cleanup(testServer.Close)
client := createTestClientWithStatusCode(t, http.StatusForbidden, []byte("{}"))

u, err := url.Parse(testServer.URL)
require.NoError(t, err)
actualRules, err := client.getSamplingRules(context.Background())
assert.Error(t, err)
assert.EqualError(t, err, fmt.Sprintf("xray client: unable to retrieve sampling settings, expected response status code 200, got: %d", http.StatusForbidden))
assert.Nil(t, actualRules)

client, err := newClient(*u)
require.NoError(t, err)
actualTargets, err := client.getSamplingTargets(context.Background(), nil)
assert.Error(t, err)
assert.EqualError(t, err, fmt.Sprintf("xray client: unable to retrieve sampling targets, expected response status code 200, got: %d", http.StatusForbidden))
assert.Nil(t, actualTargets)
}

samplingRules, err := client.getSamplingRules(context.Background())
require.Error(t, err)
require.Nil(t, samplingRules)
func TestInvalidResponseBody(t *testing.T) {
type scenarios struct {
name string
response string
}
for _, scenario := range []scenarios{
{
name: "empty response",
response: "",
},
{
name: "malformed json",
response: "",
},
} {
t.Run(scenario.name, func(t *testing.T) {
client := createTestClient(t, []byte(scenario.response))

actualRules, err := client.getSamplingRules(context.TODO())

assert.Error(t, err)
assert.Nil(t, actualRules)
assert.ErrorContains(t, err, "xray client: unable to unmarshal the response body:"+scenario.response)

actualTargets, err := client.getSamplingTargets(context.TODO(), nil)
assert.Error(t, err)
assert.Nil(t, actualTargets)
assert.ErrorContains(t, err, "xray client: unable to unmarshal the response body: "+scenario.response)
})
}
}

0 comments on commit fdbf1b2

Please sign in to comment.