Skip to content
This repository has been archived by the owner on Jan 9, 2025. It is now read-only.

Commit

Permalink
fix: expose input and output for anthropic for instill credit (#190)
Browse files Browse the repository at this point in the history
Because

- calculator should rely on component package

This commit

- expose the input and output for the package
  • Loading branch information
chuang8511 authored Jul 1, 2024
1 parent 07ce4b9 commit a36e876
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 54 deletions.
4 changes: 2 additions & 2 deletions ai/anthropic/v0/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ func TestComponent_Generation(t *testing.T) {

tc := struct {
input map[string]any
wantResp messagesOutput
wantResp MessagesOutput
}{
input: map[string]any{"prompt": "Hi! What's your name?", "chat-history": mockHistory},
wantResp: messagesOutput{
wantResp: MessagesOutput{
Text: "Hi! My name is Claude. (messageCount: 3)",
Usage: messagesUsage{
InputTokens: 10,
Expand Down
127 changes: 75 additions & 52 deletions ai/anthropic/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ package anthropic
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"slices"
"sync"

"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

"github.com/instill-ai/component/base"
Expand Down Expand Up @@ -72,7 +70,34 @@ type messagesReq struct {
TopP float32 `json:"top_p,omitempty"`
}

type messagesOutput struct {
type MessagesInput struct {
ChatHistory []ChatMessage `json:"chat-history"`
MaxNewTokens int `json:"max-new-tokens"`
ModelName string `json:"model-name"`
Prompt string `json:"prompt"`
PromptImages []string `json:"prompt-images"`
Seed int `json:"seed"`
SystemMsg string `json:"system-message"`
Temperature float32 `json:"temperature"`
TopK int `json:"top-k"`
}

type ChatMessage struct {
Role string `json:"role"`
Content []MultiModalContent `json:"content"`
}

type MultiModalContent struct {
ImageURL URL `json:"image-url"`
Text string `json:"text"`
Type string `json:"type"`
}

type URL struct {
URL string `json:"url"`
}

type MessagesOutput struct {
Text string `json:"text"`
Usage messagesUsage `json:"usage"`
}
Expand Down Expand Up @@ -203,68 +228,56 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st
return outputs, nil
}

func retriveChatMessage(chatHistoryPbList *structpb.ListValue) []message {
messages := []message{}
for _, messagePbValue := range chatHistoryPbList.GetValues() {
contents := []content{}
for _, contentPbValue := range messagePbValue.GetStructValue().Fields["content"].GetListValue().GetValues() {
contentType := contentPbValue.GetStructValue().Fields["type"].GetStringValue()
// anthrothpic models does not support image urls
if contentType == "text" {
content := content{
Type: "text",
Text: contentPbValue.GetStructValue().Fields["text"].GetStringValue(),
}
contents = append(contents, content)
}
}
completeMessage := message{Role: messagePbValue.GetStructValue().Fields["role"].GetStringValue(), Content: contents}
messages = append(messages, completeMessage)
}
return messages
}

func (e *execution) generateText(in *structpb.Struct) (*structpb.Struct, error) {

prompt := in.Fields["prompt"].GetStringValue()
var inputStruct MessagesInput
err := base.ConvertFromStructpb(in, &inputStruct)
if err != nil {
return nil, err
}

prompt := inputStruct.Prompt

messages := []message{}

chatHistory := in.Fields["chat-history"].GetListValue()
chatHistory := inputStruct.ChatHistory

if chatHistory != nil {
messages = retriveChatMessage(chatHistory)
for _, chatMessage := range chatHistory {
contents := getContents(chatMessage)
message := message{Role: chatMessage.Role, Content: contents}
messages = append(messages, message)
}


finalMessage := message{
Role: "user",
Content: []content{{Type: "text", Text: prompt}},
}

if in.Fields["prompt-images"] != nil {
for _, image := range in.Fields["prompt-images"].GetListValue().GetValues() {
extension := base.GetBase64FileExtension(image.GetStringValue())
// check if the image extension is supported
if !slices.Contains(supportedImageExtensions, extension) {
return nil, fmt.Errorf("unsupported image extension, expected one of: %v , got %s", supportedImageExtensions, extension)
}
image := content{
Type: "image",
Source: &source{Type: "base64", MediaType: fmt.Sprintf("image/%s", extension), Data: base.TrimBase64Mime(image.GetStringValue())},
}
finalMessage.Content = append(finalMessage.Content, image)
promptImages := inputStruct.PromptImages
for _, image := range promptImages {
extension := base.GetBase64FileExtension(image)
// check if the image extension is supported
if !slices.Contains(supportedImageExtensions, extension) {
return nil, fmt.Errorf("unsupported image extension, expected one of: %v , got %s", supportedImageExtensions, extension)
}
image := content{
Type: "image",
Source: &source{Type: "base64", MediaType: fmt.Sprintf("image/%s", extension), Data: base.TrimBase64Mime(image)},
}
finalMessage.Content = append(finalMessage.Content, image)
}


messages = append(messages, finalMessage)

req := messagesReq{
Messages: messages,
Model: in.Fields["model-name"].GetStringValue(),
MaxTokens: int(in.Fields["max-new-tokens"].GetNumberValue()),
System: in.Fields["system-message"].GetStringValue(),
TopK: int(in.Fields["top-k"].GetNumberValue()),
Temperature: float32(in.Fields["temperature"].GetNumberValue()),
Model: inputStruct.ModelName,
MaxTokens: inputStruct.MaxNewTokens,
System: inputStruct.SystemMsg,
TopK: inputStruct.TopK,
Temperature: float32(inputStruct.Temperature),
}

resp, err := e.client.generateTextChat(req)
Expand All @@ -273,7 +286,7 @@ func (e *execution) generateText(in *structpb.Struct) (*structpb.Struct, error)
return nil, err
}

outputStruct := messagesOutput{
outputStruct := MessagesOutput{
Text: "",
Usage: messagesUsage{
InputTokens: resp.Usage.InputTokens,
Expand All @@ -284,14 +297,24 @@ func (e *execution) generateText(in *structpb.Struct) (*structpb.Struct, error)
outputStruct.Text += c.Text
}

outputJSON, err := json.Marshal(outputStruct)
output, err := base.ConvertToStructpb(outputStruct)
if err != nil {
return nil, err
}
output := structpb.Struct{}
err = protojson.Unmarshal(outputJSON, &output)
if err != nil {
return nil, err
return output, nil
}

func getContents(chatMessage ChatMessage) []content {
contents := []content{}
for _, multiModalContent := range chatMessage.Content {
if multiModalContent.Type == "text" {
contentReq := content{
Type: "text",
Text: multiModalContent.Text,
}
contents = append(contents, contentReq)
}
}
return &output, nil

return contents
}

0 comments on commit a36e876

Please sign in to comment.