Skip to content

Commit

Permalink
ai: Update llm pipeline (#3336)
Browse files Browse the repository at this point in the history
* update llm pipeline for updated LLM pipeline in livepeer/ai-worker to provide better compatibility with open ai sdk, fixes for nil error and small fix for payments.

---------

Co-authored-by: Rick Staa <rick.staa@outlook.com>
  • Loading branch information
ad-astra-video and rickstaa authored Jan 27, 2025
1 parent 3acf4ea commit 1065885
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 79 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Features ⚒

- [#3365](https://github.com/livepeer/go-livepeer/pull/3336/) updated AI llm pipeline to new OpenAI compatible API format.

#### General

#### Broadcaster
Expand Down
2 changes: 1 addition & 1 deletion core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type AI interface {
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
LLM(context.Context, worker.GenLLMJSONRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
ImageToText(context.Context, worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error)
TextToSpeech(context.Context, worker.GenTextToSpeechJSONRequestBody) (*worker.AudioResponse, error)
Expand Down
7 changes: 5 additions & 2 deletions core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,11 @@ func (a *stubAIWorker) SegmentAnything2(ctx context.Context, req worker.GenSegme
return &worker.MasksResponse{Logits: "logits", Masks: "masks", Scores: "scores"}, nil
}

func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return &worker.LLMResponse{Response: "response tokens", TokensUsed: 10}, nil
func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
var choices []worker.LLMChoice
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
}

func (a *stubAIWorker) ImageToText(ctx context.Context, req worker.GenImageToTextMultipartRequestBody) (*worker.ImageToTextResponse, error) {
Expand Down
8 changes: 4 additions & 4 deletions core/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,14 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, requestID string
}

// Return type is LLMResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.GenLLMJSONRequestBody) (interface{}, error) {
// local AIWorker processes job if combined orchestrator/ai worker
if orch.node.AIWorker != nil {
// no file response to save, response is text sent back to gateway
return orch.node.AIWorker.LLM(ctx, req)
}

res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "llm", *req.ModelId, "", AIJobRequestData{Request: req})
res, err := orch.node.AIWorkerManager.Process(ctx, requestID, "llm", *req.Model, "", AIJobRequestData{Request: req})
if err != nil {
return nil, err
}
Expand All @@ -805,7 +805,7 @@ func (orch *orchestrator) LLM(ctx context.Context, requestID string, req worker.
if err != nil {
clog.Errorf(ctx, "Error saving remote ai result err=%q", err)
if monitor.Enabled {
monitor.AIResultSaveError(ctx, "llm", *req.ModelId, string(monitor.SegmentUploadErrorUnknown))
monitor.AIResultSaveError(ctx, "llm", *req.Model, string(monitor.SegmentUploadErrorUnknown))
}
return nil, err

Expand Down Expand Up @@ -1050,7 +1050,7 @@ func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegme
return n.AIWorker.SegmentAnything2(ctx, req)
}

func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func (n *LivepeerNode) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
return n.AIWorker.LLM(ctx, req)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/google/uuid v1.6.0
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3
github.com/livepeer/ai-worker v0.13.1
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20250118014304-79e6dcf08057
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3 h1:uutmGZq2YdIKnKhn6QGHtGnKfBGYAUMMOr44LXYs23w=
github.com/livepeer/ai-worker v0.12.7-0.20241219141308-c19289d128a3/go.mod h1:ZibfmZQQh6jFvnPLHeIPInghfX5ln+JpN845nS3GuyM=
github.com/livepeer/ai-worker v0.13.1 h1:BnqzmBD/E5gHM0P6UXt9M2/bZwU3ZryEfNpbW+NYJr0=
github.com/livepeer/ai-worker v0.13.1/go.mod h1:ZibfmZQQh6jFvnPLHeIPInghfX5ln+JpN845nS3GuyM=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
18 changes: 9 additions & 9 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func startAIServer(lp *lphttp) error {
lp.transRPC.Handle("/image-to-video", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToVideoMultipartRequestBody])))
lp.transRPC.Handle("/upscale", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenUpscaleMultipartRequestBody])))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenAudioToTextMultipartRequestBody])))
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenLLMFormdataRequestBody])))
lp.transRPC.Handle("/llm", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenLLMJSONRequestBody])))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody])))
lp.transRPC.Handle("/image-to-text", oapiReqValidator(aiHttpHandle(lp, multipartDecoder[worker.GenImageToTextMultipartRequestBody])))
lp.transRPC.Handle("/text-to-speech", oapiReqValidator(aiHttpHandle(lp, jsonDecoder[worker.GenTextToSpeechJSONRequestBody])))
Expand Down Expand Up @@ -404,10 +404,10 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.GenLLMFormdataRequestBody:
case worker.GenLLMJSONRequestBody:
pipeline = "llm"
cap = core.Capability_LLM
modelID = *v.ModelId
modelID = *v.Model
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LLM(ctx, requestID, v)
}
Expand Down Expand Up @@ -585,7 +585,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

// Check if the response is a streaming response
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(<-chan *worker.LLMResponse); ok {
glog.Infof("Streaming response for request id=%v", requestID)

// Set headers for SSE
Expand All @@ -609,7 +609,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
break
}
}
Expand Down Expand Up @@ -682,8 +682,8 @@ func (h *lphttp) AIResults() http.Handler {
case "text/event-stream":
resultType = "streaming"
glog.Infof("Received %s response from remote worker=%s taskId=%d", resultType, r.RemoteAddr, tid)
resChan := make(chan worker.LlmStreamChunk, 100)
workerResult.Results = (<-chan worker.LlmStreamChunk)(resChan)
resChan := make(chan *worker.LLMResponse, 100)
workerResult.Results = (<-chan *worker.LLMResponse)(resChan)

defer r.Body.Close()
defer close(resChan)
Expand All @@ -702,12 +702,12 @@ func (h *lphttp) AIResults() http.Handler {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
var chunk worker.LlmStreamChunk
var chunk worker.LLMResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
clog.Errorf(ctx, "Error unmarshaling stream data: %v", err)
continue
}
resChan <- chunk
resChan <- &chunk
}
}
}
Expand Down
19 changes: 9 additions & 10 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,19 @@ func (ls *LivepeerServer) LLM() http.Handler {
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.GenLLMFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
var req worker.GenLLMJSONRequestBody
if err := jsonDecoder(&req, r); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
//check required fields
if req.Model == nil || req.Messages == nil || req.Stream == nil || req.MaxTokens == nil || len(req.Messages) == 0 {
respondJsonError(ctx, w, errors.New("missing required fields"), http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream)
clog.V(common.VERBOSE).Infof(ctx, "Received LLM request model_id=%v stream=%v", *req.Model, *req.Stream)

params := aiRequestParams{
node: ls.LivepeerNode,
Expand All @@ -290,9 +289,9 @@ func (ls *LivepeerServer) LLM() http.Handler {
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request model_id=%v took=%v", *req.Model, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(chan *worker.LLMResponse); ok {
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand All @@ -302,7 +301,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
break
}
}
Expand Down
71 changes: 35 additions & 36 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -1107,14 +1107,14 @@ func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 {
return took.Seconds() / float64(tokensUsed)
}

func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMJSONRequestBody) (interface{}, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

if req.Stream != nil && *req.Stream {
streamChan, ok := resp.(chan worker.LlmStreamChunk)
streamChan, ok := resp.(chan *worker.LLMResponse)
if !ok {
return nil, errors.New("unexpected response type for streaming request")
}
Expand All @@ -1129,20 +1129,12 @@ func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFo
return llmResp, nil
}

func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
var buf bytes.Buffer
mw, err := worker.NewLLMMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, nil)
}
return nil, err
}
func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMJSONRequestBody) (interface{}, error) {

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -1155,17 +1147,17 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
}
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

start := time.Now()
resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
resp, err := client.GenLLM(ctx, req, setHeaders)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -1175,83 +1167,90 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}

// We treat a response as "receiving change" where the change is the difference between the credit and debit for the update
// TODO: move to after receive stream response in handleSSEStream and handleNonStreamingResponse to count input tokens
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

if req.Stream != nil && *req.Stream {
return handleSSEStream(ctx, resp.Body, sess, req, start)
}

return handleNonStreamingResponse(ctx, resp.Body, sess, req, start)
}

func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
streamChan := make(chan worker.LlmStreamChunk, 100)
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan *worker.LLMResponse, error) {
streamChan := make(chan *worker.LLMResponse, 100)
go func() {
defer close(streamChan)
defer body.Close()
scanner := bufio.NewScanner(body)
var totalTokens int
var totalTokens worker.LLMTokenUsage
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens}
break
}
var chunk worker.LlmStreamChunk

var chunk worker.LLMResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err)
continue
}
totalTokens += chunk.TokensUsed
streamChan <- chunk
totalTokens = chunk.TokensUsed
streamChan <- &chunk
//check if stream is finished
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
break
}
}
}
if err := scanner.Err(); err != nil {
clog.Errorf(ctx, "Error reading SSE stream: %v", err)
}

took := time.Since(start)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens.TotalTokens)

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.Model, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}
}()

return streamChan, nil
}

func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (*worker.LLMResponse, error) {
func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (*worker.LLMResponse, error) {
data, err := io.ReadAll(body)
defer body.Close()
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
}
return nil, err
}

var res worker.LLMResponse
if err := json.Unmarshal(data, &res); err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.Model, sess.OrchestratorInfo)
}
return nil, err
}

took := time.Since(start)
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed)
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed.TotalTokens)

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.Model, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

return &res, nil
Expand Down Expand Up @@ -1410,16 +1409,16 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitAudioToText(ctx, params, sess, v)
}
case worker.GenLLMFormdataRequestBody:
case worker.GenLLMJSONRequestBody:
cap = core.Capability_LLM
modelID = defaultLLMModelID
if v.ModelId != nil {
modelID = *v.ModelId
if v.Model != nil {
modelID = *v.Model
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLLM(ctx, params, sess, v)
}
ctx = clog.AddVal(ctx, "prompt", v.Prompt)

case worker.GenSegmentAnything2MultipartRequestBody:
cap = core.Capability_SegmentAnything2
modelID = defaultSegmentAnything2ModelID
Expand Down
2 changes: 1 addition & 1 deletion server/ai_process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Test_submitLLM(t *testing.T) {
ctx context.Context
params aiRequestParams
sess *AISession
req worker.GenLLMFormdataRequestBody
req worker.GenLLMJSONRequestBody
}
tests := []struct {
name string
Expand Down
Loading

0 comments on commit 1065885

Please sign in to comment.