diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 5126898c2..622b05581 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -36,5 +36,5 @@ type EmbedRequest struct { // RegisterEmbedder registers the actions for a specific embedder. func RegisterEmbedder(name string, embedder Embedder) { genkit.RegisterAction(genkit.ActionTypeEmbedder, name, - genkit.NewAction(name, embedder.Embed)) + genkit.NewAction(name, nil, embedder.Embed)) } diff --git a/go/ai/generator.go b/go/ai/generator.go index 745ff2663..8833cd009 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -33,10 +33,39 @@ type Generator interface { Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error) } +// GeneratorCapabilities describes various capabilities of the generator. +type GeneratorCapabilities struct { + Multiturn bool + Media bool + Tools bool + SystemRole bool +} + +// GeneratorMetadata is the metadata of the generator, specifying things like nice user visible label, capabilities, etc. +type GeneratorMetadata struct { + Label string + Supports GeneratorCapabilities +} + // RegisterGenerator registers the generator in the global registry. -func RegisterGenerator(name string, generator Generator) { - genkit.RegisterAction(genkit.ActionTypeModel, name, - genkit.NewStreamingAction(name, generator.Generate)) +func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, generator Generator) { + metadataMap := map[string]any{} + if metadata != nil { + if metadata.Label != "" { + metadataMap["label"] = metadata.Label + } + supports := map[string]bool{ + "media": metadata.Supports.Media, + "multiturn": metadata.Supports.Multiturn, + "systemRole": metadata.Supports.SystemRole, + "tools": metadata.Supports.Tools, + } + metadataMap["supports"] = supports + } + genkit.RegisterAction(genkit.ActionTypeModel, provider, + genkit.NewStreamingAction(name, map[string]any{ + "model": metadataMap, + }, generator.Generate)) } // generatorActionType is the instantiated genkit.Action type registered diff --git a/go/ai/retriever.go b/go/ai/retriever.go index b47558c22..9bb1fc029 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -52,10 +52,10 @@ type RetrieverResponse struct { // RegisterRetriever registers the actions for a specific retriever. func RegisterRetriever(name string, retriever Retriever) { genkit.RegisterAction(genkit.ActionTypeRetriever, name, - genkit.NewAction(name, retriever.Retrieve)) + genkit.NewAction(name, nil, retriever.Retrieve)) genkit.RegisterAction(genkit.ActionTypeIndexer, name, - genkit.NewAction(name, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { + genkit.NewAction(name, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) { err := retriever.Index(ctx, req) return struct{}{}, err })) diff --git a/go/genkit/action.go b/go/genkit/action.go index 556902ebb..e05220792 100644 --- a/go/genkit/action.go +++ b/go/genkit/action.go @@ -65,14 +65,14 @@ type Action[I, O, S any] struct { // See js/common/src/types.ts // NewAction creates a new Action with the given name and non-streaming function. -func NewAction[I, O any](name string, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { - return NewStreamingAction(name, func(ctx context.Context, in I, cb NoStream) (O, error) { +func NewAction[I, O any](name string, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] { + return NewStreamingAction(name, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) { return fn(ctx, in) }) } // NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[I, O, S any](name string, fn Func[I, O, S]) *Action[I, O, S] { +func NewStreamingAction[I, O, S any](name string, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] { var i I var o O return &Action[I, O, S]{ @@ -80,6 +80,7 @@ func NewStreamingAction[I, O, S any](name string, fn Func[I, O, S]) *Action[I, O fn: fn, inputSchema: inferJSONSchema(i), outputSchema: inferJSONSchema(o), + Metadata: metadata, } } diff --git a/go/genkit/action_test.go b/go/genkit/action_test.go index 7b074e8a4..82c638b54 100644 --- a/go/genkit/action_test.go +++ b/go/genkit/action_test.go @@ -26,7 +26,7 @@ func inc(_ context.Context, x int) (int, error) { } func TestActionRun(t *testing.T) { - a := NewAction("inc", inc) + a := NewAction("inc", nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -37,7 +37,7 @@ func TestActionRun(t *testing.T) { } func TestActionRunJSON(t *testing.T) { - a := NewAction("inc", inc) + a := NewAction("inc", nil, inc) input := []byte("3") want := []byte("4") got, err := a.runJSON(context.Background(), input, nil) @@ -63,7 +63,7 @@ func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) { func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := NewStreamingAction("count", count) + a := NewStreamingAction("count", nil, count) const n = 3 // Non-streaming. diff --git a/go/genkit/dev_server_test.go b/go/genkit/dev_server_test.go index a6ad7ea6b..679f39cb5 100644 --- a/go/genkit/dev_server_test.go +++ b/go/genkit/dev_server_test.go @@ -37,8 +37,12 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.registerAction("test", "devServer", NewAction("inc", inc)) - r.registerAction("test", "devServer", NewAction("dec", dec)) + r.registerAction("test", "devServer", NewAction("inc", map[string]any{ + "foo": "bar", + }, inc)) + r.registerAction("test", "devServer", NewAction("dec", map[string]any{ + "bar": "baz", + }, dec)) srv := httptest.NewServer(newDevServerMux(r)) defer srv.Close() @@ -78,10 +82,9 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - md := map[string]any{"inputSchema": nil, "outputSchema": nil} want := map[string]actionDesc{ - "/test/devServer/dec": {Key: "/test/devServer/dec", Name: "dec", Metadata: md}, - "/test/devServer/inc": {Key: "/test/devServer/inc", Name: "inc", Metadata: md}, + "/test/devServer/inc": {Key: "/test/devServer/inc", Name: "inc", Metadata: map[string]any{"inputSchema": nil, "outputSchema": nil, "foo": "bar"}}, + "/test/devServer/dec": {Key: "/test/devServer/dec", Name: "dec", Metadata: map[string]any{"inputSchema": nil, "outputSchema": nil, "bar": "baz"}}, } if !maps.EqualFunc(got, want, actionDesc.equal) { t.Errorf("\n got %v\nwant %v", got, want) diff --git a/go/genkit/dotprompt/genkit.go b/go/genkit/dotprompt/genkit.go index ee5a7ad8c..a44cf2c25 100644 --- a/go/genkit/dotprompt/genkit.go +++ b/go/genkit/dotprompt/genkit.go @@ -134,9 +134,9 @@ func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, st name += "." + p.Variant } - a := genkit.NewAction(name, p.Execute) + a := genkit.NewAction(name, nil, p.Execute) a.Metadata = map[string]any{ - "type": "prompt", + "type": "prompt", "prompt": p, } return a, nil diff --git a/go/genkit/flow.go b/go/genkit/flow.go index a4993492e..0dcab4b78 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -227,7 +227,7 @@ type FlowResult[O any] struct { // action creates an action for the flow. See the comment at the top of this file for more information. func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], S] { - return NewStreamingAction(f.name, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { + return NewStreamingAction(f.name, nil, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { spanMetaKey.fromContext(ctx).SetAttr("flow:wrapperAction", "true") return f.runInstruction(ctx, inst, cb) }) diff --git a/go/genkit/tracing_test.go b/go/genkit/tracing_test.go index 52bdf272b..b3f117871 100644 --- a/go/genkit/tracing_test.go +++ b/go/genkit/tracing_test.go @@ -56,7 +56,7 @@ func TestSpanMetadata(t *testing.T) { func TestTracing(t *testing.T) { ctx := context.Background() const actionName = "TestTracing-inc" - a := NewAction(actionName, inc) + a := NewAction(actionName, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index fcb826a2d..96109d868 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -204,12 +204,16 @@ func Init(ctx context.Context, model, apiKey string) error { return err } ai.RegisterEmbedder("google-genai", e) - g, err := NewGenerator(ctx, model, apiKey) if err != nil { return err } - ai.RegisterGenerator("google-genai", g) + ai.RegisterGenerator("google-genai", model, &ai.GeneratorMetadata{ + Label: "Google AI - " + model, + Supports: ai.GeneratorCapabilities{ + Multiturn: true, + }, + }, g) return nil } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 8dd053d82..301832eff 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -133,7 +133,12 @@ func Init(ctx context.Context, model, projectID, location string) error { if err != nil { return err } - ai.RegisterGenerator("google-vertexai", g) + ai.RegisterGenerator("google-vertexai", model, &ai.GeneratorMetadata{ + Label: "Vertex AI - " + model, + Supports: ai.GeneratorCapabilities{ + Multiturn: true, + }, + }, g) return nil } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 87d3a2d88..f17186a47 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -66,9 +66,9 @@ type customerTimeAndHistoryInput struct { } type testAllCoffeeFlowsOutput struct { - Pass bool `json:"pass"` + Pass bool `json:"pass"` Replies []string `json:"replies,omitempty"` - Error string `json:"error,omitempty"` + Error string `json:"error,omitempty"` } func main() { @@ -79,13 +79,13 @@ func main() { os.Exit(1) } - if err := googleai.Init(context.Background(), "gemini-pro", apiKey); err != nil { + if err := googleai.Init(context.Background(), "gemini-1.0-pro", apiKey); err != nil { log.Fatal(err) } simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", &dotprompt.Frontmatter{ - Name: "simpleGreeting", + Name: "simpleGreeting", Model: "google-genai", Input: dotprompt.FrontmatterInput{ Schema: jsonschema.Reflect(simpleGreetingInput{}), @@ -106,7 +106,7 @@ func main() { if err != nil { return "", err } - ai := &dotprompt.ActionInput{ Variables: vars } + ai := &dotprompt.ActionInput{Variables: vars} resp, err := simpleGreetingPrompt.Execute(ctx, ai) if err != nil { return "", err @@ -120,7 +120,7 @@ func main() { greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", &dotprompt.Frontmatter{ - Name: "greetingWithHistory", + Name: "greetingWithHistory", Model: "google-genai", Input: dotprompt.FrontmatterInput{ Schema: jsonschema.Reflect(customerTimeAndHistoryInput{}), @@ -141,7 +141,7 @@ func main() { if err != nil { return "", err } - ai := &dotprompt.ActionInput{ Variables: vars } + ai := &dotprompt.ActionInput{Variables: vars} resp, err := greetingWithHistoryPrompt.Execute(ctx, ai) if err != nil { return "", err @@ -177,7 +177,7 @@ func main() { return out, nil } out := &testAllCoffeeFlowsOutput{ - Pass: true, + Pass: true, Replies: []string{ test1, test2,