diff --git a/dialer.go b/dialer.go index 855fab87..8fd93bc2 100644 --- a/dialer.go +++ b/dialer.go @@ -63,18 +63,47 @@ var ( //go:embed version.txt versionString string userAgent = "alloydb-go-connector/" + strings.TrimSpace(versionString) - - // defaultKey is the default RSA public/private keypair used by the clients. - defaultKey *rsa.PrivateKey - defaultKeyErr error - keyOnce sync.Once ) -func getDefaultKeys() (*rsa.PrivateKey, error) { - keyOnce.Do(func() { - defaultKey, defaultKeyErr = rsa.GenerateKey(rand.Reader, 2048) - }) - return defaultKey, defaultKeyErr +// keyGenerator encapsulates the details of RSA key generation to provide lazy +// generation, custom keys, or a default RSA generator. +type keyGenerator struct { + once sync.Once + key *rsa.PrivateKey + err error + genFunc func() (*rsa.PrivateKey, error) +} + +// newKeyGenerator initializes a keyGenerator that will (in order): +// - always return the RSA key if one is provided, or +// - generate an RSA key lazily when it's requested, or +// - (default) immediately generate an RSA key as part of the initializer. +func newKeyGenerator( + k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error), +) (*keyGenerator, error) { + g := &keyGenerator{genFunc: genFunc} + switch { + case k != nil: + // If the caller has provided a key, initialize the key and consume the + // sync.Once now. + g.once.Do(func() { g.key, g.err = k, nil }) + case lazy: + // If lazy refresh is enabled, do nothing and wait for the call to + // rsaKey. + default: + // If no key has been provided and lazy refresh isn't enabled, generate + // the key and consume the sync.Once now. + g.once.Do(func() { g.key, g.err = g.genFunc() }) + } + return g, g.err +} + +// rsaKey will generate an RSA key if one is not already cached. Otherwise, it +// will return the cached key. +func (g *keyGenerator) rsaKey() (*rsa.PrivateKey, error) { + g.once.Do(func() { g.key, g.err = g.genFunc() }) + + return g.key, g.err } type connectionInfoCache interface { @@ -96,7 +125,7 @@ type monitoredCache struct { type Dialer struct { lock sync.RWMutex cache map[alloydb.InstanceURI]monitoredCache - key *rsa.PrivateKey + keyGenerator *keyGenerator refreshTimeout time.Duration // closed reports if the dialer has been closed. closed chan struct{} @@ -158,14 +187,6 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { // Add this to the end to make sure it's not overridden cfg.adminOpts = append(cfg.adminOpts, option.WithUserAgent(userAgent)) - if cfg.rsaKey == nil { - key, err := getDefaultKeys() - if err != nil { - return nil, fmt.Errorf("failed to generate RSA keys: %v", err) - } - cfg.rsaKey = key - } - // If no token source is configured, use ADC's token source. ts := cfg.tokenSource if ts == nil { @@ -192,12 +213,19 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { if err := trace.InitMetrics(); err != nil { return nil, err } + g, err := newKeyGenerator(cfg.rsaKey, cfg.lazyRefresh, + func() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2048) + }) + if err != nil { + return nil, err + } d := &Dialer{ closed: make(chan struct{}), cache: make(map[alloydb.InstanceURI]monitoredCache), lazyRefresh: cfg.lazyRefresh, staticConnInfo: cfg.staticConnInfo, - key: cfg.rsaKey, + keyGenerator: g, refreshTimeout: cfg.refreshTimeout, client: client, logger: cfg.logger, @@ -570,13 +598,17 @@ func (d *Dialer) connectionInfoCache( "[%v] Connection info added to cache", uri.String(), ) + k, err := d.keyGenerator.rsaKey() + if err != nil { + return monitoredCache{}, err + } var cache connectionInfoCache switch { case d.lazyRefresh: cache = alloydb.NewLazyRefreshCache( uri, d.logger, - d.client, d.key, + d.client, k, d.refreshTimeout, d.dialerID, ) case d.staticConnInfo != nil: @@ -593,7 +625,7 @@ func (d *Dialer) connectionInfoCache( cache = alloydb.NewRefreshAheadCache( uri, d.logger, - d.client, d.key, + d.client, k, d.refreshTimeout, d.dialerID, ) } diff --git a/key_gen_test.go b/key_gen_test.go new file mode 100644 index 00000000..8ccc4fc0 --- /dev/null +++ b/key_gen_test.go @@ -0,0 +1,160 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package alloydbconn + +import ( + "crypto/rsa" + "errors" + "testing" +) + +func TestKeyGenerator(t *testing.T) { + custom := &rsa.PrivateKey{} + generated := &rsa.PrivateKey{} + + var ( + defaultCount int + lazyCount int + ) + + tcs := []struct { + desc string + key *rsa.PrivateKey + lazy bool + genFunc func() (*rsa.PrivateKey, error) + wantKey *rsa.PrivateKey + // whether key generation should happen in the initializer or the call + // to rsaKey + wantLazy bool + }{ + { + desc: "by default a key is generated", + genFunc: func() (*rsa.PrivateKey, error) { + return generated, nil + }, + wantKey: generated, + }, + { + desc: "a custom key skips the generator", + key: custom, + genFunc: func() (*rsa.PrivateKey, error) { + return nil, errors.New("generator should not be called") + }, + wantKey: custom, + }, + { + desc: "lazy generates keys on first request", + lazy: true, + genFunc: func() (*rsa.PrivateKey, error) { + if defaultCount > 0 { + return nil, errors.New("genFunc was called twice") + } + defaultCount++ + return generated, nil + }, + wantKey: generated, + wantLazy: true, + }, + { + desc: "key generation happens only once", + genFunc: func() (*rsa.PrivateKey, error) { + if lazyCount > 0 { + return nil, errors.New("genFunc was called twice") + } + lazyCount++ + return generated, nil + }, + wantKey: generated, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + g, err := newKeyGenerator(tc.key, tc.lazy, tc.genFunc) + if err != nil { + t.Fatal(err) + } + if tc.wantLazy && g.key != nil { + t.Fatal("want RSA key to be lazily generated, but it wasn't") + } + k, err := g.rsaKey() + if err != nil { + t.Fatal(err) + } + if tc.wantKey != k { + t.Fatalf("want = %v, got = %v", tc.wantKey, k) + } + // Ensure a second call doesn't trigger a new key generation + _, err = g.rsaKey() + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestKeyGeneratorErrors(t *testing.T) { + sentinel := errors.New("sentinel error") + tcs := []struct { + desc string + key *rsa.PrivateKey + lazy bool + genFunc func() (*rsa.PrivateKey, error) + wantInitError error + wantKeyError error + }{ + { + desc: "generator returns errors", + genFunc: func() (*rsa.PrivateKey, error) { + return nil, sentinel + }, + wantInitError: sentinel, + wantKeyError: sentinel, + }, + { + desc: "custom keys never error", + key: &rsa.PrivateKey{}, + genFunc: func() (*rsa.PrivateKey, error) { + return nil, errors.New("generator should not be called") + }, + wantInitError: nil, + wantKeyError: nil, + }, + { + desc: "lazy generation returns errors", + lazy: true, + genFunc: func() (*rsa.PrivateKey, error) { + return nil, sentinel + }, + // initialization should succeed + wantInitError: nil, + // but requesting the key later should fail + wantKeyError: sentinel, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + g, err := newKeyGenerator(tc.key, tc.lazy, tc.genFunc) + if err != tc.wantInitError { + t.Fatal("initialization should fail, but did not") + } + _, err = g.rsaKey() + if err != tc.wantKeyError { + t.Fatal("rsaKey should fail but didn't") + } + }) + } +}