diff --git a/auth/idtoken/cache.go b/auth/idtoken/cache.go new file mode 100644 index 000000000000..6eb6d3b4445f --- /dev/null +++ b/auth/idtoken/cache.go @@ -0,0 +1,133 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +type cachingClient struct { + client *http.Client + + // clock optionally specifies a func to return the current time. + // If nil, time.Now is used. + clock func() time.Time + + mu sync.Mutex + certs map[string]*cachedResponse +} + +func newCachingClient(client *http.Client) *cachingClient { + return &cachingClient{ + client: client, + certs: make(map[string]*cachedResponse, 2), + } +} + +type cachedResponse struct { + resp *certResponse + exp time.Time +} + +func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) { + if response, ok := c.get(url); ok { + return response, nil + } + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode) + } + + certResp := &certResponse{} + if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil { + return nil, err + + } + c.set(url, certResp, resp.Header) + return certResp, nil +} + +func (c *cachingClient) now() time.Time { + if c.clock != nil { + return c.clock() + } + return time.Now() +} + +func (c *cachingClient) get(url string) (*certResponse, bool) { + c.mu.Lock() + defer c.mu.Unlock() + cachedResp, ok := c.certs[url] + if !ok { + return nil, false + } + if c.now().After(cachedResp.exp) { + return nil, false + } + return cachedResp.resp, true +} + +func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) { + exp := c.calculateExpireTime(headers) + c.mu.Lock() + c.certs[url] = &cachedResponse{resp: resp, exp: exp} + c.mu.Unlock() +} + +// calculateExpireTime will determine the expire time for the cache based on +// HTTP headers. If there is any difficulty reading the headers the fallback is +// to set the cache to expire now. +func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time { + var maxAge int + cc := strings.Split(headers.Get("cache-control"), ",") + for _, v := range cc { + if strings.Contains(v, "max-age") { + ss := strings.Split(v, "=") + if len(ss) < 2 { + return c.now() + } + ma, err := strconv.Atoi(ss[1]) + if err != nil { + return c.now() + } + maxAge = ma + } + } + a := headers.Get("age") + if a == "" { + return c.now().Add(time.Duration(maxAge) * time.Second) + } + age, err := strconv.Atoi(a) + if err != nil { + return c.now() + } + return c.now().Add(time.Duration(maxAge-age) * time.Second) +} diff --git a/auth/idtoken/cache_test.go b/auth/idtoken/cache_test.go new file mode 100644 index 000000000000..adfb5d665662 --- /dev/null +++ b/auth/idtoken/cache_test.go @@ -0,0 +1,82 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "net/http" + "sync" + "testing" + "time" +) + +type fakeClock struct { + mu sync.Mutex + t time.Time +} + +func (c *fakeClock) Now() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.t +} + +func (c *fakeClock) Sleep(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.t = c.t.Add(d) +} + +func TestCacheHit(t *testing.T) { + clock := &fakeClock{t: time.Now()} + fakeResp := &certResponse{ + Keys: []jwk{ + { + Kid: "123", + }, + }, + } + cache := newCachingClient(nil) + cache.clock = clock.Now + + // Cache should be empty + cert, ok := cache.get(googleSACertsURL) + if ok || cert != nil { + t.Fatal("cache for SA certs should be empty") + } + + // Add an item, but make it expire now + cache.set(googleSACertsURL, fakeResp, make(http.Header)) + clock.Sleep(time.Nanosecond) // it expires when current time is > expiration, not >= + cert, ok = cache.get(googleSACertsURL) + if ok || cert != nil { + t.Fatal("cache for SA certs should be expired") + } + + // Add an item that expires in 1 seconds + h := make(http.Header) + h.Set("age", "0") + h.Set("cache-control", "public, max-age=1, must-revalidate, no-transform") + cache.set(googleSACertsURL, fakeResp, h) + cert, ok = cache.get(googleSACertsURL) + if !ok || cert == nil || cert.Keys[0].Kid != "123" { + t.Fatal("cache for SA certs have a resp") + } + // Wait + clock.Sleep(2 * time.Second) + cert, ok = cache.get(googleSACertsURL) + if ok || cert != nil { + t.Fatal("cache for SA certs should be expired") + } +} diff --git a/auth/idtoken/compute.go b/auth/idtoken/compute.go new file mode 100644 index 000000000000..e9ebc1c8162a --- /dev/null +++ b/auth/idtoken/compute.go @@ -0,0 +1,77 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "context" + "fmt" + "net/url" + "time" + + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/internal" + "cloud.google.com/go/compute/metadata" +) + +const identitySuffix = "instance/service-accounts/default/identity" + +// computeTokenProvider checks if this code is being run on GCE. If it is, it +// will use the metadata service to build a TokenProvider that fetches ID +// tokens. +func computeTokenProvider(opts *Options) (auth.TokenProvider, error) { + if opts.CustomClaims != nil { + return nil, fmt.Errorf("idtoken: Options.CustomClaims can't be used with the metadata service, please provide a service account if you would like to use this feature") + } + tp := computeIDTokenProvider{ + audience: opts.Audience, + format: opts.ComputeTokenFormat, + client: *metadata.NewClient(opts.client()), + } + return auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{ + ExpireEarly: 5 * time.Minute, + }), nil +} + +type computeIDTokenProvider struct { + audience string + format ComputeTokenFormat + client metadata.Client +} + +func (c computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error) { + v := url.Values{} + v.Set("audience", c.audience) + if c.format != ComputeTokenFormatStandard { + v.Set("format", "full") + } + if c.format == ComputeTokenFormatFullWithLicense { + v.Set("licenses", "TRUE") + } + urlSuffix := identitySuffix + "?" + v.Encode() + res, err := c.client.Get(urlSuffix) + if err != nil { + return nil, err + } + if res == "" { + return nil, fmt.Errorf("idtoken: invalid empty response from metadata service") + } + return &auth.Token{ + Value: res, + Type: internal.TokenTypeBearer, + // Compute tokens are valid for one hour: + // https://cloud.google.com/iam/docs/create-short-lived-credentials-direct#create-id + Expiry: time.Now().Add(1 * time.Hour), + }, nil +} diff --git a/auth/idtoken/compute_test.go b/auth/idtoken/compute_test.go new file mode 100644 index 000000000000..a0ff4fc8a1d9 --- /dev/null +++ b/auth/idtoken/compute_test.go @@ -0,0 +1,102 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +const metadataHostEnv = "GCE_METADATA_HOST" + +func TestComputeTokenSource(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, identitySuffix) { + t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix) + } + if got, want := r.URL.Query().Get("audience"), "aud"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if got, want := r.URL.Query().Get("format"), "full"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if got, want := r.URL.Query().Get("licenses"), "TRUE"; got != want { + t.Errorf("got %q, want %q", got, want) + } + w.Write([]byte(`fake_token`)) + })) + defer ts.Close() + t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://")) + tp, err := computeTokenProvider(&Options{ + Audience: "aud", + ComputeTokenFormat: ComputeTokenFormatFullWithLicense, + }) + if err != nil { + t.Fatalf("computeTokenProvider() = %v", err) + } + tok, err := tp.Token(context.Background()) + if err != nil { + t.Fatalf("tp.Token() = %v", err) + } + if want := "fake_token"; tok.Value != want { + t.Errorf("got %q, want %q", tok.Value, want) + } +} + +func TestComputeTokenSource_Standard(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.URL.Path, identitySuffix) { + t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix) + } + if got, want := r.URL.Query().Get("audience"), "aud"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if got, want := r.URL.Query().Get("format"), ""; got != want { + t.Errorf("got %q, want %q", got, want) + } + if got, want := r.URL.Query().Get("licenses"), ""; got != want { + t.Errorf("got %q, want %q", got, want) + } + w.Write([]byte(`fake_token`)) + })) + defer ts.Close() + t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://")) + tp, err := computeTokenProvider(&Options{ + Audience: "aud", + ComputeTokenFormat: ComputeTokenFormatStandard, + }) + if err != nil { + t.Fatalf("computeTokenProvider() = %v", err) + } + tok, err := tp.Token(context.Background()) + if err != nil { + t.Fatalf("tp.Token() = %v", err) + } + if want := "fake_token"; tok.Value != want { + t.Errorf("got %q, want %q", tok.Value, want) + } +} + +func TestComputeTokenSource_Invalid(t *testing.T) { + if _, err := computeTokenProvider(&Options{ + Audience: "aud", + CustomClaims: map[string]interface{}{"foo": "bar"}, + }); err == nil { + t.Fatal("computeTokenProvider() = nil, expected non-nil error", err) + } +} diff --git a/auth/idtoken/examples_test.go b/auth/idtoken/examples_test.go new file mode 100644 index 000000000000..2e605dc0aa00 --- /dev/null +++ b/auth/idtoken/examples_test.go @@ -0,0 +1,43 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken_test + +import ( + "context" + "net/http" + + "cloud.google.com/go/auth/httptransport" + "cloud.google.com/go/auth/idtoken" +) + +func ExampleNewTokenProvider_setAuthorizationHeader() { + ctx := context.Background() + audience := "http://example.com" + tp, err := idtoken.NewTokenProvider(&idtoken.Options{ + Audience: audience, + }) + if err != nil { + // Handle error. + } + token, err := tp.Token(ctx) + if err != nil { + // Handle error. + } + req, err := http.NewRequest(http.MethodGet, audience, nil) + if err != nil { + // Handle error. + } + httptransport.SetAuthHeader(token, req) +} diff --git a/auth/idtoken/file.go b/auth/idtoken/file.go new file mode 100644 index 000000000000..c904ba1094f1 --- /dev/null +++ b/auth/idtoken/file.go @@ -0,0 +1,110 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "encoding/json" + "fmt" + "path/filepath" + "strings" + + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/detect" + "cloud.google.com/go/auth/impersonate" + "cloud.google.com/go/auth/internal/internaldetect" +) + +const ( + jwtTokenURL = "https://oauth2.googleapis.com/token" + iamCredAud = "https://iamcredentials.googleapis.com/" +) + +var ( + defaultScopes = []string{ + "https://iamcredentials.googleapis.com/", + "https://www.googleapis.com/auth/cloud-platform", + } +) + +func tokenProviderFromBytes(b []byte, opts *Options) (auth.TokenProvider, error) { + t, err := internaldetect.ParseFileType(b) + if err != nil { + return nil, err + } + switch t { + case internaldetect.ServiceAccountKey: + f, err := internaldetect.ParseServiceAccount(b) + if err != nil { + return nil, err + } + opts2LO := &auth.Options2LO{ + Email: f.ClientEmail, + PrivateKey: []byte(f.PrivateKey), + PrivateKeyID: f.PrivateKeyID, + TokenURL: f.TokenURL, + UseIDToken: true, + } + if opts2LO.TokenURL == "" { + opts2LO.TokenURL = jwtTokenURL + } + + var customClaims map[string]interface{} + if opts != nil { + customClaims = opts.CustomClaims + } + if customClaims == nil { + customClaims = make(map[string]interface{}) + } + customClaims["target_audience"] = opts.Audience + + opts2LO.PrivateClaims = customClaims + tp, err := auth.New2LOTokenProvider(opts2LO) + if err != nil { + return nil, err + } + return auth.NewCachedTokenProvider(tp, nil), nil + case internaldetect.ImpersonatedServiceAccountKey, internaldetect.ExternalAccountKey: + type url struct { + ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` + } + var accountURL url + if err := json.Unmarshal(b, &accountURL); err != nil { + return nil, err + } + account := filepath.Base(accountURL.ServiceAccountImpersonationURL) + account = strings.Split(account, ":")[0] + + creds, err := detect.DefaultCredentials(&detect.Options{ + Scopes: defaultScopes, + CredentialsJSON: b, + Client: opts.client(), + UseSelfSignedJWT: true, + }) + if err != nil { + return nil, err + } + + config := impersonate.IDTokenOptions{ + Audience: opts.Audience, + TargetPrincipal: account, + IncludeEmail: true, + Client: opts.client(), + TokenProvider: creds, + } + return impersonate.NewIDTokenProvider(&config) + default: + return nil, fmt.Errorf("idtoken: unsupported credentials type: %v", t) + } +} diff --git a/auth/idtoken/idtoken.go b/auth/idtoken/idtoken.go new file mode 100644 index 000000000000..4439dce76fc9 --- /dev/null +++ b/auth/idtoken/idtoken.go @@ -0,0 +1,122 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "fmt" + "net/http" + "os" + + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/internal" + "cloud.google.com/go/auth/internal/internaldetect" + "cloud.google.com/go/compute/metadata" +) + +// ComputeTokenFormat dictates the the token format when requesting an ID token +// from the compute metadata service. +type ComputeTokenFormat int + +const ( + // ComputeTokenFormatDefault means the same as [ComputeTokenFormatFull]. + ComputeTokenFormatDefault ComputeTokenFormat = iota + // ComputeTokenFormatStandard mean only standard JWT fields will be included + // in the token. + ComputeTokenFormatStandard + // ComputeTokenFormatFull means the token will include claims about the + // virtual machine instance and its project. + ComputeTokenFormatFull + // ComputeTokenFormatFullWithLicense means the same as + // [ComputeTokenFormatFull] with the addition of claims about licenses + // associated with the instance. + ComputeTokenFormatFullWithLicense +) + +// Options for the configuration of creation of an ID token with +// [NewTokenProvider]. +type Options struct { + // Audience is the `aud` field for the token, such as an API endpoint the + // token will grant access to. Required. + Audience string + // ComputeTokenFormat dictates the the token format when requesting an ID + // token from the compute metadata service. Optional. + ComputeTokenFormat ComputeTokenFormat + // CustomClaims specifies private non-standard claims for an ID token. + // Optional. + CustomClaims map[string]interface{} + + // CredentialsFile overrides detection logic and sources a credential file + // from the provided filepath. Optional. + CredentialsFile string + // CredentialsJSON overrides detection logic and uses the JSON bytes as the + // source for the credential. Optional. + CredentialsJSON []byte + // Client configures the underlying client used to make network requests + // when fetching tokens. If provided this should be a fully authenticated + // client. Optional. + Client *http.Client +} + +func (o *Options) client() *http.Client { + if o == nil || o.Client == nil { + return internal.CloneDefaultClient() + } + return o.Client +} + +// NewTokenProvider creates a [cloud.google.com/go/auth.TokenProvider] that +// returns ID tokens configured by the opts provided. The parameter +// opts.Audience may not be empty. +func NewTokenProvider(opts *Options) (auth.TokenProvider, error) { + if opts == nil { + return nil, fmt.Errorf("idtoken: opts must be provided") + } + if opts.Audience == "" { + return nil, fmt.Errorf("idtoken: must supply a non-empty audience") + } + if b := opts.jsonBytes(); b != nil { + return tokenProviderFromBytes(b, opts) + } + if metadata.OnGCE() { + return computeTokenProvider(opts) + } + return nil, fmt.Errorf("idtoken: couldn't find any credentials") +} + +func (opts *Options) jsonBytes() []byte { + if opts.CredentialsJSON != nil { + return opts.CredentialsJSON + } + var fnOverride string + if opts != nil { + fnOverride = opts.CredentialsFile + } + filename := internaldetect.GetFileNameFromEnv(fnOverride) + if filename != "" { + b, _ := os.ReadFile(filename) + return b + } + return nil +} + +// Payload represents a decoded payload of an ID token. +type Payload struct { + Issuer string `json:"iss"` + Audience string `json:"aud"` + Expires int64 `json:"exp"` + IssuedAt int64 `json:"iat"` + Subject string `json:"sub,omitempty"` + Claims map[string]interface{} `json:"-"` +} diff --git a/auth/idtoken/idtoken_test.go b/auth/idtoken/idtoken_test.go new file mode 100644 index 000000000000..2e2a6b12e086 --- /dev/null +++ b/auth/idtoken/idtoken_test.go @@ -0,0 +1,106 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "cloud.google.com/go/auth/internal" + "cloud.google.com/go/auth/internal/internaldetect" +) + +func TestNewTokenProvider_ServiceAccount(t *testing.T) { + wantTok, _ := createRS256JWT(t) + b, err := os.ReadFile("../internal/testdata/sa.json") + if err != nil { + t.Fatal(err) + } + f, err := internaldetect.ParseServiceAccount(b) + if err != nil { + t.Fatal(err) + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{"id_token": "%s"}`, wantTok))) + })) + defer ts.Close() + f.TokenURL = ts.URL + b, err = json.Marshal(f) + if err != nil { + t.Fatal(err) + } + + tp, err := NewTokenProvider(&Options{ + Audience: "aud", + CredentialsJSON: b, + CustomClaims: map[string]interface{}{ + "foo": "bar", + }, + }) + if err != nil { + t.Fatal(err) + } + tok, err := tp.Token(context.Background()) + if err != nil { + t.Fatalf("tp.Token() = %v", err) + } + if tok.Value != wantTok { + t.Errorf("got %q, want %q", tok.Value, wantTok) + } +} + +type mockTransport struct { + handler http.HandlerFunc +} + +func (m mockTransport) RoundTrip(r *http.Request) (*http.Response, error) { + rw := httptest.NewRecorder() + m.handler(rw, r) + return rw.Result(), nil +} + +func TestNewTokenProvider_ImpersonatedServiceAccount(t *testing.T) { + wantTok, _ := createRS256JWT(t) + client := internal.CloneDefaultClient() + client.Transport = mockTransport{ + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf(`{"token": %q}`, wantTok))) + }), + } + tp, err := NewTokenProvider(&Options{ + Audience: "aud", + CredentialsFile: "../internal/testdata/imp.json", + CustomClaims: map[string]interface{}{ + "foo": "bar", + }, + Client: client, + }) + if err != nil { + t.Fatal(err) + } + tok, err := tp.Token(context.Background()) + if err != nil { + t.Fatalf("tp.Token() = %v", err) + } + if tok.Value != wantTok { + t.Errorf("got %q, want %q", tok.Value, wantTok) + } +} diff --git a/auth/idtoken/integration_test.go b/auth/idtoken/integration_test.go new file mode 100644 index 000000000000..92dd451c2a57 --- /dev/null +++ b/auth/idtoken/integration_test.go @@ -0,0 +1,88 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken_test + +import ( + "context" + "log" + "net/http" + "os" + "strings" + "testing" + + "cloud.google.com/go/auth/httptransport" + "cloud.google.com/go/auth/idtoken" + "cloud.google.com/go/auth/internal/testutil" +) + +const ( + envCredentialFile = "GOOGLE_APPLICATION_CREDENTIALS" + aud = "http://example.com" +) + +func TestNewTokenProvider_CredentialsFile(t *testing.T) { + testutil.IntegrationTestCheck(t) + ctx := context.Background() + ts, err := idtoken.NewTokenProvider(&idtoken.Options{ + Audience: "http://example.com", + CredentialsFile: os.Getenv(envCredentialFile), + }) + if err != nil { + t.Fatalf("unable to create TokenSource: %v", err) + } + tok, err := ts.Token(ctx) + if err != nil { + t.Fatalf("unable to retrieve Token: %v", err) + } + req := &http.Request{Header: make(http.Header)} + httptransport.SetAuthHeader(tok, req) + if !strings.HasPrefix(req.Header.Get("Authorization"), "Bearer ") { + t.Fatalf("token should sign requests with Bearer Authorization header") + } + validTok, err := idtoken.Validate(context.Background(), tok.Value, aud) + if err != nil { + t.Fatalf("token validation failed: %v", err) + } + if validTok.Audience != aud { + t.Fatalf("got %q, want %q", validTok.Audience, aud) + } +} + +func TestNewTokenProvider_CredentialsJSON(t *testing.T) { + testutil.IntegrationTestCheck(t) + ctx := context.Background() + b, err := os.ReadFile(os.Getenv(envCredentialFile)) + if err != nil { + log.Fatal(err) + } + tp, err := idtoken.NewTokenProvider(&idtoken.Options{ + Audience: aud, + CredentialsJSON: b, + }) + if err != nil { + t.Fatalf("unable to create Client: %v", err) + } + tok, err := tp.Token(ctx) + if err != nil { + t.Fatalf("unable to retrieve Token: %v", err) + } + validTok, err := idtoken.Validate(context.Background(), tok.Value, aud) + if err != nil { + t.Fatalf("token validation failed: %v", err) + } + if validTok.Audience != aud { + t.Fatalf("got %q, want %q", validTok.Audience, aud) + } +} diff --git a/auth/idtoken/validate.go b/auth/idtoken/validate.go new file mode 100644 index 000000000000..d653bf2c1899 --- /dev/null +++ b/auth/idtoken/validate.go @@ -0,0 +1,269 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "strings" + "time" + + "cloud.google.com/go/auth/internal" + "cloud.google.com/go/auth/internal/jwt" +) + +const ( + es256KeySize int = 32 + googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk" + googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs" +) + +var ( + defaultValidator = &Validator{client: newCachingClient(internal.CloneDefaultClient())} + // now aliases time.Now for testing. + now = time.Now +) + +// certResponse represents a list jwks. It is the format returned from known +// Google cert endpoints. +type certResponse struct { + Keys []jwk `json:"keys"` +} + +// jwk is a simplified representation of a standard jwk. It only includes the +// fields used by Google's cert endpoints. +type jwk struct { + Alg string `json:"alg"` + Crv string `json:"crv"` + Kid string `json:"kid"` + Kty string `json:"kty"` + Use string `json:"use"` + E string `json:"e"` + N string `json:"n"` + X string `json:"x"` + Y string `json:"y"` +} + +// Validator provides a way to validate Google ID Tokens +type Validator struct { + client *cachingClient +} + +// ValidatorOptions provides a way to configure a [Validator]. +type ValidatorOptions struct { + // Client used to make requests to the certs URL. Optional. + Client *http.Client +} + +// NewValidator creates a Validator that uses the options provided to configure +// a the internal http.Client that will be used to make requests to fetch JWKs. +func NewValidator(opts *ValidatorOptions) (*Validator, error) { + var client *http.Client + if opts != nil && opts.Client != nil { + client = opts.Client + } else { + client = internal.CloneDefaultClient() + } + return &Validator{client: newCachingClient(client)}, nil +} + +// Validate is used to validate the provided idToken with a known Google cert +// URL. If audience is not empty the audience claim of the Token is validated. +// Upon successful validation a parsed token Payload is returned allowing the +// caller to validate any additional claims. +func (v *Validator) Validate(ctx context.Context, idToken string, audience string) (*Payload, error) { + return v.validate(ctx, idToken, audience) +} + +// Validate is used to validate the provided idToken with a known Google cert +// URL. If audience is not empty the audience claim of the Token is validated. +// Upon successful validation a parsed token Payload is returned allowing the +// caller to validate any additional claims. +func Validate(ctx context.Context, idToken string, audience string) (*Payload, error) { + return defaultValidator.validate(ctx, idToken, audience) +} + +// ParsePayload parses the given token and returns its payload. +// +// Warning: This function does not validate the token prior to parsing it. +// +// ParsePayload is primarily meant to be used to inspect a token's payload. This is +// useful when validation fails and the payload needs to be inspected. +// +// Note: A successful Validate() invocation with the same token will return an +// identical payload. +func ParsePayload(idToken string) (*Payload, error) { + _, payload, _, err := parseToken(idToken) + if err != nil { + return nil, err + } + return payload, nil +} + +func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) { + header, payload, sig, err := parseToken(idToken) + if err != nil { + return nil, err + } + + if audience != "" && payload.Audience != audience { + return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT") + } + + if now().Unix() > payload.Expires { + return nil, fmt.Errorf("idtoken: token expired: now=%v, expires=%v", now().Unix(), payload.Expires) + } + hashedContent := hashHeaderPayload(idToken) + switch header.Algorithm { + case jwt.HeaderAlgRSA256: + if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig); err != nil { + return nil, err + } + case "ES256": + if err := v.validateES256(ctx, header.KeyID, hashedContent, sig); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("idtoken: expected JWT signed with RS256 or ES256 but found %q", header.Algorithm) + } + + return payload, nil +} + +func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { + certResp, err := v.client.getCert(ctx, googleSACertsURL) + if err != nil { + return err + } + j, err := findMatchingKey(certResp, keyID) + if err != nil { + return err + } + dn, err := decode(j.N) + if err != nil { + return err + } + de, err := decode(j.E) + if err != nil { + return err + } + + pk := &rsa.PublicKey{ + N: new(big.Int).SetBytes(dn), + E: int(new(big.Int).SetBytes(de).Int64()), + } + return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig) +} + +func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error { + certResp, err := v.client.getCert(ctx, googleIAPCertsURL) + if err != nil { + return err + } + j, err := findMatchingKey(certResp, keyID) + if err != nil { + return err + } + dx, err := decode(j.X) + if err != nil { + return err + } + dy, err := decode(j.Y) + if err != nil { + return err + } + + pk := &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(dx), + Y: new(big.Int).SetBytes(dy), + } + r := big.NewInt(0).SetBytes(sig[:es256KeySize]) + s := big.NewInt(0).SetBytes(sig[es256KeySize:]) + if valid := ecdsa.Verify(pk, hashedContent, r, s); !valid { + return fmt.Errorf("idtoken: ES256 signature not valid") + } + return nil +} + +func findMatchingKey(response *certResponse, keyID string) (*jwk, error) { + if response == nil { + return nil, fmt.Errorf("idtoken: cert response is nil") + } + for _, v := range response.Keys { + if v.Kid == keyID { + return &v, nil + } + } + return nil, fmt.Errorf("idtoken: could not find matching cert keyId for the token provided") +} + +func parseToken(idToken string) (*jwt.Header, *Payload, []byte, error) { + segments := strings.Split(idToken, ".") + if len(segments) != 3 { + return nil, nil, nil, fmt.Errorf("idtoken: invalid token, token must have three segments; found %d", len(segments)) + } + // Header + dh, err := decode(segments[0]) + if err != nil { + return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT header: %v", err) + } + var header *jwt.Header + err = json.Unmarshal(dh, &header) + if err != nil { + return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT header: %v", err) + } + + // Payload + dp, err := decode(segments[1]) + if err != nil { + return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT claims: %v", err) + } + var payload *Payload + if err := json.Unmarshal(dp, &payload); err != nil { + return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err) + } + if err := json.Unmarshal(dp, &payload.Claims); err != nil { + return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload claims: %v", err) + } + + // Signature + signature, err := decode(segments[2]) + if err != nil { + return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT signature: %v", err) + } + return header, payload, signature, nil +} + +// hashHeaderPayload gets the SHA256 checksum for verification of the JWT. +func hashHeaderPayload(idtoken string) []byte { + // remove the sig from the token + content := idtoken[:strings.LastIndex(idtoken, ".")] + hashed := sha256.Sum256([]byte(content)) + return hashed[:] +} + +func decode(s string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(s) +} diff --git a/auth/idtoken/validate_test.go b/auth/idtoken/validate_test.go new file mode 100644 index 000000000000..afabe814c3a2 --- /dev/null +++ b/auth/idtoken/validate_test.go @@ -0,0 +1,354 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package idtoken + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "testing" + "time" + + "cloud.google.com/go/auth/internal/jwt" +) + +const ( + keyID = "1234" + testAudience = "test-audience" + expiry int64 = 233431200 +) + +var ( + beforeExp = func() time.Time { return time.Unix(expiry-1, 0) } + afterExp = func() time.Time { return time.Unix(expiry+1, 0) } +) + +func TestValidateRS256(t *testing.T) { + idToken, pk := createRS256JWT(t) + tests := []struct { + name string + keyID string + n *big.Int + e int + nowFunc func() time.Time + wantErr bool + }{ + { + name: "works", + keyID: keyID, + n: pk.N, + e: pk.E, + nowFunc: beforeExp, + wantErr: false, + }, + { + name: "no matching key", + keyID: "5678", + n: pk.N, + e: pk.E, + nowFunc: beforeExp, + wantErr: true, + }, + { + name: "sig does not match", + keyID: keyID, + n: new(big.Int).SetBytes([]byte("42")), + e: 42, + nowFunc: beforeExp, + wantErr: true, + }, + { + name: "token expired", + keyID: keyID, + n: pk.N, + e: pk.E, + nowFunc: afterExp, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{ + Transport: RoundTripFn(func(req *http.Request) *http.Response { + cr := certResponse{ + Keys: []jwk{ + { + Kid: tt.keyID, + N: base64.RawURLEncoding.EncodeToString(tt.n.Bytes()), + E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()), + }, + }, + } + b, err := json.Marshal(&cr) + if err != nil { + t.Fatalf("unable to marshal response: %v", err) + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(b)), + Header: make(http.Header), + } + }), + } + oldNow := now + defer func() { now = oldNow }() + now = tt.nowFunc + + v, err := NewValidator(&ValidatorOptions{ + Client: client, + }) + if err != nil { + t.Fatalf("NewValidator(...) = %q, want nil", err) + } + payload, err := v.Validate(context.Background(), idToken, testAudience) + if tt.wantErr && err != nil { + // Got the error we wanted. + return + } + if !tt.wantErr && err != nil { + t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err) + } + if tt.wantErr && err == nil { + t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience) + } + if payload == nil { + t.Fatalf("Got nil payload, err: %v", err) + } + if payload.Audience != testAudience { + t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience) + } + if len(payload.Claims) == 0 { + t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims) + } + if got, ok := payload.Claims["aud"]; !ok { + t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims) + } else { + got, ok := got.(string) + if !ok { + t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims) + } + if got != testAudience { + t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got) + } + } + }) + } +} + +func TestValidateES256(t *testing.T) { + idToken, pk := createES256JWT(t) + tests := []struct { + name string + keyID string + x *big.Int + y *big.Int + nowFunc func() time.Time + wantErr bool + }{ + { + name: "works", + keyID: keyID, + x: pk.X, + y: pk.Y, + nowFunc: beforeExp, + wantErr: false, + }, + { + name: "no matching key", + keyID: "5678", + x: pk.X, + y: pk.Y, + nowFunc: beforeExp, + wantErr: true, + }, + { + name: "sig does not match", + keyID: keyID, + x: new(big.Int), + y: new(big.Int), + nowFunc: beforeExp, + wantErr: true, + }, + { + name: "token expired", + keyID: keyID, + x: pk.X, + y: pk.Y, + nowFunc: afterExp, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{ + Transport: RoundTripFn(func(req *http.Request) *http.Response { + cr := certResponse{ + Keys: []jwk{ + { + Kid: tt.keyID, + X: base64.RawURLEncoding.EncodeToString(tt.x.Bytes()), + Y: base64.RawURLEncoding.EncodeToString(tt.y.Bytes()), + }, + }, + } + b, err := json.Marshal(&cr) + if err != nil { + t.Fatalf("unable to marshal response: %v", err) + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(b)), + Header: make(http.Header), + } + }), + } + oldNow := now + defer func() { now = oldNow }() + now = tt.nowFunc + + v, err := NewValidator(&ValidatorOptions{ + Client: client, + }) + if err != nil { + t.Fatalf("NewValidator(...) = %q, want nil", err) + } + payload, err := v.Validate(context.Background(), idToken, testAudience) + if !tt.wantErr && err != nil { + t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err) + } + if !tt.wantErr && payload.Audience != testAudience { + t.Fatalf("got %v, want %v", payload.Audience, testAudience) + } + }) + } +} + +func TestParsePayload(t *testing.T) { + idToken, _ := createRS256JWT(t) + tests := []struct { + name string + token string + wantPayloadAudience string + wantErr bool + }{{ + name: "valid token", + token: idToken, + wantPayloadAudience: testAudience, + }, { + name: "unparseable token", + token: "aaa.bbb.ccc", + wantErr: true, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := ParsePayload(tt.token) + gotErr := err != nil + if gotErr != tt.wantErr { + t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr) + } + if tt.wantPayloadAudience != "" { + if payload == nil || payload.Audience != tt.wantPayloadAudience { + t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience) + } + } + }) + } +} + +func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) { + t.Helper() + header, claims := commonToken(t, "ES256") + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("unable to generate key: %v", err) + } + signedContent := header + "." + claims + hashed := sha256.Sum256([]byte(signedContent)) + hash := hashed[:] + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash) + if err != nil { + t.Fatalf("unable to sign content: %v", err) + } + rb := r.Bytes() + lPadded := make([]byte, es256KeySize) + copy(lPadded[es256KeySize-len(rb):], rb) + var sig []byte + sig = append(sig, lPadded...) + sig = append(sig, s.Bytes()...) + signature := base64.RawURLEncoding.EncodeToString(sig) + return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey +} + +func createRS256JWT(t *testing.T) (string, rsa.PublicKey) { + t.Helper() + header, claims := commonToken(t, jwt.HeaderAlgRSA256) + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("unable to generate key: %v", err) + } + signedContent := header + "." + claims + hashed := sha256.Sum256([]byte(signedContent)) + hash := hashed[:] + sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hash) + if err != nil { + t.Fatalf("unable to sign content: %v", err) + } + signature := base64.RawURLEncoding.EncodeToString(sig) + return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey +} + +// returns header and claims +func commonToken(t *testing.T, alg string) (string, string) { + t.Helper() + header := jwt.Header{ + KeyID: keyID, + Algorithm: alg, + Type: jwt.HeaderType, + } + payload := Payload{ + Issuer: "example.com", + Audience: testAudience, + Expires: expiry, + } + + hb, err := json.Marshal(&header) + if err != nil { + t.Fatalf("unable to marshall header: %v", err) + } + pb, err := json.Marshal(&payload) + if err != nil { + t.Fatalf("unable to marshall payload: %v", err) + } + eb := base64.RawURLEncoding.EncodeToString(hb) + ep := base64.RawURLEncoding.EncodeToString(pb) + return eb, ep +} + +type RoundTripFn func(req *http.Request) *http.Response + +func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } diff --git a/auth/impersonate/integration_test.go b/auth/impersonate/integration_test.go index 72ce4f5aa9e9..0d4bc7d356f9 100644 --- a/auth/impersonate/integration_test.go +++ b/auth/impersonate/integration_test.go @@ -25,6 +25,7 @@ import ( "time" "cloud.google.com/go/auth/detect" + "cloud.google.com/go/auth/idtoken" "cloud.google.com/go/auth/impersonate" "cloud.google.com/go/auth/internal/testutil" "cloud.google.com/go/auth/internal/testutil/testgcs" @@ -117,63 +118,61 @@ func TestCredentialsTokenSourceIntegration(t *testing.T) { } } -// TODO(codyoss): uncomment in #8580 - -// func TestIDTokenSourceIntegration(t *testing.T) { -// testutil.IntegrationTestCheck(t) +func TestIDTokenSourceIntegration(t *testing.T) { + testutil.IntegrationTestCheck(t) -// ctx := context.Background() -// tests := []struct { -// name string -// baseKeyFile string -// delegates []string -// }{ -// { -// name: "SA -> SA", -// baseKeyFile: readerKeyFile, -// }, -// { -// name: "SA -> Delegate -> SA", -// baseKeyFile: baseKeyFile, -// delegates: []string{readerEmail}, -// }, -// } + ctx := context.Background() + tests := []struct { + name string + baseKeyFile string + delegates []string + }{ + { + name: "SA -> SA", + baseKeyFile: readerKeyFile, + }, + { + name: "SA -> Delegate -> SA", + baseKeyFile: baseKeyFile, + delegates: []string{readerEmail}, + }, + } -// for _, tt := range tests { -// name := tt.name -// t.Run(name, func(t *testing.T) { -// creds, err := detect.DefaultCredentials(&detect.Options{ -// Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"}, -// CredentialsFile: tt.baseKeyFile, -// }) -// if err != nil { -// t.Fatalf("detect.DefaultCredentials() = %v", err) -// } -// aud := "http://example.com/" -// tp, err := impersonate.NewIDTokenProvider(&impersonate.IDTokenOptions{ -// TargetPrincipal: writerEmail, -// Audience: aud, -// Delegates: tt.delegates, -// IncludeEmail: true, -// TokenProvider: creds, -// }) -// if err != nil { -// t.Fatalf("failed to create ts: %v", err) -// } -// tok, err := tp.Token(ctx) -// if err != nil { -// t.Fatalf("unable to retrieve Token: %v", err) -// } -// validTok, err := idtoken.Validate(ctx, tok.Value, aud) -// if err != nil { -// t.Fatalf("token validation failed: %v", err) -// } -// if validTok.Audience != aud { -// t.Fatalf("got %q, want %q", validTok.Audience, aud) -// } -// if validTok.Claims["email"] != writerEmail { -// t.Fatalf("got %q, want %q", validTok.Claims["email"], writerEmail) -// } -// }) -// } -// } + for _, tt := range tests { + name := tt.name + t.Run(name, func(t *testing.T) { + creds, err := detect.DefaultCredentials(&detect.Options{ + Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"}, + CredentialsFile: tt.baseKeyFile, + }) + if err != nil { + t.Fatalf("detect.DefaultCredentials() = %v", err) + } + aud := "http://example.com/" + tp, err := impersonate.NewIDTokenProvider(&impersonate.IDTokenOptions{ + TargetPrincipal: writerEmail, + Audience: aud, + Delegates: tt.delegates, + IncludeEmail: true, + TokenProvider: creds, + }) + if err != nil { + t.Fatalf("failed to create ts: %v", err) + } + tok, err := tp.Token(ctx) + if err != nil { + t.Fatalf("unable to retrieve Token: %v", err) + } + validTok, err := idtoken.Validate(ctx, tok.Value, aud) + if err != nil { + t.Fatalf("token validation failed: %v", err) + } + if validTok.Audience != aud { + t.Fatalf("got %q, want %q", validTok.Audience, aud) + } + if validTok.Claims["email"] != writerEmail { + t.Fatalf("got %q, want %q", validTok.Claims["email"], writerEmail) + } + }) + } +}