Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): add idtoken package #8580

Merged
merged 8 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions auth/idtoken/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
)

type cachingClient struct {
client *http.Client

// clock optionally specifies a func to return the current time.
// If nil, time.Now is used.
clock func() time.Time

mu sync.Mutex
certs map[string]*cachedResponse
}

func newCachingClient(client *http.Client) *cachingClient {
return &cachingClient{
client: client,
certs: make(map[string]*cachedResponse, 2),
}
}

type cachedResponse struct {
resp *certResponse
exp time.Time
}

func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
if response, ok := c.get(url); ok {
return response, nil
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
}

certResp := &certResponse{}
if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
return nil, err

}
c.set(url, certResp, resp.Header)
return certResp, nil
}

func (c *cachingClient) now() time.Time {
if c.clock != nil {
return c.clock()
}
return time.Now()
}

func (c *cachingClient) get(url string) (*certResponse, bool) {
c.mu.Lock()
defer c.mu.Unlock()
cachedResp, ok := c.certs[url]
if !ok {
return nil, false
}
if c.now().After(cachedResp.exp) {
return nil, false
}
return cachedResp.resp, true
}

func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
exp := c.calculateExpireTime(headers)
c.mu.Lock()
c.certs[url] = &cachedResponse{resp: resp, exp: exp}
c.mu.Unlock()
}

// calculateExpireTime will determine the expire time for the cache based on
// HTTP headers. If there is any difficulty reading the headers the fallback is
// to set the cache to expire now.
func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
var maxAge int
cc := strings.Split(headers.Get("cache-control"), ",")
for _, v := range cc {
if strings.Contains(v, "max-age") {
ss := strings.Split(v, "=")
if len(ss) < 2 {
return c.now()
}
ma, err := strconv.Atoi(ss[1])
if err != nil {
return c.now()
}
maxAge = ma
}
}
a := headers.Get("age")
if a == "" {
return c.now().Add(time.Duration(maxAge) * time.Second)
}
age, err := strconv.Atoi(a)
if err != nil {
return c.now()
}
return c.now().Add(time.Duration(maxAge-age) * time.Second)
}
82 changes: 82 additions & 0 deletions auth/idtoken/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"net/http"
"sync"
"testing"
"time"
)

type fakeClock struct {
mu sync.Mutex
t time.Time
}

func (c *fakeClock) Now() time.Time {
c.mu.Lock()
defer c.mu.Unlock()
return c.t
}

func (c *fakeClock) Sleep(d time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.t = c.t.Add(d)
}

func TestCacheHit(t *testing.T) {
clock := &fakeClock{t: time.Now()}
fakeResp := &certResponse{
Keys: []jwk{
{
Kid: "123",
},
},
}
cache := newCachingClient(nil)
cache.clock = clock.Now

// Cache should be empty
cert, ok := cache.get(googleSACertsURL)
if ok || cert != nil {
t.Fatal("cache for SA certs should be empty")
}

// Add an item, but make it expire now
cache.set(googleSACertsURL, fakeResp, make(http.Header))
clock.Sleep(time.Nanosecond) // it expires when current time is > expiration, not >=
cert, ok = cache.get(googleSACertsURL)
if ok || cert != nil {
t.Fatal("cache for SA certs should be expired")
}

// Add an item that expires in 1 seconds
h := make(http.Header)
h.Set("age", "0")
h.Set("cache-control", "public, max-age=1, must-revalidate, no-transform")
cache.set(googleSACertsURL, fakeResp, h)
cert, ok = cache.get(googleSACertsURL)
if !ok || cert == nil || cert.Keys[0].Kid != "123" {
t.Fatal("cache for SA certs have a resp")
}
// Wait
clock.Sleep(2 * time.Second)
cert, ok = cache.get(googleSACertsURL)
if ok || cert != nil {
t.Fatal("cache for SA certs should be expired")
}
}
77 changes: 77 additions & 0 deletions auth/idtoken/compute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"context"
"fmt"
"net/url"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/auth/internal"
"cloud.google.com/go/compute/metadata"
)

const identitySuffix = "instance/service-accounts/default/identity"

// computeTokenProvider checks if this code is being run on GCE. If it is, it
// will use the metadata service to build a TokenProvider that fetches ID
// tokens.
func computeTokenProvider(opts *Options) (auth.TokenProvider, error) {
if opts.CustomClaims != nil {
return nil, fmt.Errorf("idtoken: Options.CustomClaims can't be used with the metadata service, please provide a service account if you would like to use this feature")
}
tp := computeIDTokenProvider{
audience: opts.Audience,
format: opts.ComputeTokenFormat,
client: *metadata.NewClient(opts.client()),
}
return auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{
ExpireEarly: 5 * time.Minute,
}), nil
}

type computeIDTokenProvider struct {
audience string
format ComputeTokenFormat
client metadata.Client
}

func (c computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
v := url.Values{}
v.Set("audience", c.audience)
if c.format != ComputeTokenFormatStandard {
codyoss marked this conversation as resolved.
Show resolved Hide resolved
v.Set("format", "full")
}
if c.format == ComputeTokenFormatFullWithLicense {
v.Set("licenses", "TRUE")
codyoss marked this conversation as resolved.
Show resolved Hide resolved
}
urlSuffix := identitySuffix + "?" + v.Encode()
res, err := c.client.Get(urlSuffix)
if err != nil {
return nil, err
}
if res == "" {
return nil, fmt.Errorf("idtoken: invalid empty response from metadata service")
}
return &auth.Token{
Value: res,
Type: internal.TokenTypeBearer,
// Compute tokens are valid for one hour:
// https://cloud.google.com/iam/docs/create-short-lived-credentials-direct#create-id
Expiry: time.Now().Add(1 * time.Hour),
}, nil
}
102 changes: 102 additions & 0 deletions auth/idtoken/compute_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

const metadataHostEnv = "GCE_METADATA_HOST"

func TestComputeTokenSource(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, identitySuffix) {
t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix)
}
if got, want := r.URL.Query().Get("audience"), "aud"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("format"), "full"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("licenses"), "TRUE"; got != want {
t.Errorf("got %q, want %q", got, want)
}
w.Write([]byte(`fake_token`))
}))
defer ts.Close()
t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://"))
tp, err := computeTokenProvider(&Options{
Audience: "aud",
ComputeTokenFormat: ComputeTokenFormatFullWithLicense,
})
if err != nil {
t.Fatalf("computeTokenProvider() = %v", err)
}
tok, err := tp.Token(context.Background())
if err != nil {
t.Fatalf("tp.Token() = %v", err)
}
if want := "fake_token"; tok.Value != want {
t.Errorf("got %q, want %q", tok.Value, want)
}
}

func TestComputeTokenSource_Standard(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, identitySuffix) {
t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix)
}
if got, want := r.URL.Query().Get("audience"), "aud"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("format"), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("licenses"), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
w.Write([]byte(`fake_token`))
}))
defer ts.Close()
t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://"))
tp, err := computeTokenProvider(&Options{
Audience: "aud",
ComputeTokenFormat: ComputeTokenFormatStandard,
})
if err != nil {
t.Fatalf("computeTokenProvider() = %v", err)
}
tok, err := tp.Token(context.Background())
if err != nil {
t.Fatalf("tp.Token() = %v", err)
}
if want := "fake_token"; tok.Value != want {
t.Errorf("got %q, want %q", tok.Value, want)
}
}

func TestComputeTokenSource_Invalid(t *testing.T) {
if _, err := computeTokenProvider(&Options{
Audience: "aud",
CustomClaims: map[string]interface{}{"foo": "bar"},
}); err == nil {
t.Fatal("computeTokenProvider() = nil, expected non-nil error", err)
}
}
Loading