From ff5c846930b1a33fbffe4dd3a0fc7eb5cc4c9264 Mon Sep 17 00:00:00 2001 From: Shivaji Kharse <115525374+shivaji-kharse@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:06:15 +0530 Subject: [PATCH] feat: add dgraph tool and source --- docs/en/resources/sources/dgraph.md | 49 ++++ docs/tools/README.md | 2 + docs/tools/dgraph-dql.md | 105 +++++++ go.mod | 2 + go.sum | 6 +- internal/server/config.go | 14 + internal/sources/dgraph/dgraph.go | 378 +++++++++++++++++++++++++ internal/sources/dgraph/dgraph_test.go | 76 +++++ internal/tools/dgraph/dgraph.go | 133 +++++++++ internal/tools/dgraph/dgraph_test.go | 96 +++++++ tests/dgraph.yaml | 31 ++ tests/dgraph_integration_test.go | 173 +++++++++++ 12 files changed, 1063 insertions(+), 2 deletions(-) create mode 100644 docs/en/resources/sources/dgraph.md create mode 100644 docs/tools/dgraph-dql.md create mode 100644 internal/sources/dgraph/dgraph.go create mode 100644 internal/sources/dgraph/dgraph_test.go create mode 100644 internal/tools/dgraph/dgraph.go create mode 100644 internal/tools/dgraph/dgraph_test.go create mode 100644 tests/dgraph.yaml create mode 100644 tests/dgraph_integration_test.go diff --git a/docs/en/resources/sources/dgraph.md b/docs/en/resources/sources/dgraph.md new file mode 100644 index 000000000..3a683f454 --- /dev/null +++ b/docs/en/resources/sources/dgraph.md @@ -0,0 +1,49 @@ +--- +title: "Dgraph" +type: docs +description: > + Dgraph is a horizontally scalable and distributed graph database. + +--- + +## About + +[Dgraph][dgraph-docs] is a horizontally scalable and distributed graph database. +It provides ACID transactions, consistent replication, and linearizable reads. + +This source can connect to either a self-managed Dgraph cluster or one hosted on Dgraph Cloud. +If you're new to Dgraph, the fastest way to get started is to [sign up for Dgraph Cloud][dgraph-login]. + +[dgraph-docs]: https://dgraph.io/docs +[dgraph-login]: https://cloud.dgraph.io/login + +## Requirements + +### Database User + +When **connecting to a hosted Dgraph database**, this source uses the API key for access. If you are using a dedicated environment, you will additionally need the namespace and user credentials for that namespace. + +For **connecting to a local or self-hosted Dgraph database**, use the namespace and user credentials for that namespace. + +## Example + +```yaml +sources: + my-dgraph-source: + dgraphUrl: "https://xxxx.cloud.dgraph.io" + user: "groot" + password: "password" + apiKey: abc123 + namepace : 0 +``` + +## Reference + +| **Field** | **Type** | **Required** | **Description** | +|-------------|:--------:|:------------:|--------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "dgraph". | +| dgraphurl | string | true | Connection URI (e.g. "https://xxx.cloud.dgraph.io", "https://localhost:8080"). | +| user | string | false | Name of the Dgraph user to connect as (e.g., "groot"). | +| password | string | false | Password of the Dgraph user (e.g., "password"). | +| apiKey | string | false | API key to connect to a Dgraph Cloud instance. | +| namespace | uint64 | false | Dgraph namespace (not required for Dgraph Cloud Shared Clusters). | diff --git a/docs/tools/README.md b/docs/tools/README.md index 5e3c1148b..0ed9ee946 100644 --- a/docs/tools/README.md +++ b/docs/tools/README.md @@ -52,6 +52,8 @@ We currently support the following types of kinds of tools: statement againts Spanner database. * [neo4j-cypher](./neo4j-cypher.md) - Run a Cypher statement against a Neo4j database. +* [dgraph-dql](./dgraph-dql.md) - Run a DQL statement against a + Dgraph database. ## Specifying Parameters diff --git a/docs/tools/dgraph-dql.md b/docs/tools/dgraph-dql.md new file mode 100644 index 000000000..a151d0cb2 --- /dev/null +++ b/docs/tools/dgraph-dql.md @@ -0,0 +1,105 @@ +# Dgraph DQL Tool + + +A "dgraph-dql" tool executes a pre-defined DQL statement against a Dgraph database. It's compatible with any of the following +sources: +- [dgraph](../sources/dgraph.md) + +To run a statement as a query, you need to set the config isQuery=true. For upserts or mutations, set isQuery=false. +You can also configure timeout for a query. + +## Example + +### Query: + +```yaml +tools: + search_user: + kind: dgraph-dql + source: dgraph-user-instance + statement: | + query all($role: string){ + users(func: has(name)) @filter(eq(role, $role) AND ge(age, 30) AND le(age, 50)) { + uid + name + email + role + age + } + } + isQuery: true + timeout: 20s + description: | + Use this tool to retrieve the details of users who are admins and are between 30 and 50 years old. + The query returns the user's name, email, role, and age. + This can be helpful when you want to fetch admin users within a specific age range. + Example: Fetch admins aged between 30 and 50: + [ + { + "name": "Alice", + "role": "admin", + "age": 35 + }, + { + "name": "Bob", + "role": "admin", + "age": 45 + } + ] + parameters: + - name: $role + type: string + description: admin +``` + +### Mutation: + +```yaml +tools: + dgraph-manage-user-instance: + kind: dgraph-dql + source: dgraph-manage-user-instance + isQuery: false + statement: | + { + set { + _:user1 $user1 . + _:user1 $email1 . + _:user1 "admin" . + _:user1 "35" . + + _:user2 $user2 . + _:user2 $email2 . + _:user2 "admin" . + _:user2 "45" . + } + } + description: | + Use this tool to insert or update user data into the Dgraph database. + The mutation adds or updates user details like name, email, role, and age. + Example: Add users Alice and Bob as admins with specific ages. + parameters: + - name: $user1 + type: string + description: Alice + - name: $email1 + type: string + description: alice@email.com + - name: $user2 + type: string + description: Bob + - name: $email2 + type: string + description: bob@email.com +``` + +## Reference +| **field** | **type** | **required** | **description** | +|-------------|----------:|:------------:|----------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "dgraph-dql". | +| source | string | true | Name of the source the dql query should execute on. | +| description | string | true | Description of the tool | +| statement | string | true | dql statement to execute | +| isQuery | boolean | false | To run statment as query set true otherwise false | +| timeout | string | false | To set timout for query | +| parameters | parameter | true | List of [parameters](README.md#specifying-parameters) that will be used with the dql statement. | diff --git a/go.mod b/go.mod index 3ff2ce4f8..2d882c7a1 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/envoyproxy/go-control-plane v0.13.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -69,6 +70,7 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/microsoft/go-mssqldb v1.8.0 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/spf13/pflag v1.0.5 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect diff --git a/go.sum b/go.sum index f3e1c694c..3f0df2ef4 100644 --- a/go.sum +++ b/go.sum @@ -699,8 +699,9 @@ github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1Ig github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -951,8 +952,9 @@ github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZ github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= diff --git a/internal/server/config.go b/internal/server/config.go index 16caa940b..15d0488b7 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -25,10 +25,12 @@ import ( cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + dgraphsrc "github.com/googleapis/genai-toolbox/internal/sources/dgraph" neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres" spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/dgraph" "github.com/googleapis/genai-toolbox/internal/tools/mssql" "github.com/googleapis/genai-toolbox/internal/tools/mysql" neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j" @@ -181,6 +183,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error { return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) } (*c)[name] = actual + case dgraphsrc.SourceKind: + actual := dgraphsrc.Config{Name: name} + if err := u.Unmarshal(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + } + (*c)[name] = actual default: return fmt.Errorf("%q is not a valid kind of data source", k.Kind) } @@ -278,6 +286,12 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error { return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) } (*c)[name] = actual + case dgraph.ToolKind: + actual := dgraph.Config{Name: name} + if err := u.Unmarshal(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + } + (*c)[name] = actual default: return fmt.Errorf("%q is not a valid kind of tool", k.Kind) } diff --git a/internal/sources/dgraph/dgraph.go b/internal/sources/dgraph/dgraph.go new file mode 100644 index 000000000..28100d12b --- /dev/null +++ b/internal/sources/dgraph/dgraph.go @@ -0,0 +1,378 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dgraph + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/googleapis/genai-toolbox/internal/sources" + "go.opentelemetry.io/otel/trace" +) + +const SourceKind string = "dgraph" + +// validate interface +var _ sources.SourceConfig = Config{} +var httpClient = &http.Client{} + +// HttpToken stores credentials for making HTTP request +type HttpToken struct { + UserId string + Password string + AccessJwt string + RefreshToken string + Namespace uint64 +} + +type DgraphClient struct { + *HttpToken + baseUrl string + apiKey string +} + +type Config struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + DgraphUrl string `yaml:"dgraphUrl"` + User string `yaml:"user"` + Password string `yaml:"password"` + Namespace uint64 `yaml:"namespace"` + ApiKey string `yaml:"apiKey"` +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + hc, err := initDgraphHttpClient(ctx, tracer, r) + if err != nil { + return nil, err + } + + if err := hc.healthCheck(); err != nil { + return nil, err + } + + s := &Source{ + Name: r.Name, + Kind: SourceKind, + Client: hc, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Client *DgraphClient `yaml:"client"` +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) DgraphClient() *DgraphClient { + return s.Client +} + +func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) { + //nolint:all // Reassigned ctx + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name) + defer span.End() + + if r.DgraphUrl == "" { + return nil, fmt.Errorf("dgraph url should not be empty") + } + + hc := &DgraphClient{ + baseUrl: r.DgraphUrl, + HttpToken: &HttpToken{ + UserId: r.User, + Namespace: r.Namespace, + Password: r.Password, + }, + apiKey: r.ApiKey, + } + + if r.User != "" || r.Password != "" { + if err := hc.loginWithCredentials(); err != nil { + return nil, err + } + } + + return hc, nil +} + +func (hc *DgraphClient) ExecuteQuery(query string, paramsMap map[string]interface{}, + isQuery bool, timeout string) ([]byte, error) { + if isQuery { + return hc.postDqlQuery(query, paramsMap, timeout) + } else { + return hc.mutate(query, paramsMap) + } +} + +// postDqlQuery sends a DQL query to the Dgraph server with query, parameters, and optional timeout. +// Returns the response body ([]byte) and an error, if any. +func (hc *DgraphClient) postDqlQuery(query string, paramsMap map[string]interface{}, timeout string) ([]byte, error) { + urlParams := url.Values{} + urlParams.Add("timeout", timeout) + url, err := getUrl(hc.baseUrl, "/query", urlParams) + if err != nil { + return nil, err + } + p := struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables"` + }{ + Query: query, + Variables: paramsMap, + } + body, err := json.Marshal(p) + if err != nil { + return nil, fmt.Errorf("error marshlling json: %v", err) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body)) + if err != nil { + return nil, fmt.Errorf("error building req for endpoint [%v] :%v", url, err) + } + + req.Header.Add("Content-Type", "application/json") + + return hc.doReq(req) +} + +// mutate sends an RDF mutation to the Dgraph server with "commitNow: true", embedding parameters. +// Returns the server's response as a byte slice or an error if the mutation fails. +func (hc *DgraphClient) mutate(mutation string, paramsMap map[string]interface{}) ([]byte, error) { + mu := embedParamsIntoMutation(mutation, paramsMap) + params := url.Values{} + params.Add("commitNow", "true") + url, err := getUrl(hc.baseUrl, "/mutate", params) + if err != nil { + return nil, err + } + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBufferString(mu)) + if err != nil { + return nil, fmt.Errorf("error building req for endpoint [%v] :%v", url, err) + } + + req.Header.Add("Content-Type", "application/rdf") + + return hc.doReq(req) +} + +func (hc *DgraphClient) doReq(req *http.Request) ([]byte, error) { + if hc.HttpToken != nil { + req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt) + } + if hc.apiKey != "" { + req.Header.Set("Dg-Auth", hc.apiKey) + } + + resp, err := httpClient.Do(req) + + if err != nil && !strings.Contains(err.Error(), "Token is expired") { + return nil, fmt.Errorf("error performing HTTP request: %w", err) + } else if err != nil && strings.Contains(err.Error(), "Token is expired") { + if errLogin := hc.loginWithToken(); errLogin != nil { + return nil, errLogin + } + if hc.HttpToken != nil { + req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt) + } + resp, err = httpClient.Do(req) + if err != nil { + return nil, err + } + } + + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: url: [%v], err: [%v]", req.URL, err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("got non 200 resp: %v", string(respBody)) + } + + return respBody, nil +} + +func (hc *DgraphClient) loginWithCredentials() error { + credentials := map[string]interface{}{ + "userid": hc.UserId, + "password": hc.Password, + "namespace": hc.Namespace, + } + return hc.doLogin(credentials) +} + +func (hc *DgraphClient) loginWithToken() error { + credentials := map[string]interface{}{ + "refreshJWT": hc.RefreshToken, + "namespace": hc.Namespace, + } + return hc.doLogin(credentials) +} + +func (hc *DgraphClient) doLogin(creds map[string]interface{}) error { + url, err := getUrl(hc.baseUrl, "/login", nil) + if err != nil { + return err + } + payload, err := json.Marshal(creds) + if err != nil { + return fmt.Errorf("failed to marshal credentials: %v", err) + } + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload)) + if err != nil { + return fmt.Errorf("error building req for endpoint [%v] : %v", url, err) + } + req.Header.Add("Content-Type", "application/json") + if hc.apiKey != "" { + req.Header.Set("Dg-Auth", hc.apiKey) + } + + resp, err := hc.doReq(req) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") && + !strings.Contains(err.Error(), "unable to authenticate the refresh token") { + return hc.loginWithToken() + } + return err + } + + if err := CheckError(resp); err != nil { + return err + } + + var r struct { + Data struct { + AccessJWT string `json:"accessJWT"` + RefreshJWT string `json:"refreshJWT"` + } `json:"data"` + } + + if err := json.Unmarshal(resp, &r); err != nil { + return fmt.Errorf("failed to unmarshal response: %v", err) + } + + if r.Data.AccessJWT == "" { + return fmt.Errorf("no access JWT found in the response") + } + if r.Data.RefreshJWT == "" { + return fmt.Errorf("no refresh JWT found in the response") + } + + hc.AccessJwt = r.Data.AccessJWT + hc.RefreshToken = r.Data.RefreshJWT + return nil +} + +func (hc *DgraphClient) healthCheck() error { + url, err := getUrl(hc.baseUrl, "/health", nil) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("error performing request: %w", err) + } + + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + var result []struct { + Instance string `json:"instance"` + Address string `json:"address"` + Status string `json:"status"` + } + + // Unmarshal response into the struct + if err := json.Unmarshal(data, &result); err != nil { + return fmt.Errorf("failed to unmarshal json: %v", err) + } + + if len(result) == 0 { + return fmt.Errorf("health info should not empty for: %v", url) + } + + var unhealthyErr error + for _, info := range result { + if info.Status != "healthy" { + unhealthyErr = fmt.Errorf("dgraph instance [%v] is not in healthy state, address is %v", + info.Instance, info.Address) + } else { + return nil + } + } + + return unhealthyErr +} + +func getUrl(baseUrl, resource string, params url.Values) (string, error) { + u, err := url.ParseRequestURI(baseUrl) + if err != nil { + return "", fmt.Errorf("failed to get url %v", err) + } + u.Path = resource + u.RawQuery = params.Encode() + return u.String(), nil +} + +func CheckError(resp []byte) error { + var errResp struct { + Errors []struct { + Message string `json:"message"` + } `json:"errors"` + } + + if err := json.Unmarshal(resp, &errResp); err != nil { + return fmt.Errorf("failed to unmarshal json: %v", err) + } + + if len(errResp.Errors) > 0 { + return fmt.Errorf("error : %v", errResp.Errors) + } + + return nil +} + +func embedParamsIntoMutation(mutation string, paramsMap map[string]interface{}) string { + for key, value := range paramsMap { + mutation = strings.ReplaceAll(mutation, key, fmt.Sprintf(`"%v"`, value)) + } + return mutation +} diff --git a/internal/sources/dgraph/dgraph_test.go b/internal/sources/dgraph/dgraph_test.go new file mode 100644 index 000000000..14b073ec0 --- /dev/null +++ b/internal/sources/dgraph/dgraph_test.go @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dgraph_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/dgraph" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlDgraph(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example", + in: ` + sources: + my-dgraph-instance: + kind: dgraph + dgraphUrl: https://localhost:8080 + apiKey: abc123 + password: pass@123 + namespace: 0 + user: user123 + `, + want: server.SourceConfigs{ + "my-dgraph-instance": dgraph.Config{ + Name: "my-dgraph-instance", + Kind: dgraph.SourceKind, + DgraphUrl: "https://localhost:8080", + ApiKey: "abc123", + Password: "pass@123", + Namespace: 0, + User: "user123", + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Sources); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go new file mode 100644 index 000000000..2fea17d1b --- /dev/null +++ b/internal/tools/dgraph/dgraph.go @@ -0,0 +1,133 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dgraph + +import ( + "encoding/json" + "fmt" + + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/dgraph" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +const ToolKind string = "dgraph-dql" + +type compatibleSource interface { + DgraphClient() *dgraph.DgraphClient +} + +// validate compatible sources are still compatible +var _ compatibleSource = &dgraph.Source{} + +var compatibleSources = [...]string{dgraph.SourceKind} + +type Config struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Source string `yaml:"source"` + Description string `yaml:"description"` + Statement string `yaml:"statement"` + IsQuery bool `yaml:"isQuery"` + Timeout string `yaml:"timeout"` + Parameters tools.Parameters `yaml:"parameters"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return ToolKind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources) + } + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: ToolKind, + Parameters: cfg.Parameters, + Statement: cfg.Statement, + DgraphClient: s.DgraphClient(), + IsQuery: cfg.IsQuery, + Timeout: cfg.Timeout, + manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()}, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Parameters tools.Parameters `yaml:"parameters"` + AuthRequired []string `yaml:"authRequired"` + DgraphClient *dgraph.DgraphClient + IsQuery bool + Timeout string + Statement string + manifest tools.Manifest +} + +func (t Tool) Invoke(params tools.ParamValues) (string, error) { + paramsMap := params.AsMap() + + resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) + if err != nil { + return "", err + } + + if err := dgraph.CheckError(resp); err != nil { + return "", err + } + + var result struct { + Data map[string]interface{} `json:"data"` + } + + if err := json.Unmarshal(resp, &result); err != nil { + return "", fmt.Errorf("error parsing JSON: %v", err) + } + + return fmt.Sprintf( + "Stub tool call for %q! Parameters parsed: %q \n Output: %v", + t.Name, paramsMap, result.Data, + ), nil +} + +func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claimsMap) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) Authorized(verifiedAuthSources []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources) +} diff --git a/internal/tools/dgraph/dgraph_test.go b/internal/tools/dgraph/dgraph_test.go new file mode 100644 index 000000000..ee863b210 --- /dev/null +++ b/internal/tools/dgraph/dgraph_test.go @@ -0,0 +1,96 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dgraph_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/dgraph" +) + +func TestParseFromYamlDgraph(t *testing.T) { + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic query example", + in: ` + tools: + example_tool: + kind: dgraph-dql + source: my-dgraph-instance + description: some tool description + isQuery: true + timeout: 20s + statement: | + query {q(func: eq(email, "example@email.com")) {email}} + `, + want: server.ToolConfigs{ + "example_tool": dgraph.Config{ + Name: "example_tool", + Kind: dgraph.ToolKind, + Source: "my-dgraph-instance", + Description: "some tool description", + IsQuery: true, + Timeout: "20s", + Statement: "query {q(func: eq(email, \"example@email.com\")) {email}}\n", + }, + }, + }, + { + desc: "basic mutation example", + in: ` + tools: + example_tool: + kind: dgraph-dql + source: my-dgraph-instance + description: some tool description + statement: | + mutation {set { _:a "a@email.com" . _:b "b@email.com" .}} + `, + want: server.ToolConfigs{ + "example_tool": dgraph.Config{ + Name: "example_tool", + Kind: dgraph.ToolKind, + Source: "my-dgraph-instance", + Description: "some tool description", + Statement: "mutation {set { _:a \"a@email.com\" . _:b \"b@email.com\" .}}\n", + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/tests/dgraph.yaml b/tests/dgraph.yaml new file mode 100644 index 000000000..56ffb4adf --- /dev/null +++ b/tests/dgraph.yaml @@ -0,0 +1,31 @@ +sources: + dgraph-manage-user-instance: + kind: "dgraph" + dgraphUrl: "https://green-feather-41401502.ap-south-1.aws.cloud.dgraph.io" + apiKey: "OTJmMTc0NDQ4MTJmZDk3MTlmMWY0ZjMzYmE2YjZkNzc=" + +tools: + search_user: + kind: dgraph-dql + source: dgraph-manage-user-instance + statement: | + query all($role: string){ + users(func: has(name)) @filter(eq(role, $role) AND ge(age, 30) AND le(age, 50)) { + uid + name + email + role + age + } + } + isQuery: true + timeout: 20s + description: | + Use this tool to insert or update user data into the Dgraph database. + The mutation adds or updates user details like name, email, role, and age. + Example: Add users Alice and Bob as admins with specific ages. + parameters: + - name: $role + type: string + description: admin + \ No newline at end of file diff --git a/tests/dgraph_integration_test.go b/tests/dgraph_integration_test.go new file mode 100644 index 000000000..9a821e0cc --- /dev/null +++ b/tests/dgraph_integration_test.go @@ -0,0 +1,173 @@ +//go:build integration && dgraph + +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "os" + "reflect" + "regexp" + "testing" + "time" +) + +var ( + DGRAPH_URL = os.Getenv("DGRAPH_URL") + DGRAPH_APIKEY = os.Getenv("DGRAPH_APIKEY") +) + +func requireDgraphVars(t *testing.T) { + switch "" { + case DGRAPH_URL: + t.Fatal("'DGRAPH_URL' not set") + case DGRAPH_APIKEY: + t.Fatal("'DGRAPH_APIKEY' not set") + } +} + +func TestDgraph(t *testing.T) { + requireDgraphVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + + // Write config into a file and pass it to command + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-dgraph-instance": map[string]any{ + "kind": "dgraph", + "dgraphUrl": DGRAPH_URL, + "apiKey": DGRAPH_APIKEY, + }, + }, + "tools": map[string]any{ + "my-simple-dql-tool": map[string]any{ + "kind": "dgraph-dql", + "source": "my-dgraph-instance", + "description": "Simple tool to test end to end functionality.", + "statement": "{result(func: uid(0x0)) {constant: math(1)}}", + "isQuery": true, + "timeout": "20s", + "parameters": []any{}, + }, + }, + } + cmd, cleanup, err := StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`)) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Test tool get endpoint + tcs := []struct { + name string + api string + want map[string]any + }{ + { + name: "get my-simple-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/", + want: map[string]any{ + "my-simple-dql-tool": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "parameters": []any{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Get(tc.api) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("response status code is not 200") + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["tools"] + if !ok { + t.Fatalf("unable to find tools in response body") + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } + + // Test tool invoke endpoint + invokeTcs := []struct { + name string + api string + requestBody io.Reader + want string + }{ + { + name: "invoke my-simple-dql-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: "Stub tool call for \"my-simple-dql-tool\"! Parameters parsed: map[]" + + " \n Output: map[result:[map[constant:1]]]", + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Post(tc.api, "application/json", tc.requestBody) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + if got != tc.want { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + }) + } +}