diff --git a/config/http_config.go b/config/http_config.go index cae49b9d..71280746 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -112,11 +112,20 @@ func (u URL) MarshalYAML() (interface{}, error) { // OAuth2 is the oauth2 client configuration. type OAuth2 struct { - ClientID string `yaml:"client_id"` - ClientSecret Secret `yaml:"client_secret"` - Scopes []string `yaml:"scopes,omitempty"` - TokenURL string `yaml:"token_url"` - EndpointParams map[string]string `yaml:"endpoint_params,omitempty"` + ClientID string `yaml:"client_id"` + ClientSecret Secret `yaml:"client_secret"` + ClientSecretFile string `yaml:"client_secret_file"` + Scopes []string `yaml:"scopes,omitempty"` + TokenURL string `yaml:"token_url"` + EndpointParams map[string]string `yaml:"endpoint_params,omitempty"` +} + +// SetDirectory joins any relative file paths with dir. +func (a *OAuth2) SetDirectory(dir string) { + if a == nil { + return + } + a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile) } // HTTPClientConfig configures an HTTP client. @@ -151,6 +160,7 @@ func (c *HTTPClientConfig) SetDirectory(dir string) { c.TLSConfig.SetDirectory(dir) c.BasicAuth.SetDirectory(dir) c.Authorization.SetDirectory(dir) + c.OAuth2.SetDirectory(dir) c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile) } @@ -196,8 +206,13 @@ func (c *HTTPClientConfig) Validate() error { c.BearerTokenFile = "" } } - if c.BasicAuth != nil && c.OAuth2 != nil { - return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured") + if c.OAuth2 != nil { + if c.BasicAuth != nil { + return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured") + } + if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 { + return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured") + } } return nil } @@ -347,7 +362,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT } if cfg.OAuth2 != nil { - rt = cfg.OAuth2.NewOAuth2RoundTripper(context.Background(), rt) + rt = NewOAuth2RoundTripper(cfg.OAuth2, rt) } // Return a new configured RoundTripper. return rt, nil @@ -462,20 +477,72 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() { } } -func (c *OAuth2) NewOAuth2RoundTripper(ctx context.Context, next http.RoundTripper) http.RoundTripper { - config := &clientcredentials.Config{ - ClientID: c.ClientID, - ClientSecret: string(c.ClientSecret), - Scopes: c.Scopes, - TokenURL: c.TokenURL, - EndpointParams: mapToValues(c.EndpointParams), +type oauth2RoundTripper struct { + config *OAuth2 + rt http.RoundTripper + next http.RoundTripper + secret string + mtx sync.RWMutex +} + +func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper { + return &oauth2RoundTripper{ + config: config, + next: next, } +} - tokenSource := config.TokenSource(ctx) +func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + var ( + secret string + changed bool + ) - return &oauth2.Transport{ - Base: next, - Source: tokenSource, + if rt.config.ClientSecretFile != "" { + data, err := ioutil.ReadFile(rt.config.ClientSecretFile) + if err != nil { + return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %s", rt.config.ClientSecretFile, err) + } + secret = strings.TrimSpace(string(data)) + rt.mtx.RLock() + changed = secret != rt.secret + rt.mtx.RUnlock() + } + + if changed || rt.rt == nil { + if rt.config.ClientSecret != "" { + secret = string(rt.config.ClientSecret) + } + + config := &clientcredentials.Config{ + ClientID: rt.config.ClientID, + ClientSecret: secret, + Scopes: rt.config.Scopes, + TokenURL: rt.config.TokenURL, + EndpointParams: mapToValues(rt.config.EndpointParams), + } + + tokenSource := config.TokenSource(context.Background()) + + rt.mtx.Lock() + rt.secret = secret + rt.rt = &oauth2.Transport{ + Base: rt.next, + Source: tokenSource, + } + rt.mtx.Unlock() + } + + rt.mtx.RLock() + currentRT := rt.rt + rt.mtx.RUnlock() + return currentRT.RoundTrip(req) +} + +func (rt *oauth2RoundTripper) CloseIdleConnections() { + // OAuth2 RT does not support CloseIdleConnections() but the next RT might. + if ci, ok := rt.next.(closeIdler); ok { + ci.CloseIdleConnections() } } diff --git a/config/http_config_test.go b/config/http_config_test.go index 4810ea5e..173bf870 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -103,6 +103,10 @@ var invalidHTTPClientConfigs = []struct { httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml", errMsg: `authorization type cannot be set to "basic", use "basic_auth" instead`, }, + { + httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", + errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured", + }, } func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) { @@ -1136,7 +1140,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) } - rt := expectedConfig.NewOAuth2RoundTripper(context.Background(), http.DefaultTransport) + rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) client := http.Client{ Transport: rt, @@ -1148,3 +1152,115 @@ endpoint_params: t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } } + +func TestOAuth2WithFile(t *testing.T) { + var expectedAuth *string + var previousAuth string + tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != *expectedAuth { + t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth) + } + if auth == previousAuth { + t.Fatal("token endpoint called twice") + } + previousAuth = auth + res, _ := json.Marshal(testServerResponse{ + AccessToken: "12345", + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + })) + defer tokenTS.Close() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer 12345" { + t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth) + } + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + secretFile, err := ioutil.TempFile("", "oauth2_secret") + if err != nil { + t.Fatal(err) + } + defer os.Remove(secretFile.Name()) + + var yamlConfig = fmt.Sprintf(` +client_id: 1 +client_secret_file: %s +scopes: + - A + - B +token_url: %s +endpoint_params: + hi: hello +`, secretFile.Name(), tokenTS.URL) + expectedConfig := OAuth2{ + ClientID: "1", + ClientSecretFile: secretFile.Name(), + Scopes: []string{"A", "B"}, + EndpointParams: map[string]string{"hi": "hello"}, + TokenURL: tokenTS.URL, + } + + var unmarshalledConfig OAuth2 + err = yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) + if err != nil { + t.Fatalf("Expected no error unmarshalling yaml, got %v", err) + } + if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { + t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) + } + + rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) + + client := http.Client{ + Transport: rt, + } + + tk := "Basic MToxMjM0NTY=" + expectedAuth = &tk + if _, err := secretFile.Write([]byte("123456")); err != nil { + t.Fatal(err) + } + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + authorization := resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) + } + + // Making a second request with the same file content should not re-call the token API. + resp, err = client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + tk = "Basic MToxMjM0NTY3" + expectedAuth = &tk + if _, err := secretFile.Write([]byte("7")); err != nil { + t.Fatal(err) + } + + _, err = client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + // Making a second request with the same file content should not re-call the token API. + _, err = client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + authorization = resp.Request.Header.Get("Authorization") + if authorization != "Bearer 12345" { + t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) + } +} diff --git a/config/testdata/http.conf.oauth2-secret-and-file-set.bad.yml b/config/testdata/http.conf.oauth2-secret-and-file-set.bad.yml new file mode 100644 index 00000000..d9fc6f84 --- /dev/null +++ b/config/testdata/http.conf.oauth2-secret-and-file-set.bad.yml @@ -0,0 +1,3 @@ +oauth2: + client_secret: "mysecret" + client_secret_file: "mysecret"