Skip to content

Commit

Permalink
Added AI -> Embedding service related functions (#14)
Browse files Browse the repository at this point in the history
This pull request includes significant updates to the CI configuration and the addition of a new embedding generation feature in the `ai` package. The most important changes are detailed below:

### CI Configuration Updates:
* Updated the `actions/checkout` action from version 3 to version 4 in the `.github/workflows/CI.yaml` file. [[1]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L17-R23) [[2]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L38-R43) [[3]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L62-R62)
* Updated the Go version from 1.20 to 1.21 in the `.github/workflows/CI.yaml` file. [[1]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L17-R23) [[2]](diffhunk://#diff-b268bb7dc66e8638921e223098a11eabb92b424875ec72a19d6145dc1efbe3f2L38-R43)

### Embedding Generation Feature:
* Added a new section in the `README.md` file to document how to generate embedding vectors using the provider-based approach.
* Introduced the `ai` package, which includes utilities for generating text embeddings with various models and providers.
* Implemented the `EmbeddingService` and `EmbeddingProvider` interfaces, along with supporting types like `EmbeddingModel`, `EmbeddingObject`, `Usage`, and `EmbeddingResponse` in the `ai/embedding.go` file.
* Added unit tests for the `EmbeddingService` in the `ai/embedding_test.go` file to ensure proper functionality of the embedding generation feature.
  • Loading branch information
shahariaazam authored Dec 13, 2024
1 parent becb126 commit bd2cecf
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 6 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ jobs:
runs-on: ubuntu-latest
steps:

- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: 0

- uses: actions/setup-go@v3
with:
go-version: '^1.20'
go-version: '^1.21'
check-latest: true
cache: true

Expand All @@ -35,12 +35,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Go ^1.20
- name: Set up Go ^1.21
uses: actions/setup-go@v3
with:
go-version: '^1.20'
go-version: '^1.21'

- name: Test
run: go test ./... -covermode=atomic -coverprofile=coverage.out
Expand All @@ -59,7 +59,7 @@ jobs:
needs: test
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0

Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@

# Dependency directories (remove the comment below to include it)
# vendor/

/.idea
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,38 @@ type LLMProvider interface {
}
```

#### Generate Embedding Vector

You can generate embeddings using the provider-based approach:

```go
import (
"github.com/shaharia-lab/guti/ai"
)

// Create an embedding provider
provider := ai.NewLocalEmbeddingProvider(ai.LocalProviderConfig{
BaseURL: "http://localhost:8000",
Client: &http.Client{},
})

// Generate embedding
embedding, err := provider.GenerateEmbedding(context.Background(), "Hello world", ai.EmbeddingModelAllMiniLML6V2)
if err != nil {
log.Fatal(err)
}

fmt.Printf("Embedding vector: %+v\n", embedding)
```

The library supports multiple embedding providers. You can implement the `EmbeddingProvider` interface to add support for additional providers:

```go
type EmbeddingProvider interface {
GenerateEmbedding(ctx context.Context, text string, model EmbeddingModel) ([]float32, error)
}
```

## Documentation

Full documentation is available on [pkg.go.dev/github.com/shaharia-lab/guti](https://pkg.go.dev/github.com/shaharia-lab/guti#section-documentation)
Expand Down
184 changes: 184 additions & 0 deletions ai/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Package ai provides artificial intelligence utilities including embedding generation capabilities.
// It offers a flexible interface for generating text embeddings using various models and providers.
package ai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)

// EmbeddingModel represents the type of embedding model to be used for generating embeddings.
type EmbeddingModel string

// Available embedding models that can be used with the EmbeddingService.
// These models provide different trade-offs between performance and accuracy.
const (
// EmbeddingModelAllMiniLML6V2 is a lightweight model suitable for general-purpose embedding generation.
// It provides a good balance between performance and quality.
EmbeddingModelAllMiniLML6V2 EmbeddingModel = "all-MiniLM-L6-v2"

// EmbeddingModelAllMpnetBaseV2 is a more powerful model that provides higher quality embeddings
// at the cost of increased computation time.
EmbeddingModelAllMpnetBaseV2 EmbeddingModel = "all-mpnet-base-v2"

// EmbeddingModelParaphraseMultilingualMiniLML12V2 is specialized for multilingual text,
// supporting embedding generation across multiple languages while maintaining semantic meaning.
EmbeddingModelParaphraseMultilingualMiniLML12V2 EmbeddingModel = "paraphrase-multilingual-MiniLM-L12-v2"
)

// EmbeddingProvider defines the interface for services that can generate embeddings from text.
// Implementations of this interface can connect to different embedding services or models.
type EmbeddingProvider interface {
// GenerateEmbedding creates an embedding vector from the provided input using the specified model.
// The input can be a string or array of strings, and the response includes the embedding vectors
// along with usage statistics.
GenerateEmbedding(ctx context.Context, input interface{}, model string) (*EmbeddingResponse, error)
}

// EmbeddingObject represents a single embedding result containing the generated vector
// and metadata about the embedding.
type EmbeddingObject struct {
// Object identifies the type of the response object
Object string `json:"object"`
// Embedding is the generated vector representation of the input text
Embedding []float64 `json:"embedding"`
// Index is the position of this embedding in the response array
Index int `json:"index"`
}

// Usage represents token usage information for the embedding generation request.
type Usage struct {
// PromptTokens is the number of tokens in the input text
PromptTokens int `json:"prompt_tokens"`
// TotalTokens is the total number of tokens processed
TotalTokens int `json:"total_tokens"`
}

// EmbeddingResponse represents the complete response from the embedding API.
// It includes the generated embeddings and usage statistics.
type EmbeddingResponse struct {
// Object identifies the type of the response
Object string `json:"object"`
// Data contains the array of embedding results
Data []EmbeddingObject `json:"data"`
// Model identifies which embedding model was used
Model EmbeddingModel `json:"model"`
// Usage provides token usage statistics for the request
Usage Usage `json:"usage"`
}

// EmbeddingService implements the EmbeddingProvider interface for generating embeddings
// using a REST API endpoint.
type EmbeddingService struct {
// BaseURL is the base URL of the embedding API
BaseURL string
// HTTPClient is the HTTP client used for making requests
HTTPClient *http.Client
}

// NewEmbeddingService creates a new EmbeddingService with the specified base URL and HTTP client.
// If httpClient is nil, it uses http.DefaultClient.
//
// Example usage:
//
// client := NewEmbeddingService("https://api.example.com", nil)
// resp, err := client.GenerateEmbedding(
// context.Background(),
// "Hello, world!",
// EmbeddingModelAllMiniLML6V2,
// )
// if err != nil {
// log.Fatal(err)
// }
// fmt.Printf("Generated embedding vector: %v\n", resp.Data[0].Embedding)
func NewEmbeddingService(baseURL string, httpClient *http.Client) *EmbeddingService {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &EmbeddingService{
BaseURL: baseURL,
HTTPClient: httpClient,
}
}

// embeddingRequest represents the request body sent to the embedding API.
type embeddingRequest struct {
// Input is the text to generate embeddings for (string or []string)
Input interface{} `json:"input"`
// Model specifies which embedding model to use
Model EmbeddingModel `json:"model"`
// EncodingFormat specifies the format of the output vectors
EncodingFormat string `json:"encoding_format"`
}

// GenerateEmbedding generates embedding vectors for the provided input using the specified model.
// The input can be a single string or an array of strings. The method returns the embedding
// vectors along with usage statistics.
//
// Example usage:
//
// service := NewEmbeddingService("https://api.example.com", nil)
//
// // Generate embedding for a single string
// resp, err := service.GenerateEmbedding(
// context.Background(),
// "Hello, world!",
// EmbeddingModelAllMiniLML6V2,
// )
// if err != nil {
// log.Fatal(err)
// }
//
// // Generate embeddings for multiple strings
// texts := []string{"Hello", "World"}
// resp, err = service.GenerateEmbedding(
// context.Background(),
// texts,
// EmbeddingModelAllMpnetBaseV2,
// )
// if err != nil {
// log.Fatal(err)
// }
//
// The method returns an error if:
// - The request cannot be created or sent
// - The server returns a non-200 status code
// - The response cannot be decoded
func (s *EmbeddingService) GenerateEmbedding(ctx context.Context, input interface{}, model EmbeddingModel) (*EmbeddingResponse, error) {
reqBody := embeddingRequest{
Input: input,
Model: model,
EncodingFormat: "float",
}

jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", s.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := s.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var embResp EmbeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

return &embResp, nil
}
94 changes: 94 additions & 0 deletions ai/embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package ai

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)

func TestEmbeddingService_GenerateEmbedding(t *testing.T) {
tests := []struct {
name string
input interface{}
model EmbeddingModel
response EmbeddingResponse
wantErr bool
}{
{
name: "single string input",
input: "test text",
model: "all-MiniLM-L6-v2",
response: EmbeddingResponse{
Object: "list",
Data: []EmbeddingObject{
{
Object: "embedding",
Embedding: []float64{0.1, 0.2, 0.3},
Index: 0,
},
},
Model: "all-MiniLM-L6-v2",
Usage: Usage{
PromptTokens: 2,
TotalTokens: 2,
},
},
},
{
name: "multiple string input",
input: []string{"test1", "test2"},
model: "all-MiniLM-L6-v2",
response: EmbeddingResponse{
Object: "list",
Data: []EmbeddingObject{
{
Object: "embedding",
Embedding: []float64{0.1, 0.2, 0.3},
Index: 0,
},
{
Object: "embedding",
Embedding: []float64{0.4, 0.5, 0.6},
Index: 1,
},
},
Model: EmbeddingModelAllMiniLML6V2,
Usage: Usage{
PromptTokens: 4,
TotalTokens: 4,
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type: application/json, got %s", r.Header.Get("Content-Type"))
}

json.NewEncoder(w).Encode(tt.response)
}))
defer server.Close()

service := NewEmbeddingService(server.URL, server.Client())
gotResp, err := service.GenerateEmbedding(context.Background(), tt.input, tt.model)

if (err != nil) != tt.wantErr {
t.Errorf("GenerateEmbedding() error = %v, wantErr %v", err, tt.wantErr)
return
}

if !reflect.DeepEqual(gotResp, &tt.response) {
t.Errorf("GenerateEmbedding() = %v, want %v", gotResp, tt.response)
}
})
}
}

0 comments on commit bd2cecf

Please sign in to comment.