Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefix cache aware routing #641

Merged
merged 11 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ require (

require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/emicklei/go-restful/v3 v3.12.1 // indirect
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/evanphx/json-patch v5.9.0+incompatible // indirect
Expand Down Expand Up @@ -72,6 +74,8 @@ require (
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.7 // indirect
github.com/pkoukk/tiktoken-go-loader v0.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
Expand Down
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
dario.cat/mergo v0.3.16 h1:wrt7QIfeqlABnUvmf9WpFwB0mGBwtySAJKTgCpnsbOE=
dario.cat/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw=
Expand All @@ -16,6 +19,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU=
github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI=
Expand Down Expand Up @@ -131,6 +136,10 @@ github.com/openai/openai-go v0.1.0-alpha.37 h1:dstNWRmODNmcvVrNhJ1tzmD8J9hy+aayc
github.com/openai/openai-go v0.1.0-alpha.37/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.1 h1:aOB2gRFzZTCCPi3YsOQXJO771P/5876JAsdebMyazig=
github.com/pkoukk/tiktoken-go-loader v0.0.1/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
Expand All @@ -147,6 +156,7 @@ github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0
github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
202 changes: 179 additions & 23 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ limitations under the License.
package cache

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math"
"os"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/aibrix/aibrix/pkg/utils"

crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions"
"github.com/cespare/xxhash"
"github.com/redis/go-redis/v9"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/runtime"
Expand All @@ -44,6 +44,7 @@ import (
v1alpha1 "github.com/aibrix/aibrix/pkg/client/clientset/versioned"
v1alpha1scheme "github.com/aibrix/aibrix/pkg/client/clientset/versioned/scheme"
"github.com/aibrix/aibrix/pkg/metrics"
"github.com/aibrix/aibrix/pkg/utils"
prometheusv1 "github.com/prometheus/client_golang/api/prometheus/v1"
dto "github.com/prometheus/client_model/go"
"k8s.io/client-go/kubernetes/scheme"
Expand All @@ -52,6 +53,7 @@ import (
var once sync.Once

// type global
// TODO: split pod/model events, request trace and prefix cache into separate modules.
type Cache struct {
mu sync.RWMutex
redisClient *redis.Client
Expand All @@ -68,13 +70,23 @@ type Cache struct {
requestTrace *sync.Map // model_name: RequestTrace
numRequestsTraces int32 // counter for requestTrace
pendingRequests *sync.Map // model_name: *int32

prefixBlocks map[uint64]Block //prefix_hash:Block
}

type Block struct {
modelToPods map[string]map[string]time.Time // model_name: map[pod_name]pod_last_access_time
lastAccessTime time.Time //block_last_access_time
}

const (
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
defaultPodMetricRefreshIntervalInMS = 50
expireWriteRequestTraceIntervalInMins = 10
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
defaultPodMetricRefreshIntervalInMS = 50
expireWriteRequestTraceIntervalInMins = 10
defaultPrefixCacheBlockSize = 16
defaultPrefixCacheEvictionInternalInMS = 50
defaultPrefixCacheEvictionDurationInMins = 60
)

var (
Expand Down Expand Up @@ -120,14 +132,18 @@ var (
metrics.RunningLoraAdapters,
}

podMetricRefreshInterval = getPodMetricRefreshInterval()
// TODO: add a helper function for get methods.
podMetricRefreshInterval = getPodMetricRefreshInterval()
prefixCacheBlockSize = getPrefixCacheBlockSize()
prefixCacheEvictionInterval = getPrefixCacheEvictionInterval()
prefixCacheEvictionDuration = getPrefixCacheEvictionDuration()
)

func getPodMetricRefreshInterval() time.Duration {
value := LoadEnv("AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS", "")
value := utils.LoadEnv("AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil {
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS env value for pod metrics refresh interval: %d ms", intValue)
Expand All @@ -138,23 +154,58 @@ func getPodMetricRefreshInterval() time.Duration {
return defaultPodMetricRefreshIntervalInMS * time.Millisecond
}

func getPrefixCacheBlockSize() int {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_BLOCK_SIZE", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_BLOCK_SIZE: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_BLOCK_SIZE env value for prefix cache block size: %d", intValue)
return intValue
}
}
klog.Infof("using default prefix cache block size: %d", defaultPrefixCacheBlockSize)
return defaultPrefixCacheBlockSize
}

func getPrefixCacheEvictionInterval() time.Duration {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS env value for prefix cache eviction interval: %d ms", intValue)
return time.Duration(intValue) * time.Millisecond
}
}
klog.Infof("using default prefix cache eviction interval: %d ms", defaultPrefixCacheEvictionInternalInMS)
return defaultPrefixCacheEvictionInternalInMS * time.Millisecond
}

func getPrefixCacheEvictionDuration() time.Duration {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS env value for prefix cache eviction interval: %d ms", intValue)
return time.Duration(intValue) * time.Millisecond
}
}
klog.Infof("using default prefix cache eviction interval: %d ms", defaultPrefixCacheEvictionInternalInMS)
return defaultPrefixCacheEvictionInternalInMS * time.Millisecond
}

func GetCache() (*Cache, error) {
if !instance.initialized {
return nil, errors.New("cache is not initialized")
}
return &instance, nil
}

// LoadEnv loads an environment variable or returns a default value if not set.
func LoadEnv(key, defaultValue string) string {
value := os.Getenv(key)
if value == "" {
klog.Warningf("environment variable %s is not set, using default value: %s", key, defaultValue)
return defaultValue
}
return value
}

func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Client) *Cache {
once.Do(func() {
if err := v1alpha1scheme.AddToScheme(scheme.Scheme); err != nil {
Expand Down Expand Up @@ -187,9 +238,9 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Cl
}

// Load environment variables
prometheusEndpoint := LoadEnv("PROMETHEUS_ENDPOINT", "")
prometheusBasicAuthUsername := LoadEnv("PROMETHEUS_BASIC_AUTH_USERNAME", "")
prometheusBasicAuthPassword := LoadEnv("PROMETHEUS_BASIC_AUTH_PASSWORD", "")
prometheusEndpoint := utils.LoadEnv("PROMETHEUS_ENDPOINT", "")
prometheusBasicAuthUsername := utils.LoadEnv("PROMETHEUS_BASIC_AUTH_USERNAME", "")
prometheusBasicAuthPassword := utils.LoadEnv("PROMETHEUS_BASIC_AUTH_PASSWORD", "")

// Initialize Prometheus API
var prometheusApi prometheusv1.API
Expand All @@ -214,6 +265,7 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Cl
ModelToPodMapping: map[string]map[string]*v1.Pod{},
requestTrace: &sync.Map{},
pendingRequests: &sync.Map{},
prefixBlocks: map[uint64]Block{},
}
if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: instance.addPod,
Expand Down Expand Up @@ -246,6 +298,19 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Cl
}
}()

ticker = time.NewTicker(prefixCacheEvictionInterval)
go func() {
for {
select {
case <-ticker.C:
instance.prefixCacheEviction()
case <-stopCh:
ticker.Stop()
return
}
}
}()

tickerOffset := time.Duration(time.Now().UnixNano()) % RequestTraceWriteInterval
var traceAlignmentTimer *time.Timer
// TODO: Using ticker may be a problem if writeRequestTraceToStorage takes too long.
Expand Down Expand Up @@ -963,3 +1028,94 @@ func (c *Cache) aggregateMetrics() {
}
}
}

// returns matchedTokens, unMatchedTokens, matchedPods
// TODO: add an interface with multiple implementations such as hash or radix tree
func (c *Cache) MatchPrefix(tokens []int, model string, pods []*v1.Pod) ([]int, []int, []*v1.Pod) {
c.mu.RLock()
defer c.mu.RUnlock()
var block, lastMatchedBlock Block
var ok bool
var lastTokenMatchIndex int

for i := 0; i < len(tokens); i += prefixCacheBlockSize {
end := i + prefixCacheBlockSize
if end > len(tokens) {
end = len(tokens)
}
chunk := tokens[i:end]
prefixHash := xxhash.Sum64(IntArrayToByteArray(chunk))
block, ok = c.prefixBlocks[prefixHash]

if !ok || len(block.modelToPods[model]) == 0 {
lastTokenMatchIndex = i
break
}

lastTokenMatchIndex = end
lastMatchedBlock = block
block.lastAccessTime = time.Now()
c.prefixBlocks[prefixHash] = block
}

matchedTokens := tokens[0:lastTokenMatchIndex]
unMatchedTokens := tokens[lastTokenMatchIndex:]

var matchedPods []*v1.Pod
blockPods := lastMatchedBlock.modelToPods[model]
for _, pod := range pods {
if _, ok := blockPods[pod.Name]; ok {
matchedPods = append(matchedPods, pod)
}
}

return matchedTokens, unMatchedTokens, matchedPods
}

func (c *Cache) AddPrefixBlock(unMatchedTokens []int, model, pod string) {
c.mu.Lock()
defer c.mu.Unlock()

for i := 0; i < len(unMatchedTokens); i += prefixCacheBlockSize {
end := i + prefixCacheBlockSize
if end > len(unMatchedTokens) {
end = len(unMatchedTokens)
}
chunk := unMatchedTokens[i:end]
prefixHash := xxhash.Sum64(IntArrayToByteArray(chunk))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here. it just consider the current block size? just like to confirm this is not 100% same as vLLM's solution right? their 1st block hash is part of 2nd hash input

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, hash only for current block. No link list kind of behavior that vllm has. For our usecase we do not need that link list behavior.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that case, we need to do up to O(n) calculations? n=number of blocks

Copy link
Collaborator Author

@varungup90 varungup90 Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If you are referring to that link list behavior could prevent O(n) calculations then it wont be the case. Total computations stays the same.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a little bit different. LinkedList you can do binary search etc for optimization. In this way, we can only do O(n).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me look into this, but for our usecase we need to evaluate all blocks to ensure a 50%+ hit rate.

block, ok := c.prefixBlocks[prefixHash]
if !ok {
block = Block{
modelToPods: map[string]map[string]time.Time{},
lastAccessTime: time.Now(),
}
c.prefixBlocks[prefixHash] = block
}

blockPods, ok := block.modelToPods[model]
if !ok {
blockPods = map[string]time.Time{}
block.modelToPods[model] = blockPods
}

blockPods[pod] = time.Now()
}
}

// To be implemented
func (c *Cache) prefixCacheEviction() {
c.mu.Lock()
defer c.mu.Unlock()

}

func IntArrayToByteArray(intArray []int) []byte {
buf := new(bytes.Buffer)
for _, val := range intArray {
err := binary.Write(buf, binary.LittleEndian, int32(val))
if err != nil {
panic(err)
}
}
return buf.Bytes()
}
Loading
Loading