Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mssql): fix mssql tool kind to mssql-sql #249

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mssql"
"github.com/googleapis/genai-toolbox/internal/tools/mssqlsql"
"github.com/googleapis/genai-toolbox/internal/tools/mysql"
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
Expand Down Expand Up @@ -272,8 +272,8 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case mssql.ToolKind:
actual := mssql.Config{Name: name}
case mssqlsql.ToolKind:
actual := mssqlsql.Config{Name: name}
if err := u.Unmarshal(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package mssql
package mssqlsql

import (
"database/sql"
Expand All @@ -24,7 +24,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/tools"
)

const ToolKind string = "mssql"
const ToolKind string = "mssql-sql"

type compatibleSource interface {
MSSQLDB() *sql.DB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package mssql_test
package mssqlsql_test

import (
"testing"
Expand All @@ -21,7 +21,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mssql"
"github.com/googleapis/genai-toolbox/internal/tools/mssqlsql"
"gopkg.in/yaml.v3"
)

Expand All @@ -36,7 +36,7 @@ func TestParseFromYamlMssql(t *testing.T) {
in: `
tools:
example_tool:
kind: mssql
kind: mssql-sql
source: my-instance
description: some description
statement: |
Expand All @@ -55,9 +55,9 @@ func TestParseFromYamlMssql(t *testing.T) {
field: user_id
`,
want: server.ToolConfigs{
"example_tool": mssql.Config{
"example_tool": mssqlsql.Config{
Name: "example_tool",
Kind: mssql.ToolKind,
Kind: mssqlsql.ToolKind,
Source: "my-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Expand Down
6 changes: 3 additions & 3 deletions tests/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
switch {
case strings.EqualFold(toolKind, "postgres-sql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName)
case strings.EqualFold(toolKind, "mssql"):
case strings.EqualFold(toolKind, "mssql-sql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName)
default:
t.Fatalf("invalid tool kind: %s", toolKind)
Expand Down Expand Up @@ -132,7 +132,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql":
case "mssql-sql":
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
default:
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
Expand Down Expand Up @@ -216,7 +216,7 @@ func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql":
case "mssql-sql":
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]"
default:
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]"
Expand Down
8 changes: 4 additions & 4 deletions tests/cloud_sql_mssql_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestCloudSQLMssql(t *testing.T) {
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": "mssql",
"kind": "mssql-sql",
"source": "my-instance",
"description": "Simple tool to test end to end functionality.",
"statement": "SELECT 1;",
Expand Down Expand Up @@ -300,7 +300,7 @@ func TestToolInvocationWithParams(t *testing.T) {
defer teardownTest(t)

// call generic invocation test helper
RunToolInvocationWithParamsTest(t, sourceConfig, "mssql", tableName)
RunToolInvocationWithParamsTest(t, sourceConfig, "mssql-sql", tableName)
}

// Set up auth test database table
Expand Down Expand Up @@ -362,7 +362,7 @@ func TestCloudSQLMssqlGoogleAuthenticatedParameter(t *testing.T) {
defer teardownTest(t)

// call generic auth test helper
RunGoogleAuthenticatedParameterTest(t, sourceConfig, "mssql", tableName)
RunGoogleAuthenticatedParameterTest(t, sourceConfig, "mssql-sql", tableName)

}

Expand All @@ -371,6 +371,6 @@ func TestCloudSQLMssqlAuthRequiredToolInvocation(t *testing.T) {
sourceConfig := requireCloudSQLMssqlVars(t)

// call generic auth test helper
RunAuthRequiredToolInvocationTest(t, sourceConfig, "mssql")
RunAuthRequiredToolInvocationTest(t, sourceConfig, "mssql-sql")

}
4 changes: 2 additions & 2 deletions tests/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any,
switch toolKind {
case "postgres-sql":
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
case "mssql":
case "mssql-sql":
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
default:
t.Fatalf("invalid tool kind: %s", toolKind)
Expand All @@ -224,7 +224,7 @@ func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any,
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql":
case "mssql-sql":
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int64=1) Alice][%!s(int64=3) Sid]"
default:
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]"
Expand Down
Loading