Skip to content

Commit

Permalink
allow setting the user-agent string sent from the client (#521)
Browse files Browse the repository at this point in the history
Signed-off-by: Jake Sanders <jsand@google.com>
  • Loading branch information
Jake Sanders authored Dec 3, 2021
1 parent 41dd848 commit 3278f72
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 1 deletion.
68 changes: 68 additions & 0 deletions pkg/client/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import "net/http"

// Option is a functional option for customizing static signatures.
type Option func(*options)

type options struct {
UserAgent string
}

func makeOptions(opts ...Option) *options {
o := &options{
UserAgent: "",
}

for _, opt := range opts {
opt(o)
}

return o
}

// WithUserAgent sets the media type of the signature.
func WithUserAgent(userAgent string) Option {
return func(o *options) {
o.UserAgent = userAgent
}
}

type roundTripper struct {
http.RoundTripper
UserAgent string
}

// RoundTrip implements `http.RoundTripper`
func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", rt.UserAgent)
return rt.RoundTripper.RoundTrip(req)
}

func createRoundTripper(inner http.RoundTripper, o *options) http.RoundTripper {
if inner == nil {
inner = http.DefaultTransport
}
if o.UserAgent == "" {
// There's nothing to do...
return inner
}
return &roundTripper{
RoundTripper: inner,
UserAgent: o.UserAgent,
}
}
104 changes: 104 additions & 0 deletions pkg/client/options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"net/http"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestMakeOptions(t *testing.T) {
tests := []struct {
desc string

opts []Option
want *options
}{{
desc: "no opts",
want: &options{},
}, {
desc: "WithUserAgent",
opts: []Option{WithUserAgent("test user agent")},
want: &options{UserAgent: "test user agent"},
}}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
got := makeOptions(tc.opts...)
if d := cmp.Diff(tc.want, got); d != "" {
t.Errorf("makeOptions() returned unexpected result (-want +got): %s", d)
}
})
}
}

type mockRoundTripper struct {
gotReqs []*http.Request

resp *http.Response
err error
}

// RoundTrip implements `http.RoundTripper`
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.gotReqs = append(m.gotReqs, req)
return m.resp, m.err
}

func TestCreateRoundTripper(t *testing.T) {
t.Run("always returns non-nil", func(t *testing.T) {
got := createRoundTripper(nil, &options{})
if got == nil {
t.Errorf("createRoundTripper() should never return a nil `http.RoundTripper`")
}
})

testReq, err := http.NewRequest("GET", "http://www.example.com/test", nil)
if err != nil {
t.Fatalf("http.NewRequest() failed: %v", err)
}

testResp := &http.Response{
Status: "OK",
StatusCode: 200,
Request: testReq,
}

expectedUserAgent := "test UserAgent"

m := &mockRoundTripper{}
rt := createRoundTripper(m, &options{
UserAgent: expectedUserAgent,
})
m.resp = testResp

gotResp, err := rt.RoundTrip(testReq)
if err != nil {
t.Errorf("RoundTrip() returned error: %v", err)
}
if len(m.gotReqs) < 1 {
t.Fatalf("inner RoundTripper.RoundTrip() was not called")
}
gotReq := m.gotReqs[0]
gotReqUserAgent := gotReq.UserAgent()
if gotReqUserAgent != expectedUserAgent {
t.Errorf("rt.RoundTrip() did not set the User-Agent properly. Wanted: %q, got: %q", expectedUserAgent, gotReqUserAgent)
}

if testResp != gotResp {
t.Errorf("roundTripper.RoundTrip() should have returned exactly the response of the inner RoundTripper. Wanted %v, got %v", testResp, gotResp)
}
}
5 changes: 4 additions & 1 deletion pkg/client/rekor_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ import (
"github.com/spf13/viper"
)

func GetRekorClient(rekorServerURL string) (*client.Rekor, error) {
func GetRekorClient(rekorServerURL string, opts ...Option) (*client.Rekor, error) {
url, err := url.Parse(rekorServerURL)
if err != nil {
return nil, err
}
o := makeOptions(opts...)

rt := httptransport.New(url.Host, client.DefaultBasePath, []string{url.Scheme})
rt.Consumers["application/yaml"] = YamlConsumer()
Expand All @@ -43,6 +44,8 @@ func GetRekorClient(rekorServerURL string) (*client.Rekor, error) {
rt.DefaultAuthentication = httptransport.APIKeyAuth("apiKey", "query", viper.GetString("api-key"))
}

rt.Transport = createRoundTripper(rt.Transport, o)

registry := strfmt.Default
registry.Add("signedCheckpoint", &util.SignedNote{}, util.SignedCheckpointValidator)
return client.New(rt, registry), nil
Expand Down
22 changes: 22 additions & 0 deletions pkg/client/rekor_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,25 @@ func TestAPIKey(t *testing.T) {
}
_, _ = client.Pubkey.GetPublicKey(nil)
}

func TestGetRekorClientWithOptions(t *testing.T) {
expectedUserAgent := "test User-Agent"
testServer := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
file := []byte{}

got := r.UserAgent()
if got != expectedUserAgent {
t.Errorf("wanted User-Agent %q, got %q", expectedUserAgent, got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(file)
}))
defer testServer.Close()

client, err := GetRekorClient(testServer.URL, WithUserAgent(expectedUserAgent))
if err != nil {
t.Error(err)
}
_, _ = client.Tlog.GetLogInfo(nil)
}

0 comments on commit 3278f72

Please sign in to comment.