Skip to content

Commit

Permalink
add: implemented vertex ai with go-langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
namwoam committed Aug 8, 2024
1 parent 06fd227 commit dc2626d
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 71 deletions.
80 changes: 80 additions & 0 deletions ai/vertexai/v0/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package vertexai

import (
"context"
"encoding/json"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/googleai/vertex"
"go.uber.org/zap"
)

type VertexAIClient struct {
langchainClient *vertex.Vertex
}

func NewClient(setup VertexAISetup, logger *zap.Logger) *VertexAIClient {
ctx := context.Background()
langchainClient, err := vertex.New(ctx, googleai.WithCloudLocation(setup.Location), googleai.WithCloudProject(setup.ProjectID), googleai.WithCredentialsJSON(json.RawMessage(setup.Cred)))
if err != nil {
logger.Fatal("failed to create langchain client", zap.Error(err))
}
return &VertexAIClient{
langchainClient: langchainClient,
}
}

type EmbedRequest struct {
Text string `json:"text"`
Model string `json:"model"`
}

type EmbedResponse struct {
Embedding []float32 `json:"embedding"`
Tokens int `json:"tokens"`
}

func (c *VertexAIClient) Embed(req EmbedRequest) (EmbedResponse, error) {
resp := EmbedResponse{}
ctx := context.Background()
// this function only support generate embedding with Palm, not with other models
content, err := c.langchainClient.CreateEmbedding(ctx, []string{req.Text})
if err != nil {
return EmbedResponse{}, err
}
resp.Embedding = content[0]
resp.Tokens = 0
return resp, nil
}

type ChatRequest struct {
Messages []llms.MessageContent `json:"messages"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
Seed int `json:"seed"`
}

type ChatResponse struct {
Text string `json:"text"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

func (c *VertexAIClient) Chat(req ChatRequest) (ChatResponse, error) {
resp := ChatResponse{}
ctx := context.Background()
content, err := c.langchainClient.GenerateContent(ctx, req.Messages, llms.WithModel(req.Model), llms.WithMaxTokens(req.MaxTokens), llms.WithTemperature(req.Temperature), llms.WithTopP(req.TopP), llms.WithTopK(req.TopK), llms.WithSeed(req.Seed))
if err != nil {
return ChatResponse{}, err
}
resp.Text = content.Choices[0].Content
print(content.Choices[0].GenerationInfo)
resp.InputTokens = 0
resp.OutputTokens = 0
return resp, nil

}
8 changes: 5 additions & 3 deletions ai/vertexai/v0/config/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@
"model": {
"enum": [
"gecko",
"gecko-multilingual"
"gecko-multilingual",
"palm"
],
"example": "gecko",
"example": "palm",
"description": "The Vertex AI embedding model to be used",
"instillAcceptFormats": [
"string"
Expand All @@ -341,7 +342,8 @@
"instillCredentialMap": {
"values": [
"gecko",
"gecko-multilingual"
"gecko-multilingual",
"palm"
],
"targets": [
"setup.api-key"
Expand Down
99 changes: 99 additions & 0 deletions ai/vertexai/v0/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//go:generate compogen readme ./config ./README.mdx
package vertexai

import (
"context"
_ "embed"
"fmt"
"sync"

"google.golang.org/protobuf/types/known/structpb"

"github.com/instill-ai/component/base"
)

const (
TaskTextGenerationChat = "TASK_TEXT_GENERATION_CHAT"
TaskTextEmbeddings = "TASK_TEXT_EMBEDDINGS"
)

var (
//go:embed config/definition.json
definitionJSON []byte
//go:embed config/setup.json
setupJSON []byte
//go:embed config/tasks.json
tasksJSON []byte

once sync.Once
comp *component
)

type component struct {
base.Component
}

func Init(bc base.Component) *component {
once.Do(func() {
comp = &component{Component: bc}
err := comp.LoadDefinition(definitionJSON, setupJSON, tasksJSON, nil)
if err != nil {
panic(err)
}
})
return comp
}

type VertexAIClientInterface interface {
Chat(ChatRequest) (ChatResponse, error)
Embed(EmbedRequest) (EmbedResponse, error)
}

type VertexAISetup struct {
ProjectID string `json:"project-id"`
Cred string `json:"cred"`
Location string `json:"location"`
}

type execution struct {
base.ComponentExecution
client VertexAIClientInterface
execute func(*structpb.Struct) (*structpb.Struct, error)
}

func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
outputs := make([]*structpb.Struct, len(inputs))

// The execution takes a array of inputs and returns an array of outputs. The execution is done sequentially.
for i, input := range inputs {
output, err := e.execute(input)
if err != nil {
return nil, err
}

outputs[i] = output
}

return outputs, nil
}

func (c *component) CreateExecution(x base.ComponentExecution) (base.IExecution, error) {
setupStruct := VertexAISetup{}
if err := base.ConvertFromStructpb(x.Setup, setupStruct); err != nil {
return nil, fmt.Errorf("error parsing setup, %v", err)
}

e := &execution{
ComponentExecution: x,
client: NewClient(setupStruct, c.Logger),
}
switch x.Task {
case TaskTextGenerationChat:
e.execute = e.TaskTextGenerationChat
case TaskTextEmbeddings:
e.execute = e.TaskTextEmbeddings
default:
return nil, fmt.Errorf("unsupported task")
}
return e, nil
}
145 changes: 145 additions & 0 deletions ai/vertexai/v0/tasks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package vertexai

import (
"fmt"

"github.com/instill-ai/component/base"
"github.com/tmc/langchaingo/llms"
"google.golang.org/protobuf/types/known/structpb"
)

type ChatMessage struct {
Role string `json:"role"`
Content []MultiModalContent `json:"content"`
}
type URL struct {
URL string `json:"url"`
}

type MultiModalContent struct {
ImageURL URL `json:"image-url"`
Text string `json:"text"`
Type string `json:"type"`
}

type TaskTextGenerationChatInput struct {
ChatHistory []ChatMessage `json:"chat-history"`
MaxNewTokens int `json:"max-new-tokens"`
ModelName string `json:"model-name"`
Prompt string `json:"prompt"`
PromptImages []string `json:"prompt-images"`
Seed int `json:"seed"`
SystemMsg string `json:"system-message"`
Temperature float64 `json:"temperature"`
TopK int `json:"top-k"`
TopP float64 `json:"top-p"`
Safe bool `json:"safe"`
}

type TaskTextGenerationChatUsage struct {
InputTokens int `json:"input-tokens"`
OutputTokens int `json:"output-tokens"`
}

type TaskTextGenerationChatOutput struct {
Text string `json:"text"`
Usage TaskTextGenerationChatUsage `json:"usage"`
}

type TaskTextEmbeddingsInput struct {
Text string `json:"text"`
ModelName string `json:"model-name"`
}

type TaskTextEmbeddingsUsage struct {
Tokens int `json:"tokens"`
}

type TaskTextEmbeddingsOutput struct {
Embedding []float64 `json:"embedding"`
Usage TaskTextEmbeddingsUsage `json:"usage"`
}

func (e *execution) TaskTextGenerationChat(in *structpb.Struct) (*structpb.Struct, error) {

inputStruct := TaskTextGenerationChatInput{}
err := base.ConvertFromStructpb(in, &inputStruct)
if err != nil {
return nil, fmt.Errorf("error generating input struct: %v", err)
}

messages := []llms.MessageContent{}

if inputStruct.SystemMsg != "" {
messages = append(messages, llms.MessageContent{
Role: llms.ChatMessageType("system"), // note: not sure if this is correct, go-langchain does not have a system role
Parts: []llms.ContentPart{
llms.TextPart(inputStruct.SystemMsg),
},
})
}

for _, chatMessage := range inputStruct.ChatHistory {
messageContent := []llms.ContentPart{}
for _, content := range chatMessage.Content {
if content.Type == "text" {
messageContent = append(messageContent, llms.TextPart(content.Text))
} else if content.Type == "image" {
messageContent = append(messageContent, llms.ImageURLPart(content.ImageURL.URL))
}
}
if len(messageContent) == 0 {
continue
}
messages = append(messages, llms.MessageContent{
Role: llms.ChatMessageType(chatMessage.Role),
Parts: messageContent,
})
}

promptContent := []llms.ContentPart{}

for _, content := range inputStruct.PromptImages {
promptContent = append(promptContent, llms.ImageURLPart(content))
}

promptContent = append(promptContent, llms.TextPart(inputStruct.Prompt))

messages = append(messages, llms.MessageContent{
Role: llms.ChatMessageType("user"),
Parts: promptContent,
})

req := ChatRequest{
Messages: messages,
Model: inputStruct.ModelName,
MaxTokens: inputStruct.MaxNewTokens,
Temperature: inputStruct.Temperature,
TopP: inputStruct.TopP,
TopK: inputStruct.TopK,
Seed: inputStruct.Seed,
}
resp, err := e.client.Chat(req)
if err != nil {
return nil, fmt.Errorf("error calling Chat: %v", err)
}
outputStruct := TaskTextGenerationChatOutput{
Text: resp.Text,
Usage: TaskTextGenerationChatUsage{
InputTokens: resp.InputTokens,
OutputTokens: resp.OutputTokens,
},
}
return base.ConvertToStructpb(outputStruct)
}

func (e *execution) TaskTextEmbeddings(in *structpb.Struct) (*structpb.Struct, error) {

inputStruct := TaskTextEmbeddingsInput{}
err := base.ConvertFromStructpb(in, &inputStruct)
if err != nil {
return nil, fmt.Errorf("error generating input struct: %v", err)
}
outputStruct := TaskTextEmbeddingsOutput{}
return base.ConvertToStructpb(outputStruct)
}
Loading

0 comments on commit dc2626d

Please sign in to comment.