Skip to content

Commit

Permalink
feat: add user agent to cloud databases (#244)
Browse files Browse the repository at this point in the history
Add user agent to cloud databases that provides us anonymized data
request count, number of users, number of projects, and other
environment settings.

User agent is using the format: `genai-toolbox/$version+metadata`
  • Loading branch information
Yuan325 authored Jan 30, 2025
1 parent 8152a98 commit 8452f8e
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 50 deletions.
4 changes: 4 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
Expand Down Expand Up @@ -58,6 +59,9 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
parentCtx, span := instrumentation.Tracer.Start(context.Background(), "toolbox/server/init")
defer span.End()

userAgent := fmt.Sprintf("genai-toolbox/%s", cfg.Version)
parentCtx = context.WithValue(parentCtx, util.UserAgentKey, userAgent)

// set up http serving
r := chi.NewRouter()
r.Use(middleware.Recoverer)
Expand Down
14 changes: 9 additions & 5 deletions internal/sources/alloydbpg/alloydb_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"cloud.google.com/go/alloydbconn"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/jackc/pgx/v5/pgxpool"
"go.opentelemetry.io/otel/trace"
)
Expand Down Expand Up @@ -83,15 +84,17 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}

func getDialOpts(ipType string) ([]alloydbconn.DialOption, error) {
func getOpts(ipType, userAgent string) ([]alloydbconn.Option, error) {
opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)}
switch strings.ToLower(ipType) {
case "private":
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPrivateIP()))
case "public":
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
opts = append(opts, alloydbconn.WithDefaultDialOptions(alloydbconn.WithPublicIP()))
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
return opts, nil
}

func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
Expand All @@ -107,11 +110,12 @@ func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name,
}

// Create a new dialer with options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := getOpts(ipType, userAgent)
if err != nil {
return nil, err
}
d, err := alloydbconn.NewDialer(context.Background(), alloydbconn.WithDefaultDialOptions(dialOpts...))
d, err := alloydbconn.NewDialer(context.Background(), opts...)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
Expand Down
19 changes: 4 additions & 15 deletions internal/sources/cloudsqlmssql/cloud_sql_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ import (
"database/sql"
"fmt"
"slices"
"strings"

"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -91,17 +90,6 @@ func (s *Source) MSSQLDB() *sql.DB {
return s.Db
}

func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}

func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
Expand All @@ -111,14 +99,15 @@ func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name,
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)

// Get dial options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
if err != nil {
return nil, err
}

// Register sql server driver
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", opts...)
if err != nil {
return nil, err
}
Expand Down
19 changes: 4 additions & 15 deletions internal/sources/cloudsqlmysql/cloud_sql_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ import (
"database/sql"
"fmt"
"slices"
"strings"

"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -83,30 +82,20 @@ func (s *Source) MySQLPool() *sql.DB {
return s.Pool
}

func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}

func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()

// Create a new dialer with options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
if err != nil {
return nil, err
}

if !slices.Contains(sql.Drivers(), "cloudsql-mysql") {
_, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
_, err = mysql.RegisterDriver("cloudsql-mysql", opts...)
if err != nil {
return nil, fmt.Errorf("unable to register driver: %w", err)
}
Expand Down
18 changes: 4 additions & 14 deletions internal/sources/cloudsqlpg/cloud_sql_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (
"context"
"fmt"
"net"
"strings"

"cloud.google.com/go/cloudsqlconn"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/jackc/pgx/v5/pgxpool"
"go.opentelemetry.io/otel/trace"
)
Expand Down Expand Up @@ -82,17 +82,6 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}

func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}

func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
Expand All @@ -106,11 +95,12 @@ func initCloudSQLPgConnectionPool(ctx context.Context, tracer trace.Tracer, name
}

// Create a new dialer with options
dialOpts, err := getDialOpts(ipType)
userAgent := ctx.Value(util.UserAgentKey).(string)
opts, err := sources.GetCloudSQLOpts(ipType, userAgent)
if err != nil {
return nil, err
}
d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...))
d, err := cloudsqlconn.NewDialer(context.Background(), opts...)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
Expand Down
4 changes: 3 additions & 1 deletion internal/sources/spanner/spanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"cloud.google.com/go/spanner"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -94,7 +95,8 @@ func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project,
}

// Create spanner client
client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
userAgent := ctx.Value(util.UserAgentKey).(string)
client, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig, UserAgent: userAgent})
if err != nil {
return nil, fmt.Errorf("unable to create new client: %w", err)
}
Expand Down
37 changes: 37 additions & 0 deletions internal/sources/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sources

import (
"fmt"
"strings"

"cloud.google.com/go/cloudsqlconn"
)

// GetCloudSQLDialOpts retrieve dial options with the right ip type and user agent for cloud sql
// databases.
func GetCloudSQLOpts(ipType, userAgent string) ([]cloudsqlconn.Option, error) {
opts := []cloudsqlconn.Option{cloudsqlconn.WithUserAgent(userAgent)}
switch strings.ToLower(ipType) {
case "private":
opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPrivateIP()))
case "public":
opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPublicIP()))
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
return opts, nil
}
5 changes: 5 additions & 0 deletions internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ func (d *DelayedUnmarshaler) UnmarshalYAML(unmarshal func(interface{}) error) er
func (d *DelayedUnmarshaler) Unmarshal(v interface{}) error {
return d.unmarshal(v)
}

type contextKey string

// UserAgentKey is the key used to store userAgent within context
const UserAgentKey contextKey = "userAgent"

0 comments on commit 8452f8e

Please sign in to comment.