Skip to content

Commit

Permalink
fix: correct default usage for WithQuotaProject and WithUserAgent (
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Jan 23, 2025
1 parent 42d0019 commit 8520c3d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
18 changes: 15 additions & 3 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -220,9 +221,6 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, errUseTokenSource
}

// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))

// If callers have not provided a credential source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc., then use
// default credentials
Expand All @@ -247,6 +245,14 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// For all credential paths, use auth library's built-in
// httptransport.NewClient
if cfg.authCredentials != nil {
// Set headers for auth client as below WithHTTPClient will ignore
// WithQuotaProject and WithUserAgent Options
headers := http.Header{}
headers.Set("User-Agent", strings.Join(cfg.useragents, " "))
if cfg.quotaProject != "" {
headers.Set("X-Goog-User-Project", cfg.quotaProject)
}

authClient, err := httptransport.NewClient(&httptransport.Options{
Credentials: cfg.authCredentials,
UniverseDomain: cfg.getClientUniverseDomain(),
Expand All @@ -259,6 +265,12 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
if !cfg.setHTTPClient {
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithHTTPClient(authClient))
}
} else {
// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
if cfg.quotaProject != "" {
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithQuotaProject(cfg.quotaProject))
}
}

client, err := sqladmin.NewService(ctx, cfg.sqladminOpts...)
Expand Down
51 changes: 51 additions & 0 deletions e2e_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ var (
postgresCustomerCASPass = os.Getenv("POSTGRES_CUSTOMER_CAS_PASS") // Password for the database user for customer CAS instances; be careful when entering a password on the command line (it may go into your terminal's history).
postgresDB = os.Getenv("POSTGRES_DB") // Name of the database to connect to.
postgresUserIAM = os.Getenv("POSTGRES_USER_IAM") // Name of database IAM user.
project = os.Getenv("GOOGLE_CLOUD_PROJECT") // Name of the Google Cloud Platform project.
)

func requirePostgresVars(t *testing.T) {
Expand All @@ -69,6 +70,8 @@ func requirePostgresVars(t *testing.T) {
t.Fatal("'POSTGRES_DB' env var not set")
case postgresUserIAM:
t.Fatal("'POSTGRES_USER_IAM' env var not set")
case project:
t.Fatal("'GOOGLE_CLOUD_PROJECT' env var not set")
}
}

Expand Down Expand Up @@ -168,6 +171,54 @@ func TestPostgresCASConnect(t *testing.T) {
t.Log(now)
}

func TestPostgresConnectWithQuotaProject(t *testing.T) {
if testing.Short() {
t.Skip("skipping Postgres integration tests")
}
requirePostgresVars(t)

ctx := context.Background()

// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", postgresUser, postgresPass, postgresDB)
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
t.Fatalf("failed to parse pgx config: %v", err)
}

// Create a new dialer with any options
d, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithQuotaProject(project))
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}

// call cleanup when you're done with the database connection to close dialer
cleanup := func() error { return d.Close() }

// Tell the driver to use the Cloud SQL Go Connector to create connections
// postgresConnName takes the form of 'project:region:instance'.
config.ConnConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) {
return d.Dial(ctx, postgresConnName)
}

// Interact with the driver directly as you normally would
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
t.Fatalf("failed to create pool: %s", err)
}
// ... etc

defer cleanup()
defer pool.Close()

var now time.Time
err = pool.QueryRow(context.Background(), "SELECT NOW()").Scan(&now)
if err != nil {
t.Fatalf("QueryRow failed: %s", err)
}
t.Log(now)
}

func TestPostgresCustomerCASConnect(t *testing.T) {
if testing.Short() {
t.Skip("skipping Postgres integration tests")
Expand Down
3 changes: 2 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type dialerConfig struct {
logger debug.ContextLogger
lazyRefresh bool
clientUniverseDomain string
quotaProject string
authCredentials *auth.Credentials
iamLoginTokenProvider auth.TokenProvider
useragents []string
Expand Down Expand Up @@ -210,7 +211,7 @@ func WithUniverseDomain(ud string) Option {
// WithQuotaProject returns an Option that specifies the project used for quota and billing purposes.
func WithQuotaProject(p string) Option {
return func(cfg *dialerConfig) {
cfg.sqladminOpts = append(cfg.sqladminOpts, apiopt.WithQuotaProject(p))
cfg.quotaProject = p
}
}

Expand Down

0 comments on commit 8520c3d

Please sign in to comment.