Skip to content
This repository has been archived by the owner on Feb 15, 2025. It is now read-only.

feat: add working embedding endpoint for all-minilm-l6-v2 #141

Merged
merged 1 commit into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions api/backends/openai/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package openai
import (
"fmt"
"log"
"reflect"

"github.com/defenseunicorns/leapfrogai/pkg/client/embeddings"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
"google.golang.org/grpc"
)

Expand All @@ -23,67 +21,80 @@ type EmbeddingRequest struct {
Input any `json:"input"`
// ID of the model to use. You can use the List models API to see all of your available models,
// or see our Model overview for descriptions of them.
Model openai.EmbeddingModel `json:"model"`
Model string `json:"model"`
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
User string `json:"user"`
}

type EmbeddingResponse struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
Usage Usage `json:"usage"`
}

type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

func (o *OpenAIHandler) createEmbeddings(c *gin.Context) {
var input openai.EmbeddingRequest
var i2 EmbeddingRequest
if err := c.BindJSON(&i2); err != nil {
var input EmbeddingRequest
if err := c.BindJSON(&input); err != nil {
log.Printf("500: Error marshalling input to object: %v\n", err)
// Handle error
c.JSON(500, err)
return
}
// input was just a string, so convert back to the openai object
input = openai.EmbeddingRequest{
Model: i2.Model,
User: i2.User,
conn := o.getModelClient(c, input.Model)
if conn == nil {
return
}
log.Printf("DEBUG: INPUT TYPE: %v\n", reflect.TypeOf(i2.Input))
switch v := i2.Input.(type) {
var request embeddings.EmbeddingRequest
switch v := input.Input.(type) {
case string:
log.Printf("embedding request for string")
input.Input = []string{i2.Input.(string)}
request = embeddings.EmbeddingRequest{
Inputs: []string{v},
}
case []interface{}:
log.Printf("embedding request for []interface")
input.Input = make([]string, len(v))
inputString := make([]string, len(v))
for i, s := range v {
input.Input[i] = s.(string)
inputString[i] = s.(string)
}
request = embeddings.EmbeddingRequest{
Inputs: inputString,
}

default:
log.Printf("400: embedding request for unknown type: %v", v)
c.JSON(400, fmt.Errorf("object Input was not of type string or []string: %v", v))
}

conn := o.getModelClient(c, input.Model.String())
if conn == nil {
return
}
client := embeddings.NewEmbeddingsServiceClient(conn)
request := embeddings.EmbeddingRequest{
Inputs: input.Input,
}
grpcResponse, err := client.CreateEmbedding(c, &request)
if err != nil {
log.Printf("500: Error creating embedding for %v: %v", input.Model.String(), err)
log.Printf("500: Error creating embedding for %v: %v", input.Model, err)
c.JSON(500, fmt.Errorf("error creating embedding: %v", err))
return
}

response := openai.EmbeddingResponse{
response := EmbeddingResponse{
// Don't know what this object is
Object: "",
Model: input.Model,
// No idea what this is for
Usage: openai.Usage{},
Usage: Usage{},
}
response.Data = make([]openai.Embedding, len(grpcResponse.Embeddings))
response.Data = make([]Embedding, len(grpcResponse.Embeddings))
for i, e := range grpcResponse.Embeddings {
embed := openai.Embedding{
embed := Embedding{
Object: "", //No idea what this should be
Embedding: e.Embedding,
Index: i,
Expand Down
11 changes: 11 additions & 0 deletions api/models2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ tasks = ["completion", "chat"]
url = 'localhost:50051'
type = 'gRPC'

[all-minilm-l6-v2]

[all-minilm-l6-v2.metadata]
owned_by = ''
permission = []
description = 'all-minilm-l6-v2 english embeddings model'
tasks = ["embeddings"]

[all-minilm-l6-v2.network]
url = 'localhost:50051'
type = 'gRPC'

[ctransformers]

Expand Down
Loading