Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai: Update llm pipeline #3336

Merged
merged 13 commits into from
Jan 27, 2025
Merged
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Features ⚒

- [#3365](https://github.com/livepeer/go-livepeer/pull/3336/) updated AI llm pipeline to new OpenAI compatible API format.

#### General

#### Broadcaster
Expand Down
2 changes: 1 addition & 1 deletion core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type AI interface {
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
LLM(context.Context, worker.GenLLMJSONRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error)
Expand Down
7 changes: 5 additions & 2 deletions core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,11 @@ func (a *stubAIWorker) SegmentAnything2(ctx context.Context, req worker.GenSegme
return &worker.MasksResponse{Logits: "logits", Masks: "masks", Scores: "scores"}, nil
}

func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return &worker.LLMResponse{Response: "response tokens", TokensUsed: 10}, nil
func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
var choices []worker.LLMChoice
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
}

func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) {
Expand Down
8 changes: 4 additions & 4 deletions core/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,14 @@
}

// Return type is LLMResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.GenLLMJSONRequestBody) (interface{}, error) {

Check warning on line 790 in core/ai_worker.go

View check run for this annotation

Codecov / codecov/patch

core/ai_worker.go#L790

Added line #L790 was not covered by tests
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
// no file response to save, response is text sent back to gateway
return orch.node.AIWorker.LLM(ctx, req)
}

res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "llm", *req.ModelId, "", AIJobRequestData{Request: req})
res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "llm", *req.Model, "", AIJobRequestData{Request: req})

Check warning on line 797 in core/ai_worker.go

View check run for this annotation

Codecov / codecov/patch

core/ai_worker.go#L797

Added line #L797 was not covered by tests
rickstaa marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}
Expand All @@ -805,7 +805,7 @@
if err != nil {
clog.Errorf(ctx, "Error saving remote ai result err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "llm", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
monitor.AIResultSaveError(ctx, "llm", *req.Model, string(monitor.SegmentUploadErrorUnknown))

Check warning on line 808 in core/ai_worker.go

View check run for this annotation

Codecov / codecov/patch

core/ai_worker.go#L808

Added line #L808 was not covered by tests
}
return nil, err

Expand Down Expand Up @@ -1050,7 +1050,7 @@
return n.AIWorker.SegmentAnything2(ctx, req)
}

func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {

Check warning on line 1053 in core/ai_worker.go

View check run for this annotation

Codecov / codecov/patch

core/ai_worker.go#L1053

Added line #L1053 was not covered by tests
return n.AIWorker.LLM(ctx, req)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/google/uuid v1.6.0
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3
github.com/livepeer/ai-worker v0.13.1
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20250118014304-79e6dcf08057
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3 h1:uutmGZq2YdIKnKhn6QGHtGnKfBGYAUMMOr44LXYs23w=
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3/go.mod h1:ZibfmZQQh6jFvnPLHeIPInghfX5ln+JpN845nS3GuyM=
github.com/livepeer/ai-worker v0.13.1 h1:BnqzmBD/E5gHM0P6UXt9M2/bZwU3ZryEfNpbW+NYJr0=
github.com/livepeer/ai-worker v0.13.1/go.mod h1:ZibfmZQQh6jFvnPLHeIPInghfX5ln+JpN845nS3GuyM=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
18 changes: 9 additions & 9 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
lp.transRPC.Handle("/image-to-video", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToVideoMultipartRequestBody])))
lp.transRPC.Handle("/upscale", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenUpscaleMultipartRequestBody])))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenAudioToTextMultipartRequestBody])))
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenLLMFormdataRequestBody])))
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenLLMJSONRequestBody])))

Check warning on line 69 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L69

Added line #L69 was not covered by tests
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody])))
lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody])))
lp.transRPC.Handle("/text-to-speech", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenTextToSpeechJSONRequestBody])))
Expand Down Expand Up @@ -404,10 +404,10 @@
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.GenLLMFormdataRequestBody:
case worker.GenLLMJSONRequestBody:

Check warning on line 407 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L407

Added line #L407 was not covered by tests
pipeline = "llm"
cap = core.Capability_LLM
modelID = *v.ModelId
modelID = *v.Model

Check warning on line 410 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L410

Added line #L410 was not covered by tests
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LLM(ctx, requestID, v)
}
Expand Down Expand Up @@ -585,7 +585,7 @@
}

// Check if the response is a streaming response
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(<-chan *worker.LLMResponse); ok {

Check warning on line 588 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L588

Added line #L588 was not covered by tests
glog.Infof("Streaming response for request id=%v", requestID)

// Set headers for SSE
Expand All @@ -609,7 +609,7 @@
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {

Check warning on line 612 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L612

Added line #L612 was not covered by tests
break
}
}
Expand Down Expand Up @@ -682,8 +682,8 @@
case "text/event-stream":
resultType = "streaming"
glog.Infof("Received %s response from remote worker=%s taskId=%d", resultType, r.RemoteAddr, tid)
resChan := make(chan worker.LlmStreamChunk, 100)
workerResult.Results = (<-chan worker.LlmStreamChunk)(resChan)
resChan := make(chan *worker.LLMResponse, 100)
workerResult.Results = (<-chan *worker.LLMResponse)(resChan)

Check warning on line 686 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L685-L686

Added lines #L685 - L686 were not covered by tests

defer r.Body.Close()
defer close(resChan)
Expand All @@ -702,12 +702,12 @@
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
var chunk worker.LlmStreamChunk
var chunk worker.LLMResponse

Check warning on line 705 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L705

Added line #L705 was not covered by tests
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
clog.Errorf(ctx, "Error unmarshaling stream data: %v", err)
continue
}
resChan <- chunk
resChan <- &chunk

Check warning on line 710 in server/ai_http.go

View check run for this annotation

Codecov / codecov/patch

server/ai_http.go#L710

Added line #L710 was not covered by tests
}
}
}
Expand Down
19 changes: 9 additions & 10 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,19 @@
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.GenLLMFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
var req worker.GenLLMJSONRequestBody
if err := jsonDecoder(&req, r); err != nil {

Check warning on line 260 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L259-L260

Added lines #L259 - L260 were not covered by tests
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
//check required fields
if req.Model == nil || req.Messages == nil || req.Stream == nil || req.MaxTokens == nil || len(req.Messages) == 0 {
respondJsonError(ctx, w, errors.New("missing required fields"), http.StatusBadRequest)

Check warning on line 267 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L266-L267

Added lines #L266 - L267 were not covered by tests
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream)
clog.V(common.VERBOSE).Infof(ctx, "Received LLM request model_id=%v stream=%v", *req.Model, *req.Stream)

Check warning on line 271 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L271

Added line #L271 was not covered by tests

params := aiRequestParams{
node: ls.LivepeerNode,
Expand All @@ -290,9 +289,9 @@
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request model_id=%v took=%v", *req.Model, took)

Check warning on line 292 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L292

Added line #L292 was not covered by tests

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(chan *worker.LLMResponse); ok {

Check warning on line 294 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L294

Added line #L294 was not covered by tests
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand All @@ -302,7 +301,7 @@
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {

Check warning on line 304 in server/ai_mediaserver.go

View check run for this annotation

Codecov / codecov/patch

server/ai_mediaserver.go#L304

Added line #L304 was not covered by tests
break
}
}
Expand Down
71 changes: 35 additions & 36 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -1107,14 +1107,14 @@
return took.Seconds() / float64(tokensUsed)
}

func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMJSONRequestBody) (interface{}, error) {

Check warning on line 1110 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1110

Added line #L1110 was not covered by tests
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

if req.Stream != nil && *req.Stream {
streamChan, ok := resp.(chan worker.LlmStreamChunk)
streamChan, ok := resp.(chan *worker.LLMResponse)

Check warning on line 1117 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1117

Added line #L1117 was not covered by tests
if !ok {
return nil, errors.New("unexpected response type for streaming request")
}
Expand All @@ -1129,20 +1129,12 @@
return llmResp, nil
}

func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
var buf bytes.Buffer
mw, err := worker.NewLLMMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, nil)
}
return nil, err
}
func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMJSONRequestBody) (interface{}, error) {

Check warning on line 1132 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1132

Added line #L1132 was not covered by tests

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)

Check warning on line 1137 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1137

Added line #L1137 was not covered by tests
}
return nil, err
}
Expand All @@ -1155,17 +1147,17 @@
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)

Check warning on line 1150 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1150

Added line #L1150 was not covered by tests
}
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

start := time.Now()
resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
resp, err := client.GenLLM(ctx, req, setHeaders)

Check warning on line 1157 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1157

Added line #L1157 was not covered by tests
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)

Check warning on line 1160 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1160

Added line #L1160 was not covered by tests
}
return nil, err
}
Expand All @@ -1175,83 +1167,90 @@
return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}

// We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
// TODO: move to after receive stream response in handleSSEStream and handleNonStreamingResponse to count input tokens
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

Check warning on line 1174 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1172-L1174

Added lines #L1172 - L1174 were not covered by tests

if req.Stream != nil && *req.Stream {
return handleSSEStream(ctx, resp.Body, sess, req, start)
}

return handleNonStreamingResponse(ctx, resp.Body, sess, req, start)
}

func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
streamChan := make(chan worker.LlmStreamChunk, 100)
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan *worker.LLMResponse, error) {
streamChan := make(chan *worker.LLMResponse, 100)

Check warning on line 1184 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1183-L1184

Added lines #L1183 - L1184 were not covered by tests
go func() {
defer close(streamChan)
defer body.Close()
scanner := bufio.NewScanner(body)
var totalTokens int
var totalTokens worker.LLMTokenUsage

Check warning on line 1189 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1189

Added line #L1189 was not covered by tests
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens}
break
}
var chunk worker.LlmStreamChunk

var chunk worker.LLMResponse

Check warning on line 1195 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1194-L1195

Added lines #L1194 - L1195 were not covered by tests
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err)
continue
}
totalTokens += chunk.TokensUsed
streamChan <- chunk
totalTokens = chunk.TokensUsed
streamChan <- &chunk
//check if stream is finished
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
break

Check warning on line 1204 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1200-L1204

Added lines #L1200 - L1204 were not covered by tests
}
}
}
if err := scanner.Err(); err != nil {
clog.Errorf(ctx, "Error reading SSE stream: %v", err)
}

took := time.Since(start)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens.TotalTokens)

Check warning on line 1213 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1213

Added line #L1213 was not covered by tests

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.Model, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)

Check warning on line 1220 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1220

Added line #L1220 was not covered by tests
}
}()

return streamChan, nil
}

func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (*worker.LLMResponse, error) {
func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (*worker.LLMResponse, error) {

Check warning on line 1227 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1227

Added line #L1227 was not covered by tests
data, err := io.ReadAll(body)
defer body.Close()
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)

Check warning on line 1232 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1232

Added line #L1232 was not covered by tests
}
return nil, err
}

var res worker.LLMResponse
if err := json.Unmarshal(data, &res); err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)

Check warning on line 1240 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1240

Added line #L1240 was not covered by tests
}
return nil, err
}

took := time.Since(start)
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed)
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed.TotalTokens)

Check warning on line 1246 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1246

Added line #L1246 was not covered by tests

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.Model, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)

Check warning on line 1253 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1253

Added line #L1253 was not covered by tests
}

return &res, nil
Expand Down Expand Up @@ -1410,16 +1409,16 @@
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitAudioToText(ctx, params, sess, v)
}
case worker.GenLLMFormdataRequestBody:
case worker.GenLLMJSONRequestBody:

Check warning on line 1412 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1412

Added line #L1412 was not covered by tests
cap = core.Capability_LLM
modelID = defaultLLMModelID
if v.ModelId != nil {
modelID = *v.ModelId
if v.Model != nil {
modelID = *v.Model

Check warning on line 1416 in server/ai_process.go

View check run for this annotation

Codecov / codecov/patch

server/ai_process.go#L1415-L1416

Added lines #L1415 - L1416 were not covered by tests
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLLM(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)

case worker.GenSegmentAnything2MultipartRequestBody:
cap = core.Capability_SegmentAnything2
modelID = defaultSegmentAnything2ModelID
Expand Down
2 changes: 1 addition & 1 deletion server/ai_process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Test_submitLLM(t *testing.T) {
ctx context.Context
params aiRequestParams
sess *AISession
req worker.GenLLMFormdataRequestBody
req worker.GenLLMJSONRequestBody
}
tests := []struct {
name string
Expand Down
Loading
Loading