diff --git a/ai/anthropic/v0/component_test.go b/ai/anthropic/v0/component_test.go index ed1c313e..5e4679dc 100644 --- a/ai/anthropic/v0/component_test.go +++ b/ai/anthropic/v0/component_test.go @@ -188,10 +188,10 @@ func TestComponent_Generation(t *testing.T) { tc := struct { input map[string]any - wantResp messagesOutput + wantResp MessagesOutput }{ input: map[string]any{"prompt": "Hi! What's your name?", "chat-history": mockHistory}, - wantResp: messagesOutput{ + wantResp: MessagesOutput{ Text: "Hi! My name is Claude. (messageCount: 3)", Usage: messagesUsage{ InputTokens: 10, diff --git a/ai/anthropic/v0/main.go b/ai/anthropic/v0/main.go index c4d22d0f..2b935fbf 100644 --- a/ai/anthropic/v0/main.go +++ b/ai/anthropic/v0/main.go @@ -4,12 +4,10 @@ package anthropic import ( "context" _ "embed" - "encoding/json" "fmt" "slices" "sync" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/structpb" "github.com/instill-ai/component/base" @@ -72,7 +70,34 @@ type messagesReq struct { TopP float32 `json:"top_p,omitempty"` } -type messagesOutput struct { +type MessagesInput struct { + ChatHistory []ChatMessage `json:"chat-history"` + MaxNewTokens int `json:"max-new-tokens"` + ModelName string `json:"model-name"` + Prompt string `json:"prompt"` + PromptImages []string `json:"prompt-images"` + Seed int `json:"seed"` + SystemMsg string `json:"system-message"` + Temperature float32 `json:"temperature"` + TopK int `json:"top-k"` +} + +type ChatMessage struct { + Role string `json:"role"` + Content []MultiModalContent `json:"content"` +} + +type MultiModalContent struct { + ImageURL URL `json:"image-url"` + Text string `json:"text"` + Type string `json:"type"` +} + +type URL struct { + URL string `json:"url"` +} + +type MessagesOutput struct { Text string `json:"text"` Usage messagesUsage `json:"usage"` } @@ -203,68 +228,56 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st return outputs, nil } -func retriveChatMessage(chatHistoryPbList *structpb.ListValue) []message { - messages := []message{} - for _, messagePbValue := range chatHistoryPbList.GetValues() { - contents := []content{} - for _, contentPbValue := range messagePbValue.GetStructValue().Fields["content"].GetListValue().GetValues() { - contentType := contentPbValue.GetStructValue().Fields["type"].GetStringValue() - // anthrothpic models does not support image urls - if contentType == "text" { - content := content{ - Type: "text", - Text: contentPbValue.GetStructValue().Fields["text"].GetStringValue(), - } - contents = append(contents, content) - } - } - completeMessage := message{Role: messagePbValue.GetStructValue().Fields["role"].GetStringValue(), Content: contents} - messages = append(messages, completeMessage) - } - return messages -} - func (e *execution) generateText(in *structpb.Struct) (*structpb.Struct, error) { - prompt := in.Fields["prompt"].GetStringValue() + var inputStruct MessagesInput + err := base.ConvertFromStructpb(in, &inputStruct) + if err != nil { + return nil, err + } + + prompt := inputStruct.Prompt messages := []message{} - chatHistory := in.Fields["chat-history"].GetListValue() + chatHistory := inputStruct.ChatHistory - if chatHistory != nil { - messages = retriveChatMessage(chatHistory) + for _, chatMessage := range chatHistory { + contents := getContents(chatMessage) + message := message{Role: chatMessage.Role, Content: contents} + messages = append(messages, message) } + finalMessage := message{ Role: "user", Content: []content{{Type: "text", Text: prompt}}, } - if in.Fields["prompt-images"] != nil { - for _, image := range in.Fields["prompt-images"].GetListValue().GetValues() { - extension := base.GetBase64FileExtension(image.GetStringValue()) - // check if the image extension is supported - if !slices.Contains(supportedImageExtensions, extension) { - return nil, fmt.Errorf("unsupported image extension, expected one of: %v , got %s", supportedImageExtensions, extension) - } - image := content{ - Type: "image", - Source: &source{Type: "base64", MediaType: fmt.Sprintf("image/%s", extension), Data: base.TrimBase64Mime(image.GetStringValue())}, - } - finalMessage.Content = append(finalMessage.Content, image) + promptImages := inputStruct.PromptImages + for _, image := range promptImages { + extension := base.GetBase64FileExtension(image) + // check if the image extension is supported + if !slices.Contains(supportedImageExtensions, extension) { + return nil, fmt.Errorf("unsupported image extension, expected one of: %v , got %s", supportedImageExtensions, extension) } + image := content{ + Type: "image", + Source: &source{Type: "base64", MediaType: fmt.Sprintf("image/%s", extension), Data: base.TrimBase64Mime(image)}, + } + finalMessage.Content = append(finalMessage.Content, image) } + messages = append(messages, finalMessage) req := messagesReq{ Messages: messages, - Model: in.Fields["model-name"].GetStringValue(), - MaxTokens: int(in.Fields["max-new-tokens"].GetNumberValue()), - System: in.Fields["system-message"].GetStringValue(), - TopK: int(in.Fields["top-k"].GetNumberValue()), - Temperature: float32(in.Fields["temperature"].GetNumberValue()), + Model: inputStruct.ModelName, + MaxTokens: inputStruct.MaxNewTokens, + System: inputStruct.SystemMsg, + TopK: inputStruct.TopK, + Temperature: float32(inputStruct.Temperature), } resp, err := e.client.generateTextChat(req) @@ -273,7 +286,7 @@ func (e *execution) generateText(in *structpb.Struct) (*structpb.Struct, error) return nil, err } - outputStruct := messagesOutput{ + outputStruct := MessagesOutput{ Text: "", Usage: messagesUsage{ InputTokens: resp.Usage.InputTokens, @@ -284,14 +297,24 @@ func (e *execution) generateText(in *structpb.Struct) (*structpb.Struct, error) outputStruct.Text += c.Text } - outputJSON, err := json.Marshal(outputStruct) + output, err := base.ConvertToStructpb(outputStruct) if err != nil { return nil, err } - output := structpb.Struct{} - err = protojson.Unmarshal(outputJSON, &output) - if err != nil { - return nil, err + return output, nil +} + +func getContents(chatMessage ChatMessage) []content { + contents := []content{} + for _, multiModalContent := range chatMessage.Content { + if multiModalContent.Type == "text" { + contentReq := content{ + Type: "text", + Text: multiModalContent.Text, + } + contents = append(contents, contentReq) + } } - return &output, nil + + return contents }