Skip to content

Commit

Permalink
Review comments #1.
Browse files Browse the repository at this point in the history
  • Loading branch information
easwars committed Jun 10, 2020
1 parent 2adae09 commit cab3482
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 30 deletions.
21 changes: 10 additions & 11 deletions credentials/tls/certprovider/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package certprovider

import (
"context"
"crypto/tls"
"crypto/x509"
"sync"

"google.golang.org/grpc/internal/grpcsync"
Expand All @@ -38,11 +36,13 @@ import (
// by the provider.
// - When users of the provider call Close(), the channel returned by the
// Done() method will be closed. So, provider implementations can select on
// the channel returned by Done() to perform cleanup work.
// the channel returned by Done() to perform cleanup work, or override
// Close(), in which case they must invoke Distributor.Close() when the
// provider is closed.
type Distributor struct {
mu sync.Mutex
certs []tls.Certificate
roots *x509.CertPool
mu sync.Mutex
km *KeyMaterial

ready *grpcsync.Event
closed *grpcsync.Event
}
Expand All @@ -58,8 +58,7 @@ func NewDistributor() *Distributor {
// Set updates the key material in the distributor with km.
func (d *Distributor) Set(km *KeyMaterial) {
d.mu.Lock()
d.certs = km.Certs
d.roots = km.Roots
d.km = km
d.ready.Fire()
d.mu.Unlock()
}
Expand All @@ -70,7 +69,7 @@ func (d *Distributor) Set(km *KeyMaterial) {
// arrives.
func (d *Distributor) KeyMaterial(ctx context.Context, opts KeyMaterialOptions) (*KeyMaterial, error) {
if d.closed.HasFired() {
return nil, ErrProviderClosed
return nil, errProviderClosed
}

if d.ready.HasFired() {
Expand All @@ -81,15 +80,15 @@ func (d *Distributor) KeyMaterial(ctx context.Context, opts KeyMaterialOptions)
case <-ctx.Done():
return nil, ctx.Err()
case <-d.closed.Done():
return nil, ErrProviderClosed
return nil, errProviderClosed
case <-d.ready.Done():
return d.keyMaterial(), nil
}
}

func (d *Distributor) keyMaterial() *KeyMaterial {
d.mu.Lock()
km := &KeyMaterial{Certs: d.certs, Roots: d.roots}
km := d.km
d.mu.Unlock()
return km
}
Expand Down
4 changes: 2 additions & 2 deletions credentials/tls/certprovider/distributor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ func (s) TestDistributor(t *testing.T) {
}
proceedCh <- struct{}{}

// This call to KeyMaterial() should eventually return ErrProviderClosed
// This call to KeyMaterial() should eventually return errProviderClosed
// error.
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for {
if _, err := dist.KeyMaterial(ctx, KeyMaterialOptions{}); err == ErrProviderClosed {
if _, err := dist.KeyMaterial(ctx, KeyMaterialOptions{}); err == errProviderClosed {
doneCh := dist.Done()
if _, ok := <-doneCh; ok {
errCh <- errors.New("distributor done channel not closed")
Expand Down
15 changes: 8 additions & 7 deletions credentials/tls/certprovider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"strings"
)

var (
// ErrProviderClosed may be returned from a call to KeyMaterial() when the
// errProviderClosed may be returned from a call to KeyMaterial() when the
// underlying provider instance is closed.
ErrProviderClosed = errors.New("provider instance is closed")
errProviderClosed = errors.New("provider instance is closed")

// m is a map from name to provider builder.
m = make(map[string]Builder)
Expand All @@ -44,13 +43,13 @@ var (
// Register registers the provider builder, whose name as returned by its
// Name() method will be used as the name registered with this builder.
func Register(b Builder) {
m[strings.ToLower(b.Name())] = b
m[b.Name()] = b
}

// Get returns the provider builder registered with the given name.
// getBuilder returns the provider builder registered with the given name.
// If no builder is registered with the provided name, nil will be returned.
func Get(name string) Builder {
if b, ok := m[strings.ToLower(name)]; ok {
func getBuilder(name string) Builder {
if b, ok := m[name]; ok {
return b
}
return nil
Expand All @@ -64,6 +63,8 @@ type Builder interface {
// ParseConfig converts config input in a format specific to individual
// implementations and returns an implementation of the StableConfig
// interface.
// Equivalent configurations should return StableConfig types whose
// Canonical() method returns the same output.
ParseConfig(interface{}) (StableConfig, error)

// Name returns the name of providers built by this builder.
Expand Down
2 changes: 1 addition & 1 deletion credentials/tls/certprovider/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (ps *Store) GetProvider(key Key) Provider {
return wp
}

b := Get(key.Name)
b := getBuilder(key.Name)
if b == nil {
return nil
}
Expand Down
18 changes: 9 additions & 9 deletions credentials/tls/certprovider/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func makeProvider(t *testing.T, name, config string) (Provider, *fakeProvider) {
t.Helper()

// Grab the provider builder.
b := Get(name)
b := getBuilder(name)
if b == nil {
t.Fatalf("no provider builder found for name : %s", name)
}
Expand Down Expand Up @@ -195,8 +195,8 @@ func (s) TestStoreWithSingleProvider(t *testing.T) {
// Close the provider and retry the KeyMaterial() call, and expect it to
// fail with a known error.
prov.Close()
if _, err := prov.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed)
if _, err := prov.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed)
}
}

Expand Down Expand Up @@ -234,8 +234,8 @@ func (s) TestStoreWithSingleProviderWithSharing(t *testing.T) {
}

prov2.Close()
if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed)
if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed)
}
}

Expand Down Expand Up @@ -302,8 +302,8 @@ func (s) TestStoreWithSingleProviderWithoutSharing(t *testing.T) {
}

prov2.Close()
if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed)
if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed)
}
}

Expand Down Expand Up @@ -351,7 +351,7 @@ func (s) TestStoreWithMultipleProviders(t *testing.T) {
}

prov2.Close()
if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != ErrProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, ErrProviderClosed)
if _, err := prov2.KeyMaterial(ctx, KeyMaterialOptions{}); err != errProviderClosed {
t.Fatalf("provider.KeyMaterial() = %v, wantErr: %v", err, errProviderClosed)
}
}

0 comments on commit cab3482

Please sign in to comment.