From 1e2c189599f2d8acf84281e93cf9ba27db18a1a9 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Fri, 8 Nov 2024 23:12:34 +0000 Subject: [PATCH] feat: support requesting a single tool --- internal/server/api.go | 25 ++++++++- internal/server/api_test.go | 108 +++++++++++++++++++++++++++++++++++- 2 files changed, 129 insertions(+), 4 deletions(-) diff --git a/internal/server/api.go b/internal/server/api.go index c7f20f326..9fb52510a 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -21,6 +21,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" + "github.com/googleapis/genai-toolbox/internal/tools" ) // apiRouter creates a router that represents the routes under /api @@ -28,12 +29,14 @@ func apiRouter(s *Server) (chi.Router, error) { r := chi.NewRouter() r.Use(middleware.AllowContentType("application/json")) + r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) - r.Get("/toolset/", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) + r.Get("/toolset", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) r.Route("/tool/{toolName}", func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { toolGetHandler(s, w, r) }) r.Post("/invoke", func(w http.ResponseWriter, r *http.Request) { toolInvokeHandler(s, w, r) }) }) @@ -51,6 +54,26 @@ func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) { render.JSON(w, r, toolset.Manifest) } +// toolGetHandler handles requests for a single Tool. +func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + toolName := chi.URLParam(r, "toolName") + tool, ok := s.tools[toolName] + if !ok { + err := fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + // TODO: this can be optimized later with some caching + m := tools.ToolsetManifest{ + ServerVersion: s.conf.Version, + ToolsManifest: map[string]tools.Manifest{ + toolName: tool.Manifest(), + }, + } + + render.JSON(w, r, m) +} + // toolInvokeHandler handles the API request to invoke a specific Tool. func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { toolName := chi.URLParam(r, "toolName") diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 1af9c0968..74e448296 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -64,15 +64,15 @@ func TestToolsetEndpoint(t *testing.T) { toolsets[name] = m } - server := Server{tools: toolsMap, toolsets: toolsets} + server := Server{conf: ServerConfig{}, tools: toolsMap, toolsets: toolsets} r, err := apiRouter(&server) if err != nil { - t.Fatalf("unable to initalize router: %s", err) + t.Fatalf("unable to initialize router: %s", err) } ts := httptest.NewServer(r) defer ts.Close() - // wantRepsonse is a struct for checks against test cases + // wantResponse is a struct for checks against test cases type wantResponse struct { statusCode int isErr bool @@ -160,6 +160,108 @@ func TestToolsetEndpoint(t *testing.T) { }) } } +func TestToolGetEndpoint(t *testing.T) { + // Set up resources to test against + tool1 := MockTool{ + Name: "no_params", + Params: []tools.Parameter{}, + } + tool2 := MockTool{ + Name: "some_params", + Params: tools.Parameters{ + tools.NewIntParameter("param1", "This is the first parameter."), + tools.NewIntParameter("param2", "This is the second parameter."), + }, + } + toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2} + + server := Server{conf: ServerConfig{Version: "0.0.0"}, tools: toolsMap} + r, err := apiRouter(&server) + if err != nil { + t.Fatalf("unable to initialize router: %s", err) + } + ts := httptest.NewServer(r) + defer ts.Close() + + // wantResponse is a struct for checks against test cases + type wantResponse struct { + statusCode int + isErr bool + version string + tools []string + } + + testCases := []struct { + name string + toolName string + want wantResponse + }{ + { + name: "tool1", + toolName: tool1.Name, + want: wantResponse{ + statusCode: http.StatusOK, + version: "0.0.0", + tools: []string{tool1.Name}, + }, + }, + { + name: "tool2", + toolName: tool2.Name, + want: wantResponse{ + statusCode: http.StatusOK, + version: "0.0.0", + tools: []string{tool2.Name}, + }, + }, + { + name: "invalid tool", + toolName: "some_imaginary_tool", + want: wantResponse{ + statusCode: http.StatusNotFound, + isErr: true, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != tc.want.statusCode { + t.Logf("response body: %s", body) + t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode) + } + if tc.want.isErr { + // skip the rest of the checks if this is an error case + return + } + var m tools.ToolsetManifest + err = json.Unmarshal(body, &m) + if err != nil { + t.Fatalf("unable to parse ToolsetManifest: %s", err) + } + // Check the version is correct + if m.ServerVersion != tc.want.version { + t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion) + } + // validate that the tools in the toolset are correct + for _, name := range tc.want.tools { + _, ok := m.ToolsManifest[name] + if !ok { + t.Errorf("%q tool not found in manfiest", name) + } + } + }) + } +} func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) { req, err := http.NewRequest(method, ts.URL+path, body)