Skip to content

Commit

Permalink
feat: allow for configuring the Dial func (#57)
Browse files Browse the repository at this point in the history
Fixes #56.
  • Loading branch information
enocom authored Nov 17, 2021
1 parent de9e72e commit 4cb523e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
8 changes: 7 additions & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ type Dialer struct {
// *only* when a client has configured OpenCensus exporters.
dialerID string

// dialFunc is the function used to connect to the address on the named
// network. By default it is golang.org/x/net/proxy#Dial.
dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)

// iamTokenSource supplies the OAuth2 token used for IAM DB Authn. If IAM DB
// Authn is not enabled, iamTokenSource will be nil.
iamTokenSource oauth2.TokenSource
Expand All @@ -95,6 +99,7 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: 30 * time.Second,
sqladminOpts: []option.ClientOption{option.WithUserAgent(userAgent)},
dialFunc: proxy.Dial,
}
for _, opt := range opts {
opt(cfg)
Expand Down Expand Up @@ -153,6 +158,7 @@ func NewDialer(ctx context.Context, opts ...DialerOption) (*Dialer, error) {
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
iamTokenSource: cfg.tokenSource,
dialFunc: cfg.dialFunc,
}
return d, nil
}
Expand Down Expand Up @@ -190,7 +196,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
defer func() { connectEnd(err) }()
addr = net.JoinHostPort(addr, serverProxyPort)
conn, err = proxy.Dial(ctx, "tcp", addr)
conn, err = d.dialFunc(ctx, "tcp", addr)
if err != nil {
// refresh the instance info in case it caused the connection failure
i.ForceRefresh()
Expand Down
31 changes: 31 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"errors"
"io/ioutil"
"net"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -189,3 +191,32 @@ func TestIAMAuthn(t *testing.T) {
}
}
}

func TestDialerWithCustomDialFunc(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(
context.Background(),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
d, err := NewDialer(context.Background(),
WithTokenSource(mock.EmptyTokenSource{}),
WithDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("sentinel error")
}),
)
if err != nil {
t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
}
d.sqladmin = svc
defer func() {
if err := cleanup(); err != nil {
t.Fatalf("%v", err)
}
}()

_, err = d.Dial(context.Background(), "my-project:my-region:my-instance")
if !strings.Contains(err.Error(), "sentinel error") {
t.Fatalf("want = sentinel error, got = %v", err)
}
}
11 changes: 11 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/rsa"
"io/ioutil"
"net"
"net/http"
"time"

Expand All @@ -36,6 +37,7 @@ type dialerConfig struct {
rsaKey *rsa.PrivateKey
sqladminOpts []apiopt.ClientOption
dialOpts []DialOption
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
refreshTimeout time.Duration
useIAMAuthN bool
tokenSource oauth2.TokenSource
Expand Down Expand Up @@ -120,6 +122,15 @@ func WithHTTPClient(client *http.Client) DialerOption {
}
}

// WithDialFunc configures the function used to connect to the address on the
// named network. This option is generally unnecessary except for advanced
// use-cases.
func WithDialFunc(dial func(ctx context.Context, network, addr string) (net.Conn, error)) DialerOption {
return func(d *dialerConfig) {
d.dialFunc = dial
}
}

// WithIAMAuthN enables automatic IAM Authentication. If no token source has
// been configured (such as with WithTokenSource, WithCredentialsFile, etc), the
// dialer will use the default token source as defined by
Expand Down

0 comments on commit 4cb523e

Please sign in to comment.