-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added AI -> Embedding service related functions (#14)
This pull request includes significant updates to the CI configuration and the addition of a new embedding generation feature in the `ai` package. The most important changes are detailed below: ### CI Configuration Updates: * Updated the `actions/checkout` action from version 3 to version 4 in the `.github/workflows/CI.yaml` file. [[1]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L17-R23) [[2]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L38-R43) [[3]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L62-R62) * Updated the Go version from 1.20 to 1.21 in the `.github/workflows/CI.yaml` file. [[1]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L17-R23) [[2]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L38-R43) ### Embedding Generation Feature: * Added a new section in the `README.md` file to document how to generate embedding vectors using the provider-based approach. * Introduced the `ai` package, which includes utilities for generating text embeddings with various models and providers. * Implemented the `EmbeddingService` and `EmbeddingProvider` interfaces, along with supporting types like `EmbeddingModel`, `EmbeddingObject`, `Usage`, and `EmbeddingResponse` in the `ai/embedding.go` file. * Added unit tests for the `EmbeddingService` in the `ai/embedding_test.go` file to ensure proper functionality of the embedding generation feature.
- Loading branch information
1 parent
becb126
commit bd2cecf
Showing
5 changed files
with
318 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,5 @@ | |
|
||
# Dependency directories (remove the comment below to include it) | ||
# vendor/ | ||
|
||
/.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
// Package ai provides artificial intelligence utilities including embedding generation capabilities. | ||
// It offers a flexible interface for generating text embeddings using various models and providers. | ||
package ai | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"net/http" | ||
) | ||
|
||
// EmbeddingModel represents the type of embedding model to be used for generating embeddings. | ||
type EmbeddingModel string | ||
|
||
// Available embedding models that can be used with the EmbeddingService. | ||
// These models provide different trade-offs between performance and accuracy. | ||
const ( | ||
// EmbeddingModelAllMiniLML6V2 is a lightweight model suitable for general-purpose embedding generation. | ||
// It provides a good balance between performance and quality. | ||
EmbeddingModelAllMiniLML6V2 EmbeddingModel = "all-MiniLM-L6-v2" | ||
|
||
// EmbeddingModelAllMpnetBaseV2 is a more powerful model that provides higher quality embeddings | ||
// at the cost of increased computation time. | ||
EmbeddingModelAllMpnetBaseV2 EmbeddingModel = "all-mpnet-base-v2" | ||
|
||
// EmbeddingModelParaphraseMultilingualMiniLML12V2 is specialized for multilingual text, | ||
// supporting embedding generation across multiple languages while maintaining semantic meaning. | ||
EmbeddingModelParaphraseMultilingualMiniLML12V2 EmbeddingModel = "paraphrase-multilingual-MiniLM-L12-v2" | ||
) | ||
|
||
// EmbeddingProvider defines the interface for services that can generate embeddings from text. | ||
// Implementations of this interface can connect to different embedding services or models. | ||
type EmbeddingProvider interface { | ||
// GenerateEmbedding creates an embedding vector from the provided input using the specified model. | ||
// The input can be a string or array of strings, and the response includes the embedding vectors | ||
// along with usage statistics. | ||
GenerateEmbedding(ctx context.Context, input interface{}, model string) (*EmbeddingResponse, error) | ||
} | ||
|
||
// EmbeddingObject represents a single embedding result containing the generated vector | ||
// and metadata about the embedding. | ||
type EmbeddingObject struct { | ||
// Object identifies the type of the response object | ||
Object string `json:"object"` | ||
// Embedding is the generated vector representation of the input text | ||
Embedding []float64 `json:"embedding"` | ||
// Index is the position of this embedding in the response array | ||
Index int `json:"index"` | ||
} | ||
|
||
// Usage represents token usage information for the embedding generation request. | ||
type Usage struct { | ||
// PromptTokens is the number of tokens in the input text | ||
PromptTokens int `json:"prompt_tokens"` | ||
// TotalTokens is the total number of tokens processed | ||
TotalTokens int `json:"total_tokens"` | ||
} | ||
|
||
// EmbeddingResponse represents the complete response from the embedding API. | ||
// It includes the generated embeddings and usage statistics. | ||
type EmbeddingResponse struct { | ||
// Object identifies the type of the response | ||
Object string `json:"object"` | ||
// Data contains the array of embedding results | ||
Data []EmbeddingObject `json:"data"` | ||
// Model identifies which embedding model was used | ||
Model EmbeddingModel `json:"model"` | ||
// Usage provides token usage statistics for the request | ||
Usage Usage `json:"usage"` | ||
} | ||
|
||
// EmbeddingService implements the EmbeddingProvider interface for generating embeddings | ||
// using a REST API endpoint. | ||
type EmbeddingService struct { | ||
// BaseURL is the base URL of the embedding API | ||
BaseURL string | ||
// HTTPClient is the HTTP client used for making requests | ||
HTTPClient *http.Client | ||
} | ||
|
||
// NewEmbeddingService creates a new EmbeddingService with the specified base URL and HTTP client. | ||
// If httpClient is nil, it uses http.DefaultClient. | ||
// | ||
// Example usage: | ||
// | ||
// client := NewEmbeddingService("https://api.example.com", nil) | ||
// resp, err := client.GenerateEmbedding( | ||
// context.Background(), | ||
// "Hello, world!", | ||
// EmbeddingModelAllMiniLML6V2, | ||
// ) | ||
// if err != nil { | ||
// log.Fatal(err) | ||
// } | ||
// fmt.Printf("Generated embedding vector: %v\n", resp.Data[0].Embedding) | ||
func NewEmbeddingService(baseURL string, httpClient *http.Client) *EmbeddingService { | ||
if httpClient == nil { | ||
httpClient = http.DefaultClient | ||
} | ||
return &EmbeddingService{ | ||
BaseURL: baseURL, | ||
HTTPClient: httpClient, | ||
} | ||
} | ||
|
||
// embeddingRequest represents the request body sent to the embedding API. | ||
type embeddingRequest struct { | ||
// Input is the text to generate embeddings for (string or []string) | ||
Input interface{} `json:"input"` | ||
// Model specifies which embedding model to use | ||
Model EmbeddingModel `json:"model"` | ||
// EncodingFormat specifies the format of the output vectors | ||
EncodingFormat string `json:"encoding_format"` | ||
} | ||
|
||
// GenerateEmbedding generates embedding vectors for the provided input using the specified model. | ||
// The input can be a single string or an array of strings. The method returns the embedding | ||
// vectors along with usage statistics. | ||
// | ||
// Example usage: | ||
// | ||
// service := NewEmbeddingService("https://api.example.com", nil) | ||
// | ||
// // Generate embedding for a single string | ||
// resp, err := service.GenerateEmbedding( | ||
// context.Background(), | ||
// "Hello, world!", | ||
// EmbeddingModelAllMiniLML6V2, | ||
// ) | ||
// if err != nil { | ||
// log.Fatal(err) | ||
// } | ||
// | ||
// // Generate embeddings for multiple strings | ||
// texts := []string{"Hello", "World"} | ||
// resp, err = service.GenerateEmbedding( | ||
// context.Background(), | ||
// texts, | ||
// EmbeddingModelAllMpnetBaseV2, | ||
// ) | ||
// if err != nil { | ||
// log.Fatal(err) | ||
// } | ||
// | ||
// The method returns an error if: | ||
// - The request cannot be created or sent | ||
// - The server returns a non-200 status code | ||
// - The response cannot be decoded | ||
func (s *EmbeddingService) GenerateEmbedding(ctx context.Context, input interface{}, model EmbeddingModel) (*EmbeddingResponse, error) { | ||
reqBody := embeddingRequest{ | ||
Input: input, | ||
Model: model, | ||
EncodingFormat: "float", | ||
} | ||
|
||
jsonBody, err := json.Marshal(reqBody) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to marshal request: %w", err) | ||
} | ||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonBody)) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create request: %w", err) | ||
} | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
resp, err := s.HTTPClient.Do(req) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to send request: %w", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) | ||
} | ||
|
||
var embResp EmbeddingResponse | ||
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { | ||
return nil, fmt.Errorf("failed to decode response: %w", err) | ||
} | ||
|
||
return &embResp, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package ai | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"net/http" | ||
"net/http/httptest" | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
func TestEmbeddingService_GenerateEmbedding(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input interface{} | ||
model EmbeddingModel | ||
response EmbeddingResponse | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "single string input", | ||
input: "test text", | ||
model: "all-MiniLM-L6-v2", | ||
response: EmbeddingResponse{ | ||
Object: "list", | ||
Data: []EmbeddingObject{ | ||
{ | ||
Object: "embedding", | ||
Embedding: []float64{0.1, 0.2, 0.3}, | ||
Index: 0, | ||
}, | ||
}, | ||
Model: "all-MiniLM-L6-v2", | ||
Usage: Usage{ | ||
PromptTokens: 2, | ||
TotalTokens: 2, | ||
}, | ||
}, | ||
}, | ||
{ | ||
name: "multiple string input", | ||
input: []string{"test1", "test2"}, | ||
model: "all-MiniLM-L6-v2", | ||
response: EmbeddingResponse{ | ||
Object: "list", | ||
Data: []EmbeddingObject{ | ||
{ | ||
Object: "embedding", | ||
Embedding: []float64{0.1, 0.2, 0.3}, | ||
Index: 0, | ||
}, | ||
{ | ||
Object: "embedding", | ||
Embedding: []float64{0.4, 0.5, 0.6}, | ||
Index: 1, | ||
}, | ||
}, | ||
Model: EmbeddingModelAllMiniLML6V2, | ||
Usage: Usage{ | ||
PromptTokens: 4, | ||
TotalTokens: 4, | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if r.Method != "POST" { | ||
t.Errorf("Expected POST request, got %s", r.Method) | ||
} | ||
if r.Header.Get("Content-Type") != "application/json" { | ||
t.Errorf("Expected Content-Type: application/json, got %s", r.Header.Get("Content-Type")) | ||
} | ||
|
||
json.NewEncoder(w).Encode(tt.response) | ||
})) | ||
defer server.Close() | ||
|
||
service := NewEmbeddingService(server.URL, server.Client()) | ||
gotResp, err := service.GenerateEmbedding(context.Background(), tt.input, tt.model) | ||
|
||
if (err != nil) != tt.wantErr { | ||
t.Errorf("GenerateEmbedding() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
|
||
if !reflect.DeepEqual(gotResp, &tt.response) { | ||
t.Errorf("GenerateEmbedding() = %v, want %v", gotResp, tt.response) | ||
} | ||
}) | ||
} | ||
} |