From 4cb523e80b4a388b37c8ce251a533a3b8d370029 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Wed, 17 Nov 2021 10:33:58 -0700 Subject: [PATCH] feat: allow for configuring the Dial func (#57) Fixes #56. --- dialer.go | 8 +++++++- dialer_test.go | 31 +++++++++++++++++++++++++++++++ options.go | 11 +++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/dialer.go b/dialer.go index c51f382e..0ed0bb14 100644 --- a/dialer.go +++ b/dialer.go @@ -81,6 +81,10 @@ type Dialer struct { // *only* when a client has configured OpenCensus exporters. dialerID string + // dialFunc is the function used to connect to the address on the named + // network. By default it is golang.org/x/net/proxy#Dial. + dialFunc func(cxt context.Context, network, addr string) (net.Conn, error) + // iamTokenSource supplies the OAuth2 token used for IAM DB Authn. If IAM DB // Authn is not enabled, iamTokenSource will be nil. iamTokenSource oauth2.TokenSource @@ -95,6 +99,7 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) { cfg := &dialerConfig{ refreshTimeout: 30 * time.Second, sqladminOpts: []option.ClientOption{option.WithUserAgent(userAgent)}, + dialFunc: proxy.Dial, } for _, opt := range opts { opt(cfg) @@ -153,6 +158,7 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) { defaultDialCfg: dialCfg, dialerID: uuid.New().String(), iamTokenSource: cfg.tokenSource, + dialFunc: cfg.dialFunc, } return d, nil } @@ -190,7 +196,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect") defer func() { connectEnd(err) }() addr = net.JoinHostPort(addr, serverProxyPort) - conn, err = proxy.Dial(ctx, "tcp", addr) + conn, err = d.dialFunc(ctx, "tcp", addr) if err != nil { // refresh the instance info in case it caused the connection failure i.ForceRefresh() diff --git a/dialer_test.go b/dialer_test.go index 3debc0b0..f03165fb 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -18,6 +18,8 @@ import ( "context" "errors" "io/ioutil" + "net" + "strings" "testing" "time" @@ -189,3 +191,32 @@ func TestIAMAuthn(t *testing.T) { } } } + +func TestDialerWithCustomDialFunc(t *testing.T) { + inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance") + svc, cleanup, err := mock.NewSQLAdminService( + context.Background(), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + d, err := NewDialer(context.Background(), + WithTokenSource(mock.EmptyTokenSource{}), + WithDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("sentinel error") + }), + ) + if err != nil { + t.Fatalf("expected NewDialer to succeed, but got error: %v", err) + } + d.sqladmin = svc + defer func() { + if err := cleanup(); err != nil { + t.Fatalf("%v", err) + } + }() + + _, err = d.Dial(context.Background(), "my-project:my-region:my-instance") + if !strings.Contains(err.Error(), "sentinel error") { + t.Fatalf("want = sentinel error, got = %v", err) + } +} diff --git a/options.go b/options.go index 4409fe41..a114c14a 100644 --- a/options.go +++ b/options.go @@ -18,6 +18,7 @@ import ( "context" "crypto/rsa" "io/ioutil" + "net" "net/http" "time" @@ -36,6 +37,7 @@ type dialerConfig struct { rsaKey *rsa.PrivateKey sqladminOpts []apiopt.ClientOption dialOpts []DialOption + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) refreshTimeout time.Duration useIAMAuthN bool tokenSource oauth2.TokenSource @@ -120,6 +122,15 @@ func WithHTTPClient(client *http.Client) DialerOption { } } +// WithDialFunc configures the function used to connect to the address on the +// named network. This option is generally unnecessary except for advanced +// use-cases. +func WithDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn, error)) DialerOption { + return func(d *dialerConfig) { + d.dialFunc = dial + } +} + // WithIAMAuthN enables automatic IAM Authentication. If no token source has // been configured (such as with WithTokenSource, WithCredentialsFile, etc), the // dialer will use the default token source as defined by