Skip to content

Commit

Permalink
Add prefix cache aware routing (#641)
Browse files Browse the repository at this point in the history
* Add prefix cache aware routing

* end to end stiching

* fix lint errors

* nit

* add integ test for prefix caching

* add constants

* address review comments

* add prefix cache eviction

* add unit test for prefix cache eviction
  • Loading branch information
varungup90 authored Feb 10, 2025
1 parent 573d254 commit f40c973
Show file tree
Hide file tree
Showing 22 changed files with 759 additions and 110 deletions.
2 changes: 2 additions & 0 deletions config/gateway/gateway-plugin/gateway-plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ spec:
value: "6379"
- name: AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS
value: "50"
# - name: AIBRIX_PREFIX_CACHE_EVICTION_DURATION_MINS
# value: "1"
- name: POD_NAME
valueFrom:
fieldRef:
Expand Down
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ require (

require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/emicklei/go-restful/v3 v3.12.1 // indirect
github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect
github.com/evanphx/json-patch v5.9.0+incompatible // indirect
Expand Down Expand Up @@ -72,6 +74,8 @@ require (
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.7 // indirect
github.com/pkoukk/tiktoken-go-loader v0.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
Expand Down
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
dario.cat/mergo v0.3.16 h1:wrt7QIfeqlABnUvmf9WpFwB0mGBwtySAJKTgCpnsbOE=
dario.cat/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw=
Expand All @@ -16,6 +19,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU=
github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI=
Expand Down Expand Up @@ -131,6 +136,10 @@ github.com/openai/openai-go v0.1.0-alpha.37 h1:dstNWRmODNmcvVrNhJ1tzmD8J9hy+aayc
github.com/openai/openai-go v0.1.0-alpha.37/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.1 h1:aOB2gRFzZTCCPi3YsOQXJO771P/5876JAsdebMyazig=
github.com/pkoukk/tiktoken-go-loader v0.0.1/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
Expand All @@ -147,6 +156,7 @@ github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0
github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
208 changes: 185 additions & 23 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ limitations under the License.
package cache

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math"
"os"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/aibrix/aibrix/pkg/utils"

crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions"
"github.com/cespare/xxhash"
"github.com/redis/go-redis/v9"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/runtime"
Expand All @@ -44,6 +44,7 @@ import (
v1alpha1 "github.com/aibrix/aibrix/pkg/client/clientset/versioned"
v1alpha1scheme "github.com/aibrix/aibrix/pkg/client/clientset/versioned/scheme"
"github.com/aibrix/aibrix/pkg/metrics"
"github.com/aibrix/aibrix/pkg/utils"
prometheusv1 "github.com/prometheus/client_golang/api/prometheus/v1"
dto "github.com/prometheus/client_model/go"
"k8s.io/client-go/kubernetes/scheme"
Expand All @@ -52,6 +53,7 @@ import (
var once sync.Once

// type global
// TODO: split pod/model events, request trace and prefix cache into separate modules.
type Cache struct {
mu sync.RWMutex
redisClient *redis.Client
Expand All @@ -68,13 +70,23 @@ type Cache struct {
requestTrace *sync.Map // model_name: RequestTrace
numRequestsTraces int32 // counter for requestTrace
pendingRequests *sync.Map // model_name: *int32

prefixBlocks map[uint64]Block //prefix_hash:Block
}

type Block struct {
modelToPods map[string]map[string]time.Time // model_name: map[pod_name]pod_last_access_time
lastAccessTime time.Time //block_last_access_time
}

const (
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
defaultPodMetricRefreshIntervalInMS = 50
expireWriteRequestTraceIntervalInMins = 10
modelIdentifier = "model.aibrix.ai/name"
podPort = 8000
defaultPodMetricRefreshIntervalInMS = 50
expireWriteRequestTraceIntervalInMins = 10
defaultPrefixCacheBlockSize = 16
defaultPrefixCacheEvictionInternalInMS = 50
defaultPrefixCacheEvictionDurationInMins = 60
)

var (
Expand Down Expand Up @@ -120,14 +132,18 @@ var (
metrics.RunningLoraAdapters,
}

podMetricRefreshInterval = getPodMetricRefreshInterval()
// TODO: add a helper function for get methods.
podMetricRefreshInterval = getPodMetricRefreshInterval()
prefixCacheBlockSize = getPrefixCacheBlockSize()
prefixCacheEvictionInterval = getPrefixCacheEvictionInterval()
prefixCacheEvictionDuration = getPrefixCacheEvictionDuration()
)

func getPodMetricRefreshInterval() time.Duration {
value := LoadEnv("AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS", "")
value := utils.LoadEnv("AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil {
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_POD_METRIC_REFRESH_INTERVAL_MS env value for pod metrics refresh interval: %d ms", intValue)
Expand All @@ -138,23 +154,58 @@ func getPodMetricRefreshInterval() time.Duration {
return defaultPodMetricRefreshIntervalInMS * time.Millisecond
}

func getPrefixCacheBlockSize() int {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_BLOCK_SIZE", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_BLOCK_SIZE: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_BLOCK_SIZE env value for prefix cache block size: %d", intValue)
return intValue
}
}
klog.Infof("using default prefix cache block size: %d", defaultPrefixCacheBlockSize)
return defaultPrefixCacheBlockSize
}

func getPrefixCacheEvictionInterval() time.Duration {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_EVICTION_INTERVAL_MS env value for prefix cache eviction interval: %d ms", intValue)
return time.Duration(intValue) * time.Millisecond
}
}
klog.Infof("using default prefix cache eviction interval: %d ms", defaultPrefixCacheEvictionInternalInMS)
return defaultPrefixCacheEvictionInternalInMS * time.Millisecond
}

func getPrefixCacheEvictionDuration() time.Duration {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_EVICTION_DURATION_MINS", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_EVICTION_DURATION_MINS: %s, falling back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_EVICTION_DURATION_MINS env value for prefix cache eviction duration: %d ms", intValue)
return time.Duration(intValue) * time.Minute
}
}
klog.Infof("using default prefix cache eviction duration: %d mins", defaultPrefixCacheEvictionDurationInMins)
return defaultPrefixCacheEvictionDurationInMins * time.Minute
}

func GetCache() (*Cache, error) {
if !instance.initialized {
return nil, errors.New("cache is not initialized")
}
return &instance, nil
}

// LoadEnv loads an environment variable or returns a default value if not set.
func LoadEnv(key, defaultValue string) string {
value := os.Getenv(key)
if value == "" {
klog.Warningf("environment variable %s is not set, using default value: %s", key, defaultValue)
return defaultValue
}
return value
}

func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Client) *Cache {
once.Do(func() {
if err := v1alpha1scheme.AddToScheme(scheme.Scheme); err != nil {
Expand Down Expand Up @@ -187,9 +238,9 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Cl
}

// Load environment variables
prometheusEndpoint := LoadEnv("PROMETHEUS_ENDPOINT", "")
prometheusBasicAuthUsername := LoadEnv("PROMETHEUS_BASIC_AUTH_USERNAME", "")
prometheusBasicAuthPassword := LoadEnv("PROMETHEUS_BASIC_AUTH_PASSWORD", "")
prometheusEndpoint := utils.LoadEnv("PROMETHEUS_ENDPOINT", "")
prometheusBasicAuthUsername := utils.LoadEnv("PROMETHEUS_BASIC_AUTH_USERNAME", "")
prometheusBasicAuthPassword := utils.LoadEnv("PROMETHEUS_BASIC_AUTH_PASSWORD", "")

// Initialize Prometheus API
var prometheusApi prometheusv1.API
Expand All @@ -214,6 +265,7 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Cl
ModelToPodMapping: map[string]map[string]*v1.Pod{},
requestTrace: &sync.Map{},
pendingRequests: &sync.Map{},
prefixBlocks: map[uint64]Block{},
}
if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: instance.addPod,
Expand Down Expand Up @@ -246,6 +298,19 @@ func NewCache(config *rest.Config, stopCh <-chan struct{}, redisClient *redis.Cl
}
}()

ticker = time.NewTicker(prefixCacheEvictionInterval)
go func() {
for {
select {
case <-ticker.C:
instance.prefixCacheEviction(time.Now())
case <-stopCh:
ticker.Stop()
return
}
}
}()

tickerOffset := time.Duration(time.Now().UnixNano()) % RequestTraceWriteInterval
var traceAlignmentTimer *time.Timer
// TODO: Using ticker may be a problem if writeRequestTraceToStorage takes too long.
Expand Down Expand Up @@ -963,3 +1028,100 @@ func (c *Cache) aggregateMetrics() {
}
}
}

// returns matchedTokens, unMatchedTokens, matchedPods
// TODO: add an interface with multiple implementations such as hash or radix tree
func (c *Cache) MatchPrefix(tokens []int, model string, pods []*v1.Pod) ([]int, []int, []*v1.Pod) {
c.mu.RLock()
defer c.mu.RUnlock()
var block, lastMatchedBlock Block
var ok bool
var lastTokenMatchIndex int

for i := 0; i < len(tokens); i += prefixCacheBlockSize {
end := i + prefixCacheBlockSize
if end > len(tokens) {
end = len(tokens)
}

chunk := tokens[i:end]
prefixHash := xxhash.Sum64(IntArrayToByteArray(chunk))
block, ok = c.prefixBlocks[prefixHash]
if !ok || len(block.modelToPods[model]) == 0 {
lastTokenMatchIndex = i
break
}

lastTokenMatchIndex = end
lastMatchedBlock = block
block.lastAccessTime = time.Now()
c.prefixBlocks[prefixHash] = block
}

matchedTokens := tokens[0:lastTokenMatchIndex]
unMatchedTokens := tokens[lastTokenMatchIndex:]

var matchedPods []*v1.Pod
blockPods := lastMatchedBlock.modelToPods[model]
for _, pod := range pods {
if _, ok := blockPods[pod.Name]; ok {
matchedPods = append(matchedPods, pod)
}
}

return matchedTokens, unMatchedTokens, matchedPods
}

func (c *Cache) AddPrefixBlock(unMatchedTokens []int, model, pod string) {
c.mu.Lock()
defer c.mu.Unlock()

for i := 0; i < len(unMatchedTokens); i += prefixCacheBlockSize {
end := i + prefixCacheBlockSize
if end > len(unMatchedTokens) {
end = len(unMatchedTokens)
}

chunk := unMatchedTokens[i:end]
prefixHash := xxhash.Sum64(IntArrayToByteArray(chunk))
block, ok := c.prefixBlocks[prefixHash]
if !ok {
block = Block{
modelToPods: map[string]map[string]time.Time{},
lastAccessTime: time.Now(),
}
c.prefixBlocks[prefixHash] = block
}

blockPods, ok := block.modelToPods[model]
if !ok {
blockPods = map[string]time.Time{}
block.modelToPods[model] = blockPods
}

blockPods[pod] = time.Now()
}
}

func (c *Cache) prefixCacheEviction(now time.Time) {
c.mu.Lock()
defer c.mu.Unlock()

for hash, block := range c.prefixBlocks {
if now.Sub(block.lastAccessTime) > prefixCacheEvictionDuration {
delete(c.prefixBlocks, hash)
klog.InfoS("prefix cache block evicted", "hash", hash)
}
}
}

func IntArrayToByteArray(intArray []int) []byte {
buf := new(bytes.Buffer)
for _, val := range intArray {
err := binary.Write(buf, binary.LittleEndian, int32(val))
if err != nil {
panic(err)
}
}
return buf.Bytes()
}
Loading

0 comments on commit f40c973

Please sign in to comment.