From 14592f3d21e58fbd038cffdb6c4f67d7e3526302 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 6 Jun 2023 12:23:59 -0600 Subject: [PATCH] feat: add support for WithOneOffDialFunc (#558) Fixes #551. --- dialer.go | 6 +++++- dialer_test.go | 33 +++++++++++++++++++++++++++++++++ options.go | 18 ++++++++++++++---- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/dialer.go b/dialer.go index 6e4601b2..6f4c7be1 100644 --- a/dialer.go +++ b/dialer.go @@ -223,7 +223,11 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect") defer func() { connectEnd(err) }() addr = net.JoinHostPort(addr, serverProxyPort) - conn, err = d.dialFunc(ctx, "tcp", addr) + f := d.dialFunc + if cfg.dialFunc != nil { + f = cfg.dialFunc + } + conn, err = f(ctx, "tcp", addr) if err != nil { // refresh the instance info in case it caused the connection failure i.ForceRefresh() diff --git a/dialer_test.go b/dialer_test.go index f9240e33..7346faf4 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -567,3 +567,36 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { t.Fatal("performRefresh should not be running") } } + +func TestDialerSupportsOneOffDialFunction(t *testing.T) { + ctx := context.Background() + inst := mock.NewFakeCSQLInstance("p", "r", "i") + svc, cleanup, err := mock.NewSQLAdminService( + context.Background(), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + if err != nil { + t.Fatalf("failed to init SQLAdminService: %v", err) + } + d, err := NewDialer(ctx, WithTokenSource(mock.EmptyTokenSource{})) + if err != nil { + t.Fatal(err) + } + d.sqladmin = svc + defer func() { + if err := d.Close(); err != nil { + t.Log(err) + } + _ = cleanup() + }() + + sentinelErr := errors.New("dial func was called") + f := func(context.Context, string, string) (net.Conn, error) { + return nil, sentinelErr + } + + if _, err := d.Dial(ctx, "p:r:i", WithOneOffDialFunc(f)); !errors.Is(err, sentinelErr) { + t.Fatal("one-off dial func was not called") + } +} diff --git a/options.go b/options.go index 7fca8252..ef974d82 100644 --- a/options.go +++ b/options.go @@ -188,7 +188,8 @@ func WithQuotaProject(p string) Option { // WithDialFunc configures the function used to connect to the address on the // named network. This option is generally unnecessary except for advanced -// use-cases. +// use-cases. The function is used for all invocations of Dial. To configure +// a dial function per individual calls to dial, use WithOneOffDialFunc. func WithDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn, error)) Option { return func(d *dialerConfig) { d.dialFunc = dial @@ -212,10 +213,10 @@ func WithIAMAuthN() Option { type DialOption func(d *dialCfg) type dialCfg struct { - tcpKeepAlive time.Duration + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) ipType string - - refreshCfg cloudsql.RefreshCfg + tcpKeepAlive time.Duration + refreshCfg cloudsql.RefreshCfg } // DialOptions turns a list of DialOption instances into an DialOption. @@ -227,6 +228,15 @@ func DialOptions(opts ...DialOption) DialOption { } } +// WithOneOffDialFunc configures the dial function on a one-off basis for an +// individual call to Dial. To configure a dial function across all invocations +// of Dial, use WithDialFunc. +func WithOneOffDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn, error)) DialOption { + return func(c *dialCfg) { + c.dialFunc = dial + } +} + // WithTCPKeepAlive returns a DialOption that specifies the tcp keep alive period for the connection returned by Dial. func WithTCPKeepAlive(d time.Duration) DialOption { return func(cfg *dialCfg) {