forked from instill-ai/component
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add: implemented vertex ai with go-langchain
- Loading branch information
Showing
7 changed files
with
420 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package vertexai | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/tmc/langchaingo/llms" | ||
"github.com/tmc/langchaingo/llms/googleai" | ||
"github.com/tmc/langchaingo/llms/googleai/vertex" | ||
"go.uber.org/zap" | ||
) | ||
|
||
type VertexAIClient struct { | ||
langchainClient *vertex.Vertex | ||
} | ||
|
||
func NewClient(setup VertexAISetup, logger *zap.Logger) *VertexAIClient { | ||
ctx := context.Background() | ||
langchainClient, err := vertex.New(ctx, googleai.WithCloudLocation(setup.Location), googleai.WithCloudProject(setup.ProjectID), googleai.WithCredentialsJSON(json.RawMessage(setup.Cred))) | ||
if err != nil { | ||
logger.Fatal("failed to create langchain client", zap.Error(err)) | ||
} | ||
return &VertexAIClient{ | ||
langchainClient: langchainClient, | ||
} | ||
} | ||
|
||
type EmbedRequest struct { | ||
Text string `json:"text"` | ||
Model string `json:"model"` | ||
} | ||
|
||
type EmbedResponse struct { | ||
Embedding []float32 `json:"embedding"` | ||
Tokens int `json:"tokens"` | ||
} | ||
|
||
func (c *VertexAIClient) Embed(req EmbedRequest) (EmbedResponse, error) { | ||
resp := EmbedResponse{} | ||
ctx := context.Background() | ||
// this function only support generate embedding with Palm, not with other models | ||
content, err := c.langchainClient.CreateEmbedding(ctx, []string{req.Text}) | ||
if err != nil { | ||
return EmbedResponse{}, err | ||
} | ||
resp.Embedding = content[0] | ||
resp.Tokens = 0 | ||
return resp, nil | ||
} | ||
|
||
type ChatRequest struct { | ||
Messages []llms.MessageContent `json:"messages"` | ||
Model string `json:"model"` | ||
MaxTokens int `json:"max_tokens"` | ||
Temperature float64 `json:"temperature"` | ||
TopP float64 `json:"top_p"` | ||
TopK int `json:"top_k"` | ||
Seed int `json:"seed"` | ||
} | ||
|
||
type ChatResponse struct { | ||
Text string `json:"text"` | ||
InputTokens int `json:"input_tokens"` | ||
OutputTokens int `json:"output_tokens"` | ||
} | ||
|
||
func (c *VertexAIClient) Chat(req ChatRequest) (ChatResponse, error) { | ||
resp := ChatResponse{} | ||
ctx := context.Background() | ||
content, err := c.langchainClient.GenerateContent(ctx, req.Messages, llms.WithModel(req.Model), llms.WithMaxTokens(req.MaxTokens), llms.WithTemperature(req.Temperature), llms.WithTopP(req.TopP), llms.WithTopK(req.TopK), llms.WithSeed(req.Seed)) | ||
if err != nil { | ||
return ChatResponse{}, err | ||
} | ||
resp.Text = content.Choices[0].Content | ||
print(content.Choices[0].GenerationInfo) | ||
resp.InputTokens = 0 | ||
resp.OutputTokens = 0 | ||
return resp, nil | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
//go:generate compogen readme ./config ./README.mdx | ||
package vertexai | ||
|
||
import ( | ||
"context" | ||
_ "embed" | ||
"fmt" | ||
"sync" | ||
|
||
"google.golang.org/protobuf/types/known/structpb" | ||
|
||
"github.com/instill-ai/component/base" | ||
) | ||
|
||
const ( | ||
TaskTextGenerationChat = "TASK_TEXT_GENERATION_CHAT" | ||
TaskTextEmbeddings = "TASK_TEXT_EMBEDDINGS" | ||
) | ||
|
||
var ( | ||
//go:embed config/definition.json | ||
definitionJSON []byte | ||
//go:embed config/setup.json | ||
setupJSON []byte | ||
//go:embed config/tasks.json | ||
tasksJSON []byte | ||
|
||
once sync.Once | ||
comp *component | ||
) | ||
|
||
type component struct { | ||
base.Component | ||
} | ||
|
||
func Init(bc base.Component) *component { | ||
once.Do(func() { | ||
comp = &component{Component: bc} | ||
err := comp.LoadDefinition(definitionJSON, setupJSON, tasksJSON, nil) | ||
if err != nil { | ||
panic(err) | ||
} | ||
}) | ||
return comp | ||
} | ||
|
||
type VertexAIClientInterface interface { | ||
Chat(ChatRequest) (ChatResponse, error) | ||
Embed(EmbedRequest) (EmbedResponse, error) | ||
} | ||
|
||
type VertexAISetup struct { | ||
ProjectID string `json:"project-id"` | ||
Cred string `json:"cred"` | ||
Location string `json:"location"` | ||
} | ||
|
||
type execution struct { | ||
base.ComponentExecution | ||
client VertexAIClientInterface | ||
execute func(*structpb.Struct) (*structpb.Struct, error) | ||
} | ||
|
||
func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) { | ||
outputs := make([]*structpb.Struct, len(inputs)) | ||
|
||
// The execution takes a array of inputs and returns an array of outputs. The execution is done sequentially. | ||
for i, input := range inputs { | ||
output, err := e.execute(input) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
outputs[i] = output | ||
} | ||
|
||
return outputs, nil | ||
} | ||
|
||
func (c *component) CreateExecution(x base.ComponentExecution) (base.IExecution, error) { | ||
setupStruct := VertexAISetup{} | ||
if err := base.ConvertFromStructpb(x.Setup, setupStruct); err != nil { | ||
return nil, fmt.Errorf("error parsing setup, %v", err) | ||
} | ||
|
||
e := &execution{ | ||
ComponentExecution: x, | ||
client: NewClient(setupStruct, c.Logger), | ||
} | ||
switch x.Task { | ||
case TaskTextGenerationChat: | ||
e.execute = e.TaskTextGenerationChat | ||
case TaskTextEmbeddings: | ||
e.execute = e.TaskTextEmbeddings | ||
default: | ||
return nil, fmt.Errorf("unsupported task") | ||
} | ||
return e, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package vertexai | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/instill-ai/component/base" | ||
"github.com/tmc/langchaingo/llms" | ||
"google.golang.org/protobuf/types/known/structpb" | ||
) | ||
|
||
type ChatMessage struct { | ||
Role string `json:"role"` | ||
Content []MultiModalContent `json:"content"` | ||
} | ||
type URL struct { | ||
URL string `json:"url"` | ||
} | ||
|
||
type MultiModalContent struct { | ||
ImageURL URL `json:"image-url"` | ||
Text string `json:"text"` | ||
Type string `json:"type"` | ||
} | ||
|
||
type TaskTextGenerationChatInput struct { | ||
ChatHistory []ChatMessage `json:"chat-history"` | ||
MaxNewTokens int `json:"max-new-tokens"` | ||
ModelName string `json:"model-name"` | ||
Prompt string `json:"prompt"` | ||
PromptImages []string `json:"prompt-images"` | ||
Seed int `json:"seed"` | ||
SystemMsg string `json:"system-message"` | ||
Temperature float64 `json:"temperature"` | ||
TopK int `json:"top-k"` | ||
TopP float64 `json:"top-p"` | ||
Safe bool `json:"safe"` | ||
} | ||
|
||
type TaskTextGenerationChatUsage struct { | ||
InputTokens int `json:"input-tokens"` | ||
OutputTokens int `json:"output-tokens"` | ||
} | ||
|
||
type TaskTextGenerationChatOutput struct { | ||
Text string `json:"text"` | ||
Usage TaskTextGenerationChatUsage `json:"usage"` | ||
} | ||
|
||
type TaskTextEmbeddingsInput struct { | ||
Text string `json:"text"` | ||
ModelName string `json:"model-name"` | ||
} | ||
|
||
type TaskTextEmbeddingsUsage struct { | ||
Tokens int `json:"tokens"` | ||
} | ||
|
||
type TaskTextEmbeddingsOutput struct { | ||
Embedding []float64 `json:"embedding"` | ||
Usage TaskTextEmbeddingsUsage `json:"usage"` | ||
} | ||
|
||
func (e *execution) TaskTextGenerationChat(in *structpb.Struct) (*structpb.Struct, error) { | ||
|
||
inputStruct := TaskTextGenerationChatInput{} | ||
err := base.ConvertFromStructpb(in, &inputStruct) | ||
if err != nil { | ||
return nil, fmt.Errorf("error generating input struct: %v", err) | ||
} | ||
|
||
messages := []llms.MessageContent{} | ||
|
||
if inputStruct.SystemMsg != "" { | ||
messages = append(messages, llms.MessageContent{ | ||
Role: llms.ChatMessageType("system"), // note: not sure if this is correct, go-langchain does not have a system role | ||
Parts: []llms.ContentPart{ | ||
llms.TextPart(inputStruct.SystemMsg), | ||
}, | ||
}) | ||
} | ||
|
||
for _, chatMessage := range inputStruct.ChatHistory { | ||
messageContent := []llms.ContentPart{} | ||
for _, content := range chatMessage.Content { | ||
if content.Type == "text" { | ||
messageContent = append(messageContent, llms.TextPart(content.Text)) | ||
} else if content.Type == "image" { | ||
messageContent = append(messageContent, llms.ImageURLPart(content.ImageURL.URL)) | ||
} | ||
} | ||
if len(messageContent) == 0 { | ||
continue | ||
} | ||
messages = append(messages, llms.MessageContent{ | ||
Role: llms.ChatMessageType(chatMessage.Role), | ||
Parts: messageContent, | ||
}) | ||
} | ||
|
||
promptContent := []llms.ContentPart{} | ||
|
||
for _, content := range inputStruct.PromptImages { | ||
promptContent = append(promptContent, llms.ImageURLPart(content)) | ||
} | ||
|
||
promptContent = append(promptContent, llms.TextPart(inputStruct.Prompt)) | ||
|
||
messages = append(messages, llms.MessageContent{ | ||
Role: llms.ChatMessageType("user"), | ||
Parts: promptContent, | ||
}) | ||
|
||
req := ChatRequest{ | ||
Messages: messages, | ||
Model: inputStruct.ModelName, | ||
MaxTokens: inputStruct.MaxNewTokens, | ||
Temperature: inputStruct.Temperature, | ||
TopP: inputStruct.TopP, | ||
TopK: inputStruct.TopK, | ||
Seed: inputStruct.Seed, | ||
} | ||
resp, err := e.client.Chat(req) | ||
if err != nil { | ||
return nil, fmt.Errorf("error calling Chat: %v", err) | ||
} | ||
outputStruct := TaskTextGenerationChatOutput{ | ||
Text: resp.Text, | ||
Usage: TaskTextGenerationChatUsage{ | ||
InputTokens: resp.InputTokens, | ||
OutputTokens: resp.OutputTokens, | ||
}, | ||
} | ||
return base.ConvertToStructpb(outputStruct) | ||
} | ||
|
||
func (e *execution) TaskTextEmbeddings(in *structpb.Struct) (*structpb.Struct, error) { | ||
|
||
inputStruct := TaskTextEmbeddingsInput{} | ||
err := base.ConvertFromStructpb(in, &inputStruct) | ||
if err != nil { | ||
return nil, fmt.Errorf("error generating input struct: %v", err) | ||
} | ||
outputStruct := TaskTextEmbeddingsOutput{} | ||
return base.ConvertToStructpb(outputStruct) | ||
} |
Oops, something went wrong.