Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] adding metadata to models, etc. #71

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ type EmbedRequest struct {
// RegisterEmbedder registers the actions for a specific embedder.
func RegisterEmbedder(name string, embedder Embedder) {
genkit.RegisterAction(genkit.ActionTypeEmbedder, name,
genkit.NewAction(name, embedder.Embed))
genkit.NewAction(name, nil, embedder.Embed))
}
35 changes: 32 additions & 3 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,39 @@ type Generator interface {
Generate(context.Context, *GenerateRequest, genkit.StreamingCallback[*Candidate]) (*GenerateResponse, error)
}

// GeneratorCapabilities describes various capabilities of the generator.
type GeneratorCapabilities struct {
Multiturn bool
Media bool
Tools bool
SystemRole bool
}

// GeneratorMetadata is the metadata of the generator, specifying things like nice user visible label, capabilities, etc.
type GeneratorMetadata struct {
Label string
Supports GeneratorCapabilities
}

// RegisterGenerator registers the generator in the global registry.
func RegisterGenerator(name string, generator Generator) {
genkit.RegisterAction(genkit.ActionTypeModel, name,
genkit.NewStreamingAction(name, generator.Generate))
func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, generator Generator) {
metadataMap := map[string]any{}
if metadata != nil {
if metadata.Label != "" {
metadataMap["label"] = metadata.Label
}
supports := map[string]bool{
"media": metadata.Supports.Media,
"multiturn": metadata.Supports.Multiturn,
"systemRole": metadata.Supports.SystemRole,
"tools": metadata.Supports.Tools,
}
metadataMap["supports"] = supports
}
genkit.RegisterAction(genkit.ActionTypeModel, provider,
genkit.NewStreamingAction(name, map[string]any{
"model": metadataMap,
}, generator.Generate))
}

// generatorActionType is the instantiated genkit.Action type registered
Expand Down
4 changes: 2 additions & 2 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ type RetrieverResponse struct {
// RegisterRetriever registers the actions for a specific retriever.
func RegisterRetriever(name string, retriever Retriever) {
genkit.RegisterAction(genkit.ActionTypeRetriever, name,
genkit.NewAction(name, retriever.Retrieve))
genkit.NewAction(name, nil, retriever.Retrieve))

genkit.RegisterAction(genkit.ActionTypeIndexer, name,
genkit.NewAction(name, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
genkit.NewAction(name, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
err := retriever.Index(ctx, req)
return struct{}{}, err
}))
Expand Down
7 changes: 4 additions & 3 deletions go/genkit/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ type Action[I, O, S any] struct {
// See js/common/src/types.ts

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[I, O any](name string, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] {
return NewStreamingAction(name, func(ctx context.Context, in I, cb NoStream) (O, error) {
func NewAction[I, O any](name string, metadata map[string]any, fn func(context.Context, I) (O, error)) *Action[I, O, struct{}] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you see metadata being used? Who is the producer and who is the consumer? Typically in Go we would use specific types here (and make the metadata parameter type any), but if this is truly general, or user-controlled, than a map makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the action level metadata is not strictly defined, it is an arbitrary map. But at the specific action type level (model, embedder, retriever, etc.) metadata is expected to have specific structure (if it wants to be understood). This is an example of model metadata:
image

It is important because, for example, the model playground needs to know whether the model supports multiturn input (if it does then it will render a chat style interface) or whether it support multimodal input (it will allow inserting images).

but it's not just the Dev UI, for example the generate function will check model supported features and throw errors if the developer is trying to do something that's not supported by that specific model:

if (!model.__action.metadata?.model.supports?.tools) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That suggests that the metadata parameter to NewAction should be type any, and the actual action will type assert to the desired value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the action schema it is defined as map[string]any (Record<string, any>):

.record(z.string(), CustomAnySchema)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without getting into too much detail, I think the question here comes down to whether this value will ever be used in Go code, or will it just be written to JSON for the benefit of the tooling?

If no Go code needs to access the metadata except to JSON-marshal it, then Ian's way is more type-safe and Go-like. With an any parameter, a user can still pass in a map[string]any, or they can pass in a struct that is JSON-marshalable (like the GeneratorCapabilities one you've defined). Both a map[string]any and a struct with the right field names or struct tags will marshal to the same JSON, but the latter will be easier for Go programmers to write and won't need to be hand-converted to a map in the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go code might need/want to use it. This example:

if (!model.__action.metadata?.model.supports?.tools) {

Maybe in that case it could be added to the Generator interface?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using specific types for generators, as you've already written, is consistent with having the metadata parameter of NewAction by type any. Generators can type assert the metadata as needed. Perhaps I misunderstand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be possible to get the GeneratorCapabilities from a Generator without going into the action.

Copy link
Collaborator Author

@pavelgj pavelgj May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried changing the metadata type on NewAction to any, but I don't think it's possible/easy... I'll leave that to you. :)

Basically:
https://github.com/firebase/genkit/blob/main/go/genkit/action.go#L62
and
https://github.com/firebase/genkit/blob/main/go/genkit/action.go#L173

return NewStreamingAction(name, metadata, func(ctx context.Context, in I, cb NoStream) (O, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[I, O, S any](name string, fn Func[I, O, S]) *Action[I, O, S] {
func NewStreamingAction[I, O, S any](name string, metadata map[string]any, fn Func[I, O, S]) *Action[I, O, S] {
var i I
var o O
return &Action[I, O, S]{
name: name,
fn: fn,
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
Metadata: metadata,
}
}

Expand Down
6 changes: 3 additions & 3 deletions go/genkit/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func inc(_ context.Context, x int) (int, error) {
}

func TestActionRun(t *testing.T) {
a := NewAction("inc", inc)
a := NewAction("inc", nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -37,7 +37,7 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := NewAction("inc", inc)
a := NewAction("inc", nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.runJSON(context.Background(), input, nil)
Expand All @@ -63,7 +63,7 @@ func count(ctx context.Context, n int, cb StreamingCallback[int]) (int, error) {

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := NewStreamingAction("count", count)
a := NewStreamingAction("count", nil, count)
const n = 3

// Non-streaming.
Expand Down
13 changes: 8 additions & 5 deletions go/genkit/dev_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r.registerAction("test", "devServer", NewAction("inc", inc))
r.registerAction("test", "devServer", NewAction("dec", dec))
r.registerAction("test", "devServer", NewAction("inc", map[string]any{
"foo": "bar",
}, inc))
r.registerAction("test", "devServer", NewAction("dec", map[string]any{
"bar": "baz",
}, dec))
srv := httptest.NewServer(newDevServerMux(r))
defer srv.Close()

Expand Down Expand Up @@ -78,10 +82,9 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
md := map[string]any{"inputSchema": nil, "outputSchema": nil}
want := map[string]actionDesc{
"/test/devServer/dec": {Key: "/test/devServer/dec", Name: "dec", Metadata: md},
"/test/devServer/inc": {Key: "/test/devServer/inc", Name: "inc", Metadata: md},
"/test/devServer/inc": {Key: "/test/devServer/inc", Name: "inc", Metadata: map[string]any{"inputSchema": nil, "outputSchema": nil, "foo": "bar"}},
"/test/devServer/dec": {Key: "/test/devServer/dec", Name: "dec", Metadata: map[string]any{"inputSchema": nil, "outputSchema": nil, "bar": "baz"}},
}
if !maps.EqualFunc(got, want, actionDesc.equal) {
t.Errorf("\n got %v\nwant %v", got, want)
Expand Down
4 changes: 2 additions & 2 deletions go/genkit/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ func (p *Prompt) Action() (*genkit.Action[*ActionInput, *ai.GenerateResponse, st
name += "." + p.Variant
}

a := genkit.NewAction(name, p.Execute)
a := genkit.NewAction(name, nil, p.Execute)
a.Metadata = map[string]any{
"type": "prompt",
"type": "prompt",
"prompt": p,
}
return a, nil
Expand Down
2 changes: 1 addition & 1 deletion go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ type FlowResult[O any] struct {

// action creates an action for the flow. See the comment at the top of this file for more information.
func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], S] {
return NewStreamingAction(f.name, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) {
return NewStreamingAction(f.name, nil, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) {
spanMetaKey.fromContext(ctx).SetAttr("flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, cb)
})
Expand Down
2 changes: 1 addition & 1 deletion go/genkit/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestSpanMetadata(t *testing.T) {
func TestTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := NewAction(actionName, inc)
a := NewAction(actionName, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 6 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ func Init(ctx context.Context, model, apiKey string) error {
return err
}
ai.RegisterEmbedder("google-genai", e)

g, err := NewGenerator(ctx, model, apiKey)
if err != nil {
return err
}
ai.RegisterGenerator("google-genai", g)
ai.RegisterGenerator("google-genai", model, &ai.GeneratorMetadata{
Label: "Google AI - " + model,
Supports: ai.GeneratorCapabilities{
Multiturn: true,
},
}, g)

return nil
}
Expand Down
7 changes: 6 additions & 1 deletion go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ func Init(ctx context.Context, model, projectID, location string) error {
if err != nil {
return err
}
ai.RegisterGenerator("google-vertexai", g)
ai.RegisterGenerator("google-vertexai", model, &ai.GeneratorMetadata{
Label: "Vertex AI - " + model,
Supports: ai.GeneratorCapabilities{
Multiturn: true,
},
}, g)

return nil
}
Expand Down
16 changes: 8 additions & 8 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ type customerTimeAndHistoryInput struct {
}

type testAllCoffeeFlowsOutput struct {
Pass bool `json:"pass"`
Pass bool `json:"pass"`
Replies []string `json:"replies,omitempty"`
Error string `json:"error,omitempty"`
Error string `json:"error,omitempty"`
}

func main() {
Expand All @@ -79,13 +79,13 @@ func main() {
os.Exit(1)
}

if err := googleai.Init(context.Background(), "gemini-pro", apiKey); err != nil {
if err := googleai.Init(context.Background(), "gemini-1.0-pro", apiKey); err != nil {
log.Fatal(err)
}

simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting",
&dotprompt.Frontmatter{
Name: "simpleGreeting",
Name: "simpleGreeting",
Model: "google-genai",
Input: dotprompt.FrontmatterInput{
Schema: jsonschema.Reflect(simpleGreetingInput{}),
Expand All @@ -106,7 +106,7 @@ func main() {
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{ Variables: vars }
ai := &dotprompt.ActionInput{Variables: vars}
resp, err := simpleGreetingPrompt.Execute(ctx, ai)
if err != nil {
return "", err
Expand All @@ -120,7 +120,7 @@ func main() {

greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory",
&dotprompt.Frontmatter{
Name: "greetingWithHistory",
Name: "greetingWithHistory",
Model: "google-genai",
Input: dotprompt.FrontmatterInput{
Schema: jsonschema.Reflect(customerTimeAndHistoryInput{}),
Expand All @@ -141,7 +141,7 @@ func main() {
if err != nil {
return "", err
}
ai := &dotprompt.ActionInput{ Variables: vars }
ai := &dotprompt.ActionInput{Variables: vars}
resp, err := greetingWithHistoryPrompt.Execute(ctx, ai)
if err != nil {
return "", err
Expand Down Expand Up @@ -177,7 +177,7 @@ func main() {
return out, nil
}
out := &testAllCoffeeFlowsOutput{
Pass: true,
Pass: true,
Replies: []string{
test1,
test2,
Expand Down
Loading