From 8520c3d938e5011eb77ff5b5dc08c4e94e691a16 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Thu, 23 Jan 2025 14:57:50 -0500 Subject: [PATCH] fix: correct default usage for `WithQuotaProject` and `WithUserAgent` (#920) --- dialer.go | 18 +++++++++++++--- e2e_postgres_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++ options.go | 3 ++- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/dialer.go b/dialer.go index e20c0641..37a8def7 100644 --- a/dialer.go +++ b/dialer.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "net" + "net/http" "os" "strings" "sync" @@ -220,9 +221,6 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { return nil, errUseTokenSource } - // Add this to the end to make sure it's not overridden - cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " "))) - // If callers have not provided a credential source, either explicitly with // WithTokenSource or implicitly with WithCredentialsJSON etc., then use // default credentials @@ -247,6 +245,14 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { // For all credential paths, use auth library's built-in // httptransport.NewClient if cfg.authCredentials != nil { + // Set headers for auth client as below WithHTTPClient will ignore + // WithQuotaProject and WithUserAgent Options + headers := http.Header{} + headers.Set("User-Agent", strings.Join(cfg.useragents, " ")) + if cfg.quotaProject != "" { + headers.Set("X-Goog-User-Project", cfg.quotaProject) + } + authClient, err := httptransport.NewClient(&httptransport.Options{ Credentials: cfg.authCredentials, UniverseDomain: cfg.getClientUniverseDomain(), @@ -259,6 +265,12 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { if !cfg.setHTTPClient { cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithHTTPClient(authClient)) } + } else { + // Add this to the end to make sure it's not overridden + cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " "))) + if cfg.quotaProject != "" { + cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithQuotaProject(cfg.quotaProject)) + } } client, err := sqladmin.NewService(ctx, cfg.sqladminOpts...) diff --git a/e2e_postgres_test.go b/e2e_postgres_test.go index 06028f02..71e77379 100644 --- a/e2e_postgres_test.go +++ b/e2e_postgres_test.go @@ -47,6 +47,7 @@ var ( postgresCustomerCASPass = os.Getenv("POSTGRES_CUSTOMER_CAS_PASS") // Password for the database user for customer CAS instances; be careful when entering a password on the command line (it may go into your terminal's history). postgresDB = os.Getenv("POSTGRES_DB") // Name of the database to connect to. postgresUserIAM = os.Getenv("POSTGRES_USER_IAM") // Name of database IAM user. + project = os.Getenv("GOOGLE_CLOUD_PROJECT") // Name of the Google Cloud Platform project. ) func requirePostgresVars(t *testing.T) { @@ -69,6 +70,8 @@ func requirePostgresVars(t *testing.T) { t.Fatal("'POSTGRES_DB' env var not set") case postgresUserIAM: t.Fatal("'POSTGRES_USER_IAM' env var not set") + case project: + t.Fatal("'GOOGLE_CLOUD_PROJECT' env var not set") } } @@ -168,6 +171,54 @@ func TestPostgresCASConnect(t *testing.T) { t.Log(now) } +func TestPostgresConnectWithQuotaProject(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + requirePostgresVars(t) + + ctx := context.Background() + + // Configure the driver to connect to the database + dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", postgresUser, postgresPass, postgresDB) + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + t.Fatalf("failed to parse pgx config: %v", err) + } + + // Create a new dialer with any options + d, err := cloudsqlconn.NewDialer(ctx, cloudsqlconn.WithQuotaProject(project)) + if err != nil { + t.Fatalf("failed to init Dialer: %v", err) + } + + // call cleanup when you're done with the database connection to close dialer + cleanup := func() error { return d.Close() } + + // Tell the driver to use the Cloud SQL Go Connector to create connections + // postgresConnName takes the form of 'project:region:instance'. + config.ConnConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) { + return d.Dial(ctx, postgresConnName) + } + + // Interact with the driver directly as you normally would + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + t.Fatalf("failed to create pool: %s", err) + } + // ... etc + + defer cleanup() + defer pool.Close() + + var now time.Time + err = pool.QueryRow(context.Background(), "SELECT NOW()").Scan(&now) + if err != nil { + t.Fatalf("QueryRow failed: %s", err) + } + t.Log(now) +} + func TestPostgresCustomerCASConnect(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") diff --git a/options.go b/options.go index 1a89303b..5965a938 100644 --- a/options.go +++ b/options.go @@ -47,6 +47,7 @@ type dialerConfig struct { logger debug.ContextLogger lazyRefresh bool clientUniverseDomain string + quotaProject string authCredentials *auth.Credentials iamLoginTokenProvider auth.TokenProvider useragents []string @@ -210,7 +211,7 @@ func WithUniverseDomain(ud string) Option { // WithQuotaProject returns an Option that specifies the project used for quota and billing purposes. func WithQuotaProject(p string) Option { return func(cfg *dialerConfig) { - cfg.sqladminOpts = append(cfg.sqladminOpts, apiopt.WithQuotaProject(p)) + cfg.quotaProject = p } }