Skip to content

Commit

Permalink
chore: add ip_type to alloydb and cloudsql source
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Dec 2, 2024
1 parent 51c6b5b commit 8cf1cee
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
23 changes: 20 additions & 3 deletions internal/sources/alloydbpg/alloydb_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"net"
"strings"

"cloud.google.com/go/alloydbconn"
"github.com/googleapis/genai-toolbox/internal/sources"
Expand All @@ -36,6 +37,7 @@ type Config struct {
Region string `yaml:"region"`
Cluster string `yaml:"cluster"`
Instance string `yaml:"instance"`
Ip_type string `yaml:"ip_type"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
Expand All @@ -46,7 +48,7 @@ func (r Config) SourceConfigKind() string {
}

func (r Config) Initialize() (sources.Source, error) {
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.User, r.Password, r.Database)
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.Ip_type, r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
Expand Down Expand Up @@ -80,7 +82,22 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}

func initAlloyDBPgConnectionPool(project, region, cluster, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
func getDialer(ip_type string) (*alloydbconn.Dialer, error) {
switch strings.ToLower(ip_type) {
case "private":
// alloydbconn create a dialer with private IP by default
return alloydbconn.NewDialer(context.Background())
default:
return alloydbconn.NewDialer(
context.Background(),
alloydbconn.WithDefaultDialOptions(
alloydbconn.WithPublicIP(),
),
)
}
}

func initAlloyDBPgConnectionPool(project, region, cluster, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
config, err := pgxpool.ParseConfig(dsn)
Expand All @@ -89,7 +106,7 @@ func initAlloyDBPgConnectionPool(project, region, cluster, instance, user, pass,
}

// Create a new dialer with any options
d, err := alloydbconn.NewDialer(context.Background())
d, err := getDialer(ip_type)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
Expand Down
27 changes: 24 additions & 3 deletions internal/sources/cloudsqlpg/cloud_sql_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"net"
"strings"

"cloud.google.com/go/cloudsqlconn"
"github.com/googleapis/genai-toolbox/internal/sources"
Expand All @@ -35,6 +36,7 @@ type Config struct {
Project string `yaml:"project"`
Region string `yaml:"region"`
Instance string `yaml:"instance"`
Ip_type string `yaml:"ip_type"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
Expand All @@ -45,7 +47,7 @@ func (r Config) SourceConfigKind() string {
}

func (r Config) Initialize() (sources.Source, error) {
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database)
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.Ip_type, r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
Expand Down Expand Up @@ -79,7 +81,26 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}

func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
func getDialer(ip_type string) (*cloudsqlconn.Dialer, error) {
switch strings.ToLower(ip_type) {
case "private":
return cloudsqlconn.NewDialer(
context.Background(),
cloudsqlconn.WithDefaultDialOptions(
cloudsqlconn.WithPrivateIP(),
),
)
default:
return cloudsqlconn.NewDialer(
context.Background(),
cloudsqlconn.WithDefaultDialOptions(
cloudsqlconn.WithPublicIP(),
),
)
}
}

func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
config, err := pgxpool.ParseConfig(dsn)
Expand All @@ -88,7 +109,7 @@ func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname
}

// Create a new dialer with any options
d, err := cloudsqlconn.NewDialer(context.Background())
d, err := getDialer(ip_type)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
Expand Down

0 comments on commit 8cf1cee

Please sign in to comment.