Skip to content

Commit

Permalink
feat: add Spanner source and tool (#90)
Browse files Browse the repository at this point in the history
Add Spanner source and tool.

Spanner source is initialize with the following config:
```
sources:
    my-spanner-source:
        kind: spanner
        project: my-project-name
        instance: my-instance-name
        database: my_db
        # dialect: postgresql # The default dialect is google_standard_sql.
```

Spanner tool (with gsql dialect) is initialize with the following
config.
```
tools:
    get_flight_by_id:
        kind: spanner
        source: my-cloud-sql-source
        description: >
            Use this tool to list all airports matching search criteria. Takes 
            at least one of country, city, name, or all and returns all matching
            airports. The agent can decide to return the results directly to 
            the user.
        statement: "SELECT * FROM flights WHERE id = @id"
        parameters:
        - name: id
          type: int
          description: 'id' represents the unique ID for each flight. 
```

Spanner tool (with postgresql dialect) is initialize with the following
config.
```
tools:
    get_flight_by_id:
        kind: spanner
        source: my-cloud-sql-source
        description: >
            Use this tool to list all airports matching search criteria. Takes 
            at least one of country, city, name, or all and returns all matching
            airports. The agent can decide to return the results directly to 
            the user.
        statement: "SELECT * FROM flights WHERE id = $1"
        parameters:
        - name: id
          type: int
          description: 'id' represents the unique ID for each flight. 
```

Note: the only difference in config for both dialects is the sql
statement.

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
  • Loading branch information
Yuan325 and kurtisvg authored Dec 7, 2024
1 parent 5528bec commit 890914a
Show file tree
Hide file tree
Showing 10 changed files with 2,034 additions and 11 deletions.
13 changes: 11 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,32 @@ go 1.22.2
require (
cloud.google.com/go/alloydbconn v1.13.0
cloud.google.com/go/cloudsqlconn v1.12.1
cloud.google.com/go/spanner v1.67.0
github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/httplog/v2 v2.1.1
github.com/go-chi/render v1.0.3
github.com/google/go-cmp v0.6.0
github.com/jackc/pgx/v5 v5.7.1
github.com/spf13/cobra v1.8.1
google.golang.org/api v0.199.0
gopkg.in/yaml.v3 v3.0.1
)

require (
cel.dev/expr v0.16.0 // indirect
cloud.google.com/go v0.115.1 // indirect
cloud.google.com/go/alloydb v1.12.1 // indirect
cloud.google.com/go/auth v0.9.5 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
cloud.google.com/go/compute/metadata v0.5.2 // indirect
cloud.google.com/go/longrunning v0.6.0 // indirect
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0 // indirect
github.com/ajg/form v1.5.1 // indirect
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 // indirect
github.com/envoyproxy/go-control-plane v0.13.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
Expand All @@ -34,7 +43,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.opencensus.io v0.24.0 // indirect
Expand All @@ -50,7 +59,7 @@ require (
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.7.0 // indirect
google.golang.org/api v0.199.0 // indirect
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect
Expand Down
1,446 changes: 1,445 additions & 1 deletion go.sum

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -138,6 +140,12 @@ func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case spannersrc.SourceKind:
actual := spannersrc.Config{Name: name, Dialect: "google_standard_sql"}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
}
Expand Down Expand Up @@ -175,6 +183,12 @@ func (c *ToolConfigs) UnmarshalYAML(node *yaml.Node) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case spanner.ToolKind:
actual := spanner.Config{Name: name}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
}
Expand Down
46 changes: 46 additions & 0 deletions internal/sources/dialect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sources

import (
"fmt"
"strings"

"gopkg.in/yaml.v3"
)

// Dialect represents the dialect type of a database.
type Dialect string

func (i *Dialect) String() string {
if string(*i) != "" {
return strings.ToLower(string(*i))
}
return "google_standard_sql"
}

func (i *Dialect) UnmarshalYAML(node *yaml.Node) error {
var dialect string
if err := node.Decode(&dialect); err != nil {
return err
}
switch strings.ToLower(dialect) {
case "google_standard_sql", "postgresql":
*i = Dialect(strings.ToLower(dialect))
return nil
default:
return fmt.Errorf(`dialect invalid: must be one of "google_standard_sql", or "postgresql"`)
}
}
100 changes: 100 additions & 0 deletions internal/sources/spanner/spanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package spanner

import (
"context"
"fmt"

"cloud.google.com/go/spanner"
"github.com/googleapis/genai-toolbox/internal/sources"
)

const SourceKind string = "spanner"

// validate interface
var _ sources.SourceConfig = Config{}

type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string `yaml:"project"`
Instance string `yaml:"instance"`
Dialect sources.Dialect `yaml:"dialect"`
Database string `yaml:"database"`
}

func (r Config) SourceConfigKind() string {
return SourceKind
}

func (r Config) Initialize() (sources.Source, error) {
client, err := initSpannerClient(r.Project, r.Instance, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create client: %w", err)
}

s := &Source{
Name: r.Name,
Kind: SourceKind,
Client: client,
Dialect: r.Dialect.String(),
}
return s, nil
}

var _ sources.Source = &Source{}

type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Client *spanner.Client
Dialect string
}

func (s *Source) SourceKind() string {
return SourceKind
}

func (s *Source) SpannerClient() *spanner.Client {
return s.Client
}

func (s *Source) DatabaseDialect() string {
return s.Dialect
}

func initSpannerClient(project, instance, dbname string) (*spanner.Client, error) {
// Configure the connection to the database
db := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbname)

// Configure session pool to automatically clean inactive transactions
sessionPoolConfig := spanner.SessionPoolConfig{
TrackSessionHandles: true,
InactiveTransactionRemovalOptions: spanner.InactiveTransactionRemovalOptions{
ActionOnInactiveTransaction: spanner.WarnAndClose,
},
}

// Create spanner client
ctx := context.Background()
client, err := spanner.NewClientWithConfig(ctx, db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
if err != nil {
return nil, fmt.Errorf("unable to create new client: %w", err)
}
defer client.Close()

return client, nil
}
148 changes: 148 additions & 0 deletions internal/sources/spanner/spanner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package spanner_test

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/testutils"
"gopkg.in/yaml.v3"
)

func TestParseFromYamlSpannerDb(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
database: my_db
`,
want: map[string]sources.SourceConfig{
"my-spanner-instance": spanner.Config{
Name: "my-spanner-instance",
Kind: spanner.SourceKind,
Project: "my-project",
Instance: "my-instance",
Dialect: "google_standard_sql",
Database: "my_db",
},
},
},
{
desc: "gsql dialect",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
dialect: Google_standard_sql
database: my_db
`,
want: map[string]sources.SourceConfig{
"my-spanner-instance": spanner.Config{
Name: "my-spanner-instance",
Kind: spanner.SourceKind,
Project: "my-project",
Instance: "my-instance",
Dialect: "google_standard_sql",
Database: "my_db",
},
},
},
{
desc: "postgresql dialect",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
dialect: postgresql
database: my_db
`,
want: map[string]sources.SourceConfig{
"my-spanner-instance": spanner.Config{
Name: "my-spanner-instance",
Kind: spanner.SourceKind,
Project: "my-project",
Instance: "my-instance",
Dialect: "postgresql",
Database: "my_db",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}

}

func FailParseFromYamlSpanner(t *testing.T) {
tcs := []struct {
desc string
in string
}{
{
desc: "invalid dialect",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
dialect: fail
database: my_db
`,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail: %s", err)
}
})
}
}
Loading

0 comments on commit 890914a

Please sign in to comment.