Skip to content

Commit

Permalink
Use vllm metrics for routing (#274)
Browse files Browse the repository at this point in the history
* Cache bug fix in update pod and model mapping (#259)

* test

* Use vllm metrics for routing

* nit reverts

* update log level

* refactor cache to fetch metrics once

* remove port from random routing
  • Loading branch information
varungup90 authored Oct 7, 2024
1 parent a70c82d commit c6e1c2b
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 70 deletions.
7 changes: 4 additions & 3 deletions docs/development/app/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ kubectl -n envoy-gateway-system port-forward service/envoy-aibrix-system-aibrix-
# Add rpm/tpm config
```shell
kubectl -n aibrix-system exec -it aibrix-redis-master-767bcb955d-qrlfc -- redis-cli
kubectl -n aibrix-system port-forward svc/aibrix-gateway-users 8090:8090 &

set aibrix:your-user-name_TPM_LIMIT 100
set aibrix:your-user-name_RPM_LIMIT 10
curl http://localhost:8090/CreateUser \
-H "Content-Type: application/json" \
-d '{"name": "your-user-name","rpm": 100,"tpm": 1000}'
```
Test request (ensure header model name matches with deployment's model name for routing)
Expand Down
13 changes: 13 additions & 0 deletions docs/development/app/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flask import Flask, request, Response, jsonify
import time
from random import randint
import os
try:
from kubernetes import client, config
Expand Down Expand Up @@ -181,11 +182,23 @@ def metrics():
success_total = total / replicas
avg_prompt_throughput = total / replicas if replicas > 0 else 0
avg_generation_throughput = total / replicas if replicas > 0 else 0
running = randint(1, 100)
waiting = randint(1, 100)
swapped = randint(1, 100)

# construct Prometheus-style Metrics
metrics_output = f"""# HELP vllm:request_success_total Count of successfully processed requests.
# TYPE vllm:request_success_total counter
vllm:request_success_total{{finished_reason="stop",model_name="{model_name}"}} {success_total}
# HELP vllm:num_requests_running Number of requests currently running on GPU.
# TYPE vllm:num_requests_running gauge
vllm:num_requests_running{{model_name="{model_name}"}} {running}
# HELP vllm:num_requests_swapped Number of requests swapped to CPU.
# TYPE vllm:num_requests_swapped gauge
vllm:num_requests_swapped{{model_name="{model_name}"}} {swapped}
# HELP vllm:num_requests_waiting Number of requests waiting to be processed.
# TYPE vllm:num_requests_waiting gauge
vllm:num_requests_waiting{{model_name="{model_name}"}} {waiting}
# HELP vllm:avg_prompt_throughput_toks_per_s Average prefill throughput in tokens/s.
# TYPE vllm:avg_prompt_throughput_toks_per_s gauge
vllm:avg_prompt_throughput_toks_per_s{{model_name="{model_name}"}} {avg_prompt_throughput}
Expand Down
128 changes: 107 additions & 21 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ package cache
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"

crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions"
v1 "k8s.io/api/core/v1"
Expand All @@ -43,17 +48,21 @@ type Cache struct {
mu sync.RWMutex
initialized bool
pods map[string]*v1.Pod
podMetrics map[string]map[string]float64 // pod_name: map[metric_name]metric_val
podToModelMapping map[string]map[string]struct{} // pod_name: map[model_name]struct{}
modelToPodMapping map[string]map[string]*v1.Pod // model_name: map[pod_name]*v1.Pod
podRequestTracker map[string]int
}

var (
instance Cache
instance Cache
metricNames = []string{"num_requests_running", "num_requests_waiting", "num_requests_swapped",
"avg_prompt_throughput_toks_per_s", "avg_generation_throughput_toks_per_s"} //, "e2e_request_latency_seconds_sum"}
)

const (
modelIdentifier = "model.aibrix.ai/name"
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
podMetricRefreshIntervalInSeconds = 10
)

func GetCache() (*Cache, error) {
Expand All @@ -65,7 +74,6 @@ func GetCache() (*Cache, error) {

func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
once.Do(func() {

if err := v1alpha1scheme.AddToScheme(scheme.Scheme); err != nil {
panic(err)
}
Expand All @@ -90,9 +98,6 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
factory.Start(stopCh)
crdFactory.Start(stopCh)

// factory.WaitForCacheSync(stopCh)
// crdFactory.WaitForCacheSync(stopCh)

if !cache.WaitForCacheSync(stopCh, podInformer.HasSynced, modelInformer.HasSynced) {
runtime.HandleError(fmt.Errorf("timed out waiting for caches to sync"))
return
Expand All @@ -101,9 +106,9 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
instance = Cache{
initialized: true,
pods: map[string]*v1.Pod{},
podMetrics: map[string]map[string]float64{},
podToModelMapping: map[string]map[string]struct{}{},
modelToPodMapping: map[string]map[string]*v1.Pod{},
podRequestTracker: map[string]int{},
}

if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
Expand All @@ -121,6 +126,20 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
}); err != nil {
panic(err)
}

ticker := time.NewTicker(podMetricRefreshIntervalInSeconds * time.Second)
go func() {
for {
select {
case <-ticker.C:
instance.updatePodMetrics()
instance.debugInfo()
case <-stopCh:
ticker.Stop()
return
}
}
}()
})

return &instance
Expand Down Expand Up @@ -185,6 +204,7 @@ func (c *Cache) deletePod(obj interface{}) {
}
delete(c.podToModelMapping, pod.Name)
delete(c.pods, pod.Name)
delete(c.podMetrics, pod.Name)

klog.V(4).Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name)
c.debugInfo()
Expand Down Expand Up @@ -280,6 +300,11 @@ func (c *Cache) debugInfo() {
for _, pod := range c.pods {
klog.V(4).Infof("pod: %s, podIP: %v", pod.Name, pod.Status.PodIP)
}
for podName, metrics := range c.podMetrics {
for metricName, metricVal := range metrics {
klog.V(4).Infof("%v_%v_%v", podName, metricName, metricVal)
}
}
for podName, models := range c.podToModelMapping {
var modelList string
for modelName := range models {
Expand Down Expand Up @@ -339,25 +364,86 @@ func (c *Cache) GetModelsForPod(podName string) (map[string]struct{}, error) {
return models, nil
}

func (c *Cache) IncrPodRequestCount(podName string) int {
c.mu.Lock()
defer c.mu.Unlock()
func (c *Cache) GetPodMetric(podName, metricName string) (float64, error) {
c.mu.RLock()
defer c.mu.RUnlock()

c.podRequestTracker[podName] += 1
return c.podRequestTracker[podName]
metrics, ok := c.podMetrics[podName]
if !ok {
return 0, fmt.Errorf("pod does not exist in the metrics cache")
}

metricVal, ok := metrics[metricName]
if !ok {
return 0, fmt.Errorf("no metric available for %v", metricName)
}

return metricVal, nil
}

func (c *Cache) DecrPodRequestCount(podName string) int {
func (c *Cache) updatePodMetrics() {
c.mu.Lock()
defer c.mu.Unlock()

c.podRequestTracker[podName] -= 1
return c.podRequestTracker[podName]
}
for _, pod := range c.pods {
if pod.Status.PodIP == "" {
continue
}
podName := pod.Name
if len(c.podMetrics[podName]) == 0 {
c.podMetrics[podName] = map[string]float64{}
}

func (c *Cache) GetPodRequestCount() map[string]int {
c.mu.RLock()
defer c.mu.RUnlock()
// We should use the primary container port. In future, we can decide whether to use sidecar container's port
url := fmt.Sprintf("http://%s:%d/metrics", pod.Status.PodIP, podPort)
resp, err := http.Get(url)
if err != nil {
klog.Errorf("failed to fetch metrics from pod %s %s %d: %v", pod.Name, pod.Status.PodIP, podPort, err)
continue
}
defer func() {
if err := resp.Body.Close(); err != nil {
klog.Errorf("Error closing response body: %v", err)
}
}()

body, err := io.ReadAll(resp.Body)
if err != nil {
klog.Errorf("failed to read response from pod %s %s %d: %v", pod.Name, pod.Status.PodIP, podPort, err)
continue
}

for _, metricName := range metricNames {
metricValue, err := parseMetricFromBody(body, metricName)
if err != nil {
klog.Errorf("failed to parse metrics from pod %s %s %d: %v", pod.Name, pod.Status.PodIP, podPort, err)
continue
}

c.podMetrics[pod.Name][metricName] = metricValue
klog.V(5).InfoS("Successfully parsed metrics", "metric", metricName, "PodIP", pod.Status.PodIP, "Port", podPort, "metricValue", metricValue)
}
}
}

return c.podRequestTracker
func parseMetricFromBody(body []byte, metricName string) (float64, error) {
lines := strings.Split(string(body), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "#") && strings.Contains(line, metricName) {
// format is `http_requests_total 1234.56`
parts := strings.Fields(line)
if len(parts) < 2 {
return 0, fmt.Errorf("unexpected format for metric %s", metricName)
}

// parse to float64
value, err := strconv.ParseFloat(parts[len(parts)-1], 64)
if err != nil {
return 0, fmt.Errorf("failed to parse metric value for %s: %v", metricName, err)
}

return value, nil
}
}
return 0, fmt.Errorf("metrics %s not found", metricName)
}
40 changes: 28 additions & 12 deletions pkg/plugins/gateway/algorithms/least_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package routingalgorithms

import (
"context"
"fmt"
"math"

"github.com/aibrix/aibrix/pkg/cache"
Expand All @@ -44,22 +43,39 @@ func NewLeastRequestRouter(ratelimiter ratelimiter.RateLimiter) Router {
}
}

func (r leastRequestRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
func (r leastRequestRouter) Route(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
var targetPodIP string
minCount := math.MaxInt
podRequestCounts := r.cache.GetPodRequestCount()
minCount := math.MaxFloat64

for _, pod := range pods {
podIP := pod.Status.PodIP + ":8000"
podRequestCount := fmt.Sprintf("%v_REQUEST_COUNT", podIP)
if pod.Status.PodIP == "" {
continue
}

runningReq, err := r.cache.GetPodMetric(pod.Name, num_requests_running)
if err != nil {
klog.Error(err)
continue
}
waitingReq, err := r.cache.GetPodMetric(pod.Name, num_requests_waiting)
if err != nil {
klog.Error(err)
continue
}
swappedReq, err := r.cache.GetPodMetric(pod.Name, num_requests_swapped)
if err != nil {
klog.Error(err)
continue
}
totalReq := runningReq + waitingReq + swappedReq
klog.V(4).Infof("pod: %v, podIP: %v, runningReq: %v, waitingReq: %v, swappedReq: %v, totalReq: %v",
pod.Name, pod.Status.PodIP, runningReq, waitingReq, swappedReq, totalReq)

reqCount := podRequestCounts[podRequestCount]
klog.Infof("PodIP: %s, PodRequestCount: %v", podIP, reqCount)
if reqCount <= minCount {
minCount = reqCount
targetPodIP = podIP
if totalReq <= minCount {
minCount = totalReq
targetPodIP = pod.Status.PodIP
}
}

return targetPodIP, nil // TODO (varun): remove static port
return targetPodIP, nil
}
10 changes: 8 additions & 2 deletions pkg/plugins/gateway/algorithms/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package routingalgorithms

import (
"context"
"fmt"

v1 "k8s.io/api/core/v1"
)
Expand All @@ -29,11 +30,16 @@ func NewRandomRouter() Router {
return randomRouter{}
}

func (r randomRouter) Get(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
func (r randomRouter) Route(ctx context.Context, pods map[string]*v1.Pod) (string, error) {
var selectedPod *v1.Pod
if len(pods) == 0 {
return "", fmt.Errorf("no pods to forward request")
}

for _, pod := range pods {
selectedPod = pod
break
}

return selectedPod.Status.PodIP + ":8000", nil // TODO (varun): remove static port
return selectedPod.Status.PodIP, nil
}
14 changes: 12 additions & 2 deletions pkg/plugins/gateway/algorithms/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@ import (
v1 "k8s.io/api/core/v1"
)

const (
num_requests_running = "num_requests_running"
num_requests_waiting = "num_requests_waiting"
num_requests_swapped = "num_requests_swapped"
throughput_prompt = "avg_prompt_throughput_toks_per_s"
throughput_generation = "avg_generation_throughput_toks_per_s"
latency = "e2e_request_latency_seconds_sum"

podPort = 8000
)

type Router interface {
// Returns the target pod
// TODO (varun): replace with cache util package which can watch on pods
Get(ctx context.Context, pods map[string]*v1.Pod) (string, error)
Route(ctx context.Context, pods map[string]*v1.Pod) (string, error)
}
Loading

0 comments on commit c6e1c2b

Please sign in to comment.