diff --git a/credentials/tls/certprovider/distributor.go b/credentials/tls/certprovider/distributor.go index 05ac1f7c7106..6a9116f1cb44 100644 --- a/credentials/tls/certprovider/distributor.go +++ b/credentials/tls/certprovider/distributor.go @@ -20,8 +20,6 @@ package certprovider import ( "context" - "crypto/tls" - "crypto/x509" "sync" "google.golang.org/grpc/internal/grpcsync" @@ -38,11 +36,13 @@ import ( // by the provider. // - When users of the provider call Close(), the channel returned by the // Done() method will be closed. So, provider implementations can select on -// the channel returned by Done() to perform cleanup work. +// the channel returned by Done() to perform cleanup work, or override +// Close(), in which case they must invoke Distributor.Close() when the +// provider is closed. type Distributor struct { - mu sync.Mutex - certs []tls.Certificate - roots *x509.CertPool + mu sync.Mutex + km *KeyMaterial + ready *grpcsync.Event closed *grpcsync.Event } @@ -58,8 +58,7 @@ func NewDistributor() *Distributor { // Set updates the key material in the distributor with km. func (d *Distributor) Set(km *KeyMaterial) { d.mu.Lock() - d.certs = km.Certs - d.roots = km.Roots + d.km = km d.ready.Fire() d.mu.Unlock() } @@ -70,7 +69,7 @@ func (d *Distributor) Set(km *KeyMaterial) { // arrives. func (d *Distributor) KeyMaterial(ctx context.Context, opts KeyMaterialOptions) (*KeyMaterial, error) { if d.closed.HasFired() { - return nil, ErrProviderClosed + return nil, errProviderClosed } if d.ready.HasFired() { @@ -81,7 +80,7 @@ func (d *Distributor) KeyMaterial(ctx context.Context, opts KeyMaterialOptions) case <-ctx.Done(): return nil, ctx.Err() case <-d.closed.Done(): - return nil, ErrProviderClosed + return nil, errProviderClosed case <-d.ready.Done(): return d.keyMaterial(), nil } @@ -89,7 +88,7 @@ func (d *Distributor) KeyMaterial(ctx context.Context, opts KeyMaterialOptions) func (d *Distributor) keyMaterial() *KeyMaterial { d.mu.Lock() - km := &KeyMaterial{Certs: d.certs, Roots: d.roots} + km := d.km d.mu.Unlock() return km } diff --git a/credentials/tls/certprovider/distributor_test.go b/credentials/tls/certprovider/distributor_test.go index 1f6f660edd3d..8d58318d3b95 100644 --- a/credentials/tls/certprovider/distributor_test.go +++ b/credentials/tls/certprovider/distributor_test.go @@ -94,12 +94,12 @@ func (s) TestDistributor(t *testing.T) { } proceedCh <- struct{}{} - // This call to KeyMaterial() should eventually return ErrProviderClosed + // This call to KeyMaterial() should eventually return errProviderClosed // error. ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for { - if _, err := dist.KeyMaterial(ctx, KeyMaterialOptions{}); err == ErrProviderClosed { + if _, err := dist.KeyMaterial(ctx, KeyMaterialOptions{}); err == errProviderClosed { doneCh := dist.Done() if _, ok := <-doneCh; ok { errCh <- errors.New("distributor done channel not closed") diff --git a/credentials/tls/certprovider/provider.go b/credentials/tls/certprovider/provider.go index da254a51bc48..def10ffc7c04 100644 --- a/credentials/tls/certprovider/provider.go +++ b/credentials/tls/certprovider/provider.go @@ -29,13 +29,12 @@ import ( "crypto/tls" "crypto/x509" "errors" - "strings" ) var ( - // ErrProviderClosed may be returned from a call to KeyMaterial() when the + // errProviderClosed may be returned from a call to KeyMaterial() when the // underlying provider instance is closed. - ErrProviderClosed = errors.New("provider instance is closed") + errProviderClosed = errors.New("provider instance is closed") // m is a map from name to provider builder. m = make(map[string]Builder) @@ -44,13 +43,13 @@ var ( // Register registers the provider builder, whose name as returned by its // Name() method will be used as the name registered with this builder. func Register(b Builder) { - m[strings.ToLower(b.Name())] = b + m[b.Name()] = b } -// Get returns the provider builder registered with the given name. +// getBuilder returns the provider builder registered with the given name. // If no builder is registered with the provided name, nil will be returned. -func Get(name string) Builder { - if b, ok := m[strings.ToLower(name)]; ok { +func getBuilder(name string) Builder { + if b, ok := m[name]; ok { return b } return nil @@ -64,6 +63,8 @@ type Builder interface { // ParseConfig converts config input in a format specific to individual // implementations and returns an implementation of the StableConfig // interface. + // Equivalent configurations should return StableConfig types whose + // Canonical() method returns the same output. ParseConfig(interface{}) (StableConfig, error) // Name returns the name of providers built by this builder. diff --git a/credentials/tls/certprovider/store.go b/credentials/tls/certprovider/store.go index f998df276999..024e621c52a2 100644 --- a/credentials/tls/certprovider/store.go +++ b/credentials/tls/certprovider/store.go @@ -90,7 +90,7 @@ func (ps *Store) GetProvider(key Key) Provider { return wp } - b := Get(key.Name) + b := getBuilder(key.Name) if b == nil { return nil } diff --git a/credentials/tls/certprovider/store_test.go b/credentials/tls/certprovider/store_test.go index 9485567d505c..878292842409 100644 --- a/credentials/tls/certprovider/store_test.go +++ b/credentials/tls/certprovider/store_test.go @@ -135,7 +135,7 @@ func makeProvider(t *testing.T, name, config string) (Provider, *fakeProvider) { t.Helper() // Grab the provider builder. - b := Get(name) + b := getBuilder(name) if b == nil { t.Fatalf("no provider builder found for name : %s", name) } @@ -195,8 +195,8 @@ func (s) TestStoreWithSingleProvider(t *testing.T) { // Close the provider and retry the KeyMaterial() call, and expect it to // fail with a known error. prov.Close() - if _, err := prov.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed { - t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed) + if _, err := prov.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed { + t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed) } } @@ -234,8 +234,8 @@ func (s) TestStoreWithSingleProviderWithSharing(t *testing.T) { } prov2.Close() - if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed { - t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed) + if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed { + t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed) } } @@ -302,8 +302,8 @@ func (s) TestStoreWithSingleProviderWithoutSharing(t *testing.T) { } prov2.Close() - if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed { - t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed) + if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed { + t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed) } } @@ -351,7 +351,7 @@ func (s) TestStoreWithMultipleProviders(t *testing.T) { } prov2.Close() - if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed { - t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed) + if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed { + t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed) } }