Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request trace for profiling #291

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/controllers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func main() {
panic(err)
}

cache.NewCache(config, stopCh)
cache.NewCache(config, stopCh, nil)

// Kind controller registration is encapsulated inside the pkg/controller/controller.go
// So here we can use more clean registration flow and there's no need to change logics in future.
Expand Down
6 changes: 3 additions & 3 deletions cmd/plugins/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func main() {
// Connect to Redis
redisClient := utils.GetRedisClient()

fmt.Println("Starting cache")
fmt.Println("starting cache")
stopCh := make(chan struct{})
defer close(stopCh)
var config *rest.Config
Expand All @@ -71,7 +71,7 @@ func main() {
panic(err)
}

cache.NewCache(config, stopCh)
cache.NewCache(config, stopCh, redisClient)

// Connect to K8s cluster
k8sClient, err := kubernetes.NewForConfig(config)
Expand All @@ -90,7 +90,7 @@ func main() {
extProcPb.RegisterExternalProcessorServer(s, gateway.NewServer(redisClient, k8sClient))
healthPb.RegisterHealthServer(s, &gateway.HealthServer{})

klog.Info("Starting gRPC server on port :50052")
klog.Info("starting gRPC server on port :50052")

// shutdown
var gracefulStop = make(chan os.Signal, 1)
Expand Down
2 changes: 1 addition & 1 deletion docs/development/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Here are some essential resources for anyone interested in AIBricks:
- **Project Roadmap**: TODO

Additional resources for contributors:
jj

- **Kubebuilder Tutorial**: [Learn Kubebuilder](https://book.kubebuilder.io/) - A comprehensive, step-by-step guide to developing with Kubebuilder.
- **Kubernetes Documentation**: [Explore Kubernetes Docs](https://kubernetes.io/docs/home/) - Detailed explanations of Kubernetes concepts.

Expand Down
18 changes: 12 additions & 6 deletions docs/development/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def completion():
if not prompt or not model:
return jsonify({"status": "error", "message": "Prompt and model are required"}), 400

prompt_tokens = randint(1, 100)
completion_tokens = randint(1, 100)

# Simulated response
response = {
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
Expand All @@ -124,9 +127,9 @@ def completion():
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens+completion_tokens
}
}
return jsonify(response), 200
Expand All @@ -139,16 +142,19 @@ def chat_completions():
if not messages or not model:
return jsonify({"status": "error", "message": "Messages and model are required"}), 400

prompt_tokens = randint(1, 100)
completion_tokens = randint(1, 100)

# Simulated response
response = {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1677858242,
"model": model,
"usage": {
"prompt_tokens": 13,
"completion_tokens": 7,
"total_tokens": 20
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens+completion_tokens
},
"choices": [
{
Expand Down
84 changes: 80 additions & 4 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ limitations under the License.
package cache

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"strconv"
"strings"
"sync"
"time"

crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions"
"github.com/redis/go-redis/v9"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/informers"
Expand All @@ -46,11 +50,13 @@ var once sync.Once
// type global
type Cache struct {
mu sync.RWMutex
redisClient *redis.Client
initialized bool
pods map[string]*v1.Pod
podMetrics map[string]map[string]float64 // pod_name: map[metric_name]metric_val
podToModelMapping map[string]map[string]struct{} // pod_name: map[model_name]struct{}
modelToPodMapping map[string]map[string]*v1.Pod // model_name: map[pod_name]*v1.Pod
requestTrace map[string]map[string]int // model_name: map[Log2(input_token)-Log2(output_token)]request_count
}

var (
Expand All @@ -60,9 +66,11 @@ var (
)

const (
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
podMetricRefreshIntervalInSeconds = 10
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
podMetricRefreshIntervalInSeconds = 10
writeRequestTraceIntervalInSeconds = 10
expireWriteRequestTraceIntervalInMins = 10
)

func GetCache() (*Cache, error) {
Expand All @@ -72,7 +80,7 @@ func GetCache() (*Cache, error) {
return &instance, nil
}

func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Client) *Cache {
once.Do(func() {
if err := v1alpha1scheme.AddToScheme(scheme.Scheme); err != nil {
panic(err)
Expand Down Expand Up @@ -105,10 +113,12 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {

instance = Cache{
initialized: true,
redisClient: redisClient,
pods: map[string]*v1.Pod{},
podMetrics: map[string]map[string]float64{},
podToModelMapping: map[string]map[string]struct{}{},
modelToPodMapping: map[string]map[string]*v1.Pod{},
requestTrace: map[string]map[string]int{},
}

if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
Expand Down Expand Up @@ -140,6 +150,27 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}) *Cache {
}
}
}()

traceTicker := time.NewTicker(writeRequestTraceIntervalInSeconds * time.Second)
go func() {
if redisClient == nil {
return
}
for {
select {
case <-traceTicker.C:
if len(instance.requestTrace) == 0 {
continue
}
t := time.Now().Unix()
roundT := t - t%writeRequestTraceIntervalInSeconds
instance.writeRequestTraceToStorage(roundT)
case <-stopCh:
ticker.Stop()
return
}
}
}()
})

return &instance
Expand Down Expand Up @@ -319,6 +350,11 @@ func (c *Cache) debugInfo() {
}
klog.V(4).Infof("model: %s, pods: %s", modelName, podList)
}
for inputIndex, output := range c.requestTrace {
for outputIndex, requestCount := range output {
klog.V(4).Infof("inputIndex: %v, outputIndex: %v, requestCount: %v", inputIndex, outputIndex, requestCount)
}
}
}

func (c *Cache) GetPod(podName string) (*v1.Pod, error) {
Expand Down Expand Up @@ -447,3 +483,43 @@ func parseMetricFromBody(body []byte, metricName string) (float64, error) {
}
return 0, fmt.Errorf("metrics %s not found", metricName)
}

func (c *Cache) AddRequestTrace(modelName string, inputTokens, outputTokens int) {
c.mu.Lock()
defer c.mu.Unlock()

inputIndex := math.Trunc(math.Log2(float64(inputTokens)))
outputIndex := math.Trunc(math.Log2(float64(outputTokens)))

klog.V(5).Infof("inputTokens: %v, inputIndex: %v, outputTokens: %v, outputIndex: %v",
inputTokens, inputIndex, outputTokens, outputIndex)

if len(c.requestTrace[modelName]) == 0 {
c.requestTrace[modelName] = map[string]int{}
}

c.requestTrace[modelName][fmt.Sprintf("%v:%v", inputIndex, outputIndex)] += 1
}

func (c *Cache) writeRequestTraceToStorage(roundT int64) {
c.mu.Lock()
defer c.mu.Unlock()

defer func() {
klog.V(5).Infof("writeRequestTraceWithKey: %v", roundT)
c.requestTrace = map[string]map[string]int{}
}()

for modelName, trace := range c.requestTrace {
key := fmt.Sprintf("aibrix:%v_request_trace_%v", modelName, roundT)
value, err := json.Marshal(trace)
if err != nil {
klog.ErrorS(err, "error to marshall request trace for redis set")
continue
}

if _, err = c.redisClient.Set(context.Background(), key, value, expireWriteRequestTraceIntervalInMins*time.Minute).Result(); err != nil {
klog.Error(err)
}
}
}
10 changes: 8 additions & 2 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
}

func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, utils.User, int64, string) {
klog.Info("\n\n-- In RequestHeaders processing ...")
klog.Info("\n\n")
klog.Info("-- In RequestHeaders processing ...")
var username, routingStrategy string
var user utils.User
var rpm int64
Expand Down Expand Up @@ -307,7 +308,6 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)

var res openai.CompletionResponse
klog.Info(b.ResponseBody.String())
if err := json.Unmarshal(b.ResponseBody.Body, &res); err != nil {
return generateErrorResponse(
envoyTypePb.StatusCode_InternalServerError,
Expand All @@ -317,6 +317,12 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
err.Error())
}

defer func() {
go func() {
s.cache.AddRequestTrace(res.Model, res.Usage.PromptTokens, res.Usage.CompletionTokens)
}()
}()

headers := []*configPb.HeaderValueOption{}
if user.Name != "" {
tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user), int64(res.Usage.TotalTokens))
Expand Down
Loading