forked from instill-ai/component
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into namwoam/fireworks
- Loading branch information
Showing
128 changed files
with
25,502 additions
and
851 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package groq | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/instill-ai/component/internal/util/httpclient" | ||
"go.uber.org/zap" | ||
"google.golang.org/protobuf/types/known/structpb" | ||
) | ||
|
||
const ( | ||
Endpoint = "https://api.groq.com" | ||
) | ||
|
||
// reference: https://console.groq.com/docs/api-reference on 2024-08-05 | ||
|
||
type errBody struct { | ||
Error struct { | ||
Message string `json:"message"` | ||
} `json:"error"` | ||
} | ||
|
||
func (e errBody) Message() string { | ||
return e.Error.Message | ||
} | ||
|
||
type GroqClient struct { | ||
httpClient *httpclient.Client | ||
} | ||
|
||
func NewClient(token string, logger *zap.Logger) *GroqClient { | ||
c := httpclient.New("Groq", Endpoint, httpclient.WithLogger(logger), | ||
httpclient.WithEndUserError(new(errBody))) | ||
c.SetAuthToken(token) | ||
return &GroqClient{httpClient: c} | ||
} | ||
|
||
type GroqChatMessageInterface interface { | ||
} | ||
|
||
type GroqChatMessage struct { | ||
Role string `json:"role"` | ||
Content []GroqChatContent `json:"content"` | ||
} | ||
|
||
type GroqSystemMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
type GroqChatContent struct { | ||
ImageURL *GroqURL `json:"image_url,omitempty"` | ||
Text string `json:"text"` | ||
Type GroqChatContentType `json:"type,omitempty"` | ||
} | ||
|
||
type GroqChatContentType string | ||
|
||
const ( | ||
GroqChatContentTypeText GroqChatContentType = "text" | ||
GroqChatContentTypeImage GroqChatContentType = "image" | ||
) | ||
|
||
type GroqURL struct { | ||
URL string `json:"url"` | ||
Detail string `json:"detail,omitempty"` | ||
} | ||
|
||
type ChatRequest struct { | ||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` | ||
MaxTokens int `json:"max_tokens"` | ||
Model string `json:"model"` | ||
Messages []GroqChatMessageInterface `json:"messages"` | ||
N int `json:"n,omitempty"` | ||
PresencePenalty float32 `json:"presence_penalty,omitempty"` | ||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` | ||
Seed int `json:"seed,omitempty"` | ||
Stop []string `json:"stop"` | ||
Stream bool `json:"stream,omitempty"` | ||
Temperature float32 `json:"temperature,omitempty"` | ||
TopP float32 `json:"top_p,omitempty"` | ||
User string `json:"user,omitempty"` | ||
} | ||
|
||
type ChatResponse struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int `json:"created"` | ||
Model string `json:"model"` | ||
Choices []GroqChoice `json:"choices"` | ||
Usage GroqUsage `json:"usage"` | ||
} | ||
|
||
type GroqChoice struct { | ||
Index int `json:"index"` | ||
Message GroqResponseMessage `json:"message"` | ||
FinishReason string `json:"finish_reason"` | ||
} | ||
|
||
type GroqResponseMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
type GroqUsage struct { | ||
PromptTokens int `json:"prompt_tokens"` | ||
CompletionTokens int `json:"completion_tokens"` | ||
TotalTokens int `json:"total_tokens"` | ||
PromptTime float32 `json:"prompt_time"` | ||
CompletionTime float32 `json:"completion_time"` | ||
TotalTime float32 `json:"total_time"` | ||
} | ||
|
||
func (c *GroqClient) Chat(request ChatRequest) (ChatResponse, error) { | ||
response := ChatResponse{} | ||
req := c.httpClient.R().SetResult(&response).SetBody(request) | ||
if resp, err := req.Post("/openai/v1/chat/completions"); err != nil { | ||
if resp != nil { | ||
respString := string(resp.Body()) | ||
return response, fmt.Errorf("error when sending chat request %v: %s", err, respString) | ||
} | ||
return response, fmt.Errorf("error when sending chat request %v", err) | ||
} | ||
return response, nil | ||
} | ||
|
||
func getAPIKey(setup *structpb.Struct) string { | ||
return setup.GetFields()[cfgAPIKey].GetStringValue() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
package groq | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/gojuno/minimock/v3" | ||
"go.uber.org/zap" | ||
"google.golang.org/protobuf/types/known/structpb" | ||
|
||
qt "github.com/frankban/quicktest" | ||
|
||
"github.com/instill-ai/component/base" | ||
) | ||
|
||
const ( | ||
MockAPIKey = "### Mock API Key ###" | ||
) | ||
|
||
func TestComponent_Execute(t *testing.T) { | ||
c := qt.New(t) | ||
|
||
bc := base.Component{Logger: zap.NewNop()} | ||
cmp := Init(bc) | ||
|
||
c.Run("ok - supported task", func(c *qt.C) { | ||
task := TaskTextGenerationChat | ||
|
||
_, err := cmp.CreateExecution(base.ComponentExecution{ | ||
Component: cmp, | ||
Task: task, | ||
}) | ||
c.Check(err, qt.IsNil) | ||
}) | ||
|
||
c.Run("nok - unsupported task", func(c *qt.C) { | ||
task := "FOOBAR" | ||
|
||
_, err := cmp.CreateExecution(base.ComponentExecution{ | ||
Component: cmp, | ||
Task: task, | ||
}) | ||
c.Check(err, qt.ErrorMatches, "unsupported task") | ||
}) | ||
} | ||
|
||
func TestComponent_Tasks(t *testing.T) { | ||
mc := minimock.NewController(t) | ||
c := qt.New(t) | ||
bc := base.Component{Logger: zap.NewNop()} | ||
connector := Init(bc) | ||
ctx := context.Background() | ||
|
||
GroqClientMock := NewGroqClientInterfaceMock(mc) | ||
GroqClientMock.ChatMock. | ||
When(ChatRequest{ | ||
Model: "llama-3.1-405b-reasoning", | ||
Messages: []GroqChatMessageInterface{ | ||
GroqChatMessage{ | ||
Role: "user", | ||
Content: []GroqChatContent{ | ||
{ | ||
Text: "Tell me a joke", | ||
Type: GroqChatContentTypeText, | ||
}, | ||
}, | ||
}, | ||
}, | ||
N: 1, | ||
Stop: []string{}, | ||
}). | ||
Then(ChatResponse{ | ||
ID: "34a9110d-c39d-423b-9ab9-9c748747b204", | ||
Object: "chat.completion", | ||
Model: "llama-3.1-405b-reasoning", | ||
Created: 1708045122, | ||
Usage: GroqUsage{ | ||
PromptTokens: 24, | ||
CompletionTokens: 377, | ||
TotalTokens: 401, | ||
PromptTime: 0.009, | ||
CompletionTime: 0.774, | ||
TotalTime: 0.783, | ||
}, | ||
Choices: []GroqChoice{ | ||
{ | ||
Index: 0, | ||
FinishReason: "stop", | ||
Message: GroqResponseMessage{ | ||
Role: "assistant", | ||
Content: "\nWhy did the tomato turn red?\nAnswer: Because it saw the salad dressing", | ||
}, | ||
}, | ||
}, | ||
}, nil) | ||
GroqClientMock.ChatMock. | ||
When(ChatRequest{ | ||
Model: "gemini", | ||
Messages: []GroqChatMessageInterface{ | ||
GroqChatMessage{ | ||
Role: "user", | ||
Content: []GroqChatContent{ | ||
{ | ||
Text: "Tell me a joke", | ||
Type: GroqChatContentTypeText, | ||
}, | ||
}, | ||
}, | ||
}, | ||
N: 1, | ||
Stop: []string{}, | ||
}). | ||
Then(ChatResponse{}, fmt.Errorf("error when sending chat request %s", `no access to "gemini"`)) | ||
|
||
c.Run("ok - task text generation", func(c *qt.C) { | ||
setup, err := structpb.NewStruct(map[string]any{ | ||
"api-key": MockAPIKey, | ||
}) | ||
c.Assert(err, qt.IsNil) | ||
e := &execution{ | ||
ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskTextGenerationChat}, | ||
client: GroqClientMock, | ||
} | ||
e.execute = e.TaskTextGenerationChat | ||
|
||
pbIn, err := base.ConvertToStructpb(map[string]any{"model": "llama-3.1-405b-reasoning", "prompt": "Tell me a joke"}) | ||
c.Assert(err, qt.IsNil) | ||
|
||
got, err := e.Execute(ctx, []*structpb.Struct{pbIn}) | ||
c.Assert(err, qt.IsNil) | ||
|
||
wantJSON, err := json.Marshal(TaskTextGenerationChatOuput{Text: "\nWhy did the tomato turn red?\nAnswer: Because it saw the salad dressing", Usage: TaskTextGenerationChatUsage{InputTokens: 24, OutputTokens: 377}}) | ||
c.Assert(err, qt.IsNil) | ||
c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) | ||
}) | ||
|
||
c.Run("nok - task text generation", func(c *qt.C) { | ||
setup, err := structpb.NewStruct(map[string]any{ | ||
"api-key": MockAPIKey, | ||
}) | ||
c.Assert(err, qt.IsNil) | ||
e := &execution{ | ||
ComponentExecution: base.ComponentExecution{Component: connector, SystemVariables: nil, Setup: setup, Task: TaskTextGenerationChat}, | ||
client: GroqClientMock, | ||
} | ||
e.execute = e.TaskTextGenerationChat | ||
|
||
pbIn, err := base.ConvertToStructpb(map[string]any{"model": "gemini", "prompt": "Tell me a joke"}) | ||
c.Assert(err, qt.IsNil) | ||
|
||
_, err = e.Execute(ctx, []*structpb.Struct{pbIn}) | ||
c.Assert(err, qt.ErrorMatches, `error when sending chat request no access to "gemini"`) | ||
}) | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"availableTasks": [ | ||
"TASK_TEXT_GENERATION_CHAT" | ||
], | ||
"documentationUrl": "https://www.instill.tech/docs/component/ai/groq", | ||
"icon": "assets/groq.svg", | ||
"id": "groq", | ||
"public": true, | ||
"title": "Groq", | ||
"description": "Connect the AI models served on GroqCloud", | ||
"type": "COMPONENT_TYPE_AI", | ||
"uid": "d5e64e5c-2dd2-4358-82dd-0e3a035c2157", | ||
"vendor": "Groq", | ||
"vendorAttributes": {}, | ||
"version": "0.1.0", | ||
"sourceUrl": "https://github.com/instill-ai/component/blob/main/ai/groq/v0", | ||
"releaseStage": "RELEASE_STAGE_ALPHA" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
{ | ||
"$schema": "http://json-schema.org/draft-07/schema#", | ||
"additionalProperties": true, | ||
"properties": { | ||
"api-key": { | ||
"description": "Fill in your GroqCloud API key. To find your keys, visit the GroqCloud API Keys page.", | ||
"instillUpstreamTypes": [ | ||
"reference" | ||
], | ||
"instillAcceptFormats": [ | ||
"string" | ||
], | ||
"instillSecret": true, | ||
"instillCredential": true, | ||
"instillUIOrder": 0, | ||
"title": "API Key", | ||
"type": "string" | ||
} | ||
}, | ||
"required": [ | ||
"api-key" | ||
], | ||
"instillEditOnNodeFields": [ | ||
"api-key" | ||
], | ||
"title": "GroqCloud Connection", | ||
"type": "object" | ||
} |
Oops, something went wrong.