From 4a5341510f4107b98268e6398195d6223ad3c626 Mon Sep 17 00:00:00 2001 From: Varun Gupta Date: Thu, 29 Aug 2024 17:38:45 -0700 Subject: [PATCH] Add custom cache and interface for model adapter scheduling (#100) * Add custom CRD clientset * nit * Add custom cache * test * test * nit * clean up .DS_Store files * add interface * fix lint errors * revert controller name and tag --------- Co-authored-by: varungupta --- Dockerfile | 2 + cmd/main.go | 6 + docs/tutorial/lora/README.md | 2 +- go.mod | 1 + go.sum | 2 + pkg/cache/cache.go | 258 ++++++++++++++++++ .../modeladapter/modeladapter_controller.go | 17 +- .../modeladapter/scheduling/leastadapters.go | 60 ++++ .../modeladapter/scheduling/scheduler.go | 28 ++ 9 files changed, 372 insertions(+), 4 deletions(-) create mode 100644 pkg/cache/cache.go create mode 100644 pkg/controller/modeladapter/scheduling/leastadapters.go create mode 100644 pkg/controller/modeladapter/scheduling/scheduler.go diff --git a/Dockerfile b/Dockerfile index 7fb7e903..424a4483 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,6 +16,8 @@ COPY cmd/main.go cmd/main.go COPY api/ api/ COPY pkg/controller/ pkg/controller/ COPY pkg/utils/ pkg/utils/ +COPY pkg/cache/ pkg/cache/ +COPY pkg/client/ pkg/client/ # Build # the GOARCH has not a default value to allow the binary be built according to the host where the command diff --git a/cmd/main.go b/cmd/main.go index 3a7af4f8..9a8555e0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -40,6 +40,7 @@ import ( autoscalingv1alpha1 "github.com/aibrix/aibrix/api/autoscaling/v1alpha1" modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1" orchestrationv1alpha1 "github.com/aibrix/aibrix/api/orchestration/v1alpha1" + "github.com/aibrix/aibrix/pkg/cache" "github.com/aibrix/aibrix/pkg/controller" //+kubebuilder:scaffold:imports ) @@ -161,6 +162,11 @@ func main() { os.Exit(1) } + setupLog.Info("starting cache") + stopCh := make(chan struct{}) + defer close(stopCh) + cache.NewCache(stopCh) + // Kind controller registration is encapsulated inside the pkg/controller/controller.go // So here we can use more clean registration flow and there's no need to change logics in future. if err = controller.SetupWithManager(mgr); err != nil { diff --git a/docs/tutorial/lora/README.md b/docs/tutorial/lora/README.md index 55c5abca..041cf27f 100644 --- a/docs/tutorial/lora/README.md +++ b/docs/tutorial/lora/README.md @@ -27,7 +27,7 @@ curl -X POST http://localhost:8000/v1/load_lora_adapter \ ``` # check available models -curl http://localhost:8000/v1/models +curl http://localhost:8000/v1/models | jq . ``` 4. Unload Model diff --git a/go.mod b/go.mod index 6d63c129..4182529d 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( k8s.io/apimachinery v0.29.2 k8s.io/client-go v0.29.2 k8s.io/code-generator v0.29.2 + k8s.io/klog v0.2.0 k8s.io/klog/v2 v2.110.1 k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 k8s.io/utils v0.0.0-20230726121419-3b25d923346b diff --git a/go.sum b/go.sum index 5c4f96f5..f649cfe5 100644 --- a/go.sum +++ b/go.sum @@ -183,6 +183,8 @@ k8s.io/component-base v0.29.2 h1:lpiLyuvPA9yV1aQwGLENYyK7n/8t6l3nn3zAtFTJYe8= k8s.io/component-base v0.29.2/go.mod h1:BfB3SLrefbZXiBfbM+2H1dlat21Uewg/5qtKOl8degM= k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01 h1:pWEwq4Asjm4vjW7vcsmijwBhOr1/shsbSYiWXmNGlks= k8s.io/gengo v0.0.0-20230829151522-9cce18d56c01/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= +k8s.io/klog v0.2.0 h1:0ElL0OHzF3N+OhoJTL0uca20SxtYt4X4+bzHeqrB83c= +k8s.io/klog v0.2.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 00000000..d09f61d2 --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,258 @@ +/* +Copyright 2024 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cache + +import ( + "errors" + "fmt" + "log" + "strings" + "sync" + + crdinformers "github.com/aibrix/aibrix/pkg/client/informers/externalversions" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/klog" + + modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1" + v1alpha1 "github.com/aibrix/aibrix/pkg/client/clientset/versioned" + v1alpha1scheme "github.com/aibrix/aibrix/pkg/client/clientset/versioned/scheme" + "k8s.io/client-go/kubernetes/scheme" +) + +var once sync.Once + +// type global +type Cache struct { + mu sync.RWMutex + initialized bool + pods map[string]*v1.Pod + modelAdapterToPodMapping map[string][]string + podToModelAdapterMapping map[string]map[string]struct{} +} + +var ( + instance Cache + kubeconfig string +) + +func GetCache() (*Cache, error) { + if !instance.initialized { + return nil, errors.New("cache is not initialized") + } + return &instance, nil +} + +func NewCache(stopCh <-chan struct{}) *Cache { + once.Do(func() { + var config *rest.Config + var err error + + if kubeconfig == "" { + log.Printf("using in-cluster configuration") + config, err = rest.InClusterConfig() + } else { + log.Printf("using configuration from '%s'", kubeconfig) + config, err = clientcmd.BuildConfigFromFlags("", kubeconfig) + } + + if err != nil { + panic(err) + } + + if err := v1alpha1scheme.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + k8sClientSet, err := kubernetes.NewForConfig(config) + if err != nil { + panic(err) + } + + crdClientSet, err := v1alpha1.NewForConfig(config) + if err != nil { + panic(err) + } + + factory := informers.NewSharedInformerFactoryWithOptions(k8sClientSet, 0) + crdFactory := crdinformers.NewSharedInformerFactoryWithOptions(crdClientSet, 0) + + podInformer := factory.Core().V1().Pods().Informer() + modeInformer := crdFactory.Model().V1alpha1().ModelAdapters().Informer() + + defer runtime.HandleCrash() + factory.Start(stopCh) + crdFactory.Start(stopCh) + + // factory.WaitForCacheSync(stopCh) + // crdFactory.WaitForCacheSync(stopCh) + + if !cache.WaitForCacheSync(stopCh, podInformer.HasSynced, modeInformer.HasSynced) { + runtime.HandleError(fmt.Errorf("timed out waiting for caches to sync")) + return + } + + instance = Cache{ + initialized: true, + pods: map[string]*v1.Pod{}, + modelAdapterToPodMapping: map[string][]string{}, + podToModelAdapterMapping: map[string]map[string]struct{}{}, + } + + if _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: instance.addPod, + UpdateFunc: instance.updatePod, + DeleteFunc: instance.deletePod, + }); err != nil { + panic(err) + } + + if _, err = modeInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: instance.addModel, + UpdateFunc: instance.updateModel, + DeleteFunc: instance.deleteModel, + }); err != nil { + panic(err) + } + }) + + return &instance +} + +func (c *Cache) addPod(obj interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + pod := obj.(*v1.Pod) + c.pods[pod.Name] = pod + c.podToModelAdapterMapping[pod.Name] = map[string]struct{}{} + klog.Infof("POD CREATED: %s/%s", pod.Namespace, pod.Name) +} + +func (c *Cache) updatePod(oldObj interface{}, newObj interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + oldPod := oldObj.(*v1.Pod) + newPod := newObj.(*v1.Pod) + klog.Infof("POD UPDATED. %s/%s %s", oldPod.Namespace, oldPod.Name, newPod.Status.Phase) +} + +func (c *Cache) deletePod(obj interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + pod := obj.(*v1.Pod) + delete(c.pods, pod.Name) + klog.Infof("POD DELETED: %s/%s", pod.Namespace, pod.Name) +} + +func (c *Cache) addModel(obj interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + model := obj.(*modelv1alpha1.ModelAdapter) + c.modelAdapterToPodMapping[model.Name] = model.Status.Instances + c.addModelAdapterMapping(model) + + klog.Infof("MODELADAPTER CREATED: %s/%s", model.Namespace, model.Name) +} + +func (c *Cache) updateModel(oldObj interface{}, newObj interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + oldModel := oldObj.(*modelv1alpha1.ModelAdapter) + newModel := newObj.(*modelv1alpha1.ModelAdapter) + c.modelAdapterToPodMapping[newModel.Name] = newModel.Status.Instances + c.deleteModelAdapterMapping(oldModel) + c.addModelAdapterMapping(newModel) + + klog.Infof("MODELADAPTER UPDATED. %s/%s %s", oldModel.Namespace, oldModel.Name, newModel.Status.Phase) +} + +func (c *Cache) deleteModel(obj interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + model := obj.(*modelv1alpha1.ModelAdapter) + delete(c.modelAdapterToPodMapping, model.Name) + c.deleteModelAdapterMapping(model) + + klog.Infof("MODELADAPTER DELETED: %s/%s", model.Namespace, model.Name) +} + +func (c *Cache) addModelAdapterMapping(model *modelv1alpha1.ModelAdapter) { + for _, pod := range model.Status.Instances { + models, ok := c.podToModelAdapterMapping[pod] + if !ok { + c.podToModelAdapterMapping[pod] = map[string]struct{}{ + model.Name: {}, + } + continue + } + + models[model.Name] = struct{}{} + c.podToModelAdapterMapping[pod] = models + } +} + +func (c *Cache) deleteModelAdapterMapping(model *modelv1alpha1.ModelAdapter) { + for _, pod := range model.Status.Instances { + modelAdapters := c.podToModelAdapterMapping[pod] + delete(modelAdapters, model.Name) + c.podToModelAdapterMapping[pod] = modelAdapters + } +} + +func (c *Cache) debugInfo() { + for model, instances := range c.modelAdapterToPodMapping { + klog.Infof("modelName: %s, instances: %v", model, instances) + } + + for pod, models := range c.podToModelAdapterMapping { + if !strings.HasPrefix(pod, "llama") { + continue + } + + modelsArr := []string{} + for m := range models { + modelsArr = append(modelsArr, m) + } + + klog.Infof("podName: %s, modelAdapters: %v", pod, modelsArr) + } +} + +func (c *Cache) GetPods() map[string]*v1.Pod { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.pods +} + +func (c *Cache) GetPodToModelAdapterMapping() map[string]map[string]struct{} { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.podToModelAdapterMapping +} diff --git a/pkg/controller/modeladapter/modeladapter_controller.go b/pkg/controller/modeladapter/modeladapter_controller.go index 4377d491..e67a37ab 100644 --- a/pkg/controller/modeladapter/modeladapter_controller.go +++ b/pkg/controller/modeladapter/modeladapter_controller.go @@ -28,6 +28,8 @@ import ( "time" modelv1alpha1 "github.com/aibrix/aibrix/api/model/v1alpha1" + "github.com/aibrix/aibrix/pkg/cache" + "github.com/aibrix/aibrix/pkg/controller/modeladapter/scheduling" corev1 "k8s.io/api/core/v1" discoveryv1 "k8s.io/api/discovery/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -112,6 +114,13 @@ func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) { eventBroadcaster.StartRecordingToSink(&clientv1core.EventSinkImpl{Interface: k8sClient.CoreV1().Events("")}) recorder := eventBroadcaster.NewRecorder(mgr.GetScheme(), corev1.EventSource{Component: "model-adapter-controller"}) + c, err := cache.GetCache() + if err != nil { + klog.Fatal(err.Error()) + } + + scheduler := scheduling.NewLeastAdapters(c) + reconciler := &ModelAdapterReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -119,6 +128,7 @@ func newReconciler(mgr manager.Manager) (reconcile.Reconciler, error) { ServiceLister: serviceLister, EndpointSliceLister: endpointSliceLister, Recorder: recorder, + scheduler: scheduler, } return reconciler, nil } @@ -142,8 +152,9 @@ var _ reconcile.Reconciler = &ModelAdapterReconciler{} // ModelAdapterReconciler reconciles a ModelAdapter object type ModelAdapterReconciler struct { client.Client - Scheme *runtime.Scheme - Recorder record.EventRecorder + Scheme *runtime.Scheme + Recorder record.EventRecorder + scheduler scheduling.Scheduler // PodLister is able to list/get pods from a shared informer's cache store PodLister corelisters.PodLister // ServiceLister is able to list/get services from a shared informer's cache store @@ -394,7 +405,7 @@ func (r *ModelAdapterReconciler) schedulePod(ctx context.Context, instance *mode // TODO: let's build the scheduling algorithm later // we should also fetch > mappings later. - return &podList.Items[0], nil // Returning the first Pod for simplicity + return r.scheduler.SelectPod(ctx, podList.Items) } // GetEnvKey retrieves the value of the environment variable named by the key. diff --git a/pkg/controller/modeladapter/scheduling/leastadapters.go b/pkg/controller/modeladapter/scheduling/leastadapters.go new file mode 100644 index 00000000..f3cdf148 --- /dev/null +++ b/pkg/controller/modeladapter/scheduling/leastadapters.go @@ -0,0 +1,60 @@ +/* +Copyright 2024 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "errors" + "math" + + "github.com/aibrix/aibrix/pkg/cache" + v1 "k8s.io/api/core/v1" + "k8s.io/klog" +) + +type leastAdapters struct { + cache *cache.Cache +} + +func NewLeastAdapters(c *cache.Cache) Scheduler { + return leastAdapters{ + cache: c, + } +} + +func (r leastAdapters) SelectPod(ctx context.Context, pods []v1.Pod) (*v1.Pod, error) { + modelAdapterCountMin := math.MaxInt + selectedPod := v1.Pod{} + podMap := r.cache.GetPods() + podToModelAdapterMapping := r.cache.GetPodToModelAdapterMapping() + + for _, pod := range pods { + if _, ok := podMap[pod.Name]; !ok { + return nil, errors.New("pod not found in the cache") + } + + modelAdapters := podToModelAdapterMapping[pod.Name] + if len(modelAdapters) < modelAdapterCountMin { + selectedPod = pod + modelAdapterCountMin = len(modelAdapters) + } + } + + klog.Infof("pod selected with least model adapters: %s", selectedPod.Name) + + return &selectedPod, nil +} diff --git a/pkg/controller/modeladapter/scheduling/scheduler.go b/pkg/controller/modeladapter/scheduling/scheduler.go new file mode 100644 index 00000000..c5120e99 --- /dev/null +++ b/pkg/controller/modeladapter/scheduling/scheduler.go @@ -0,0 +1,28 @@ +/* +Copyright 2024 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + + v1 "k8s.io/api/core/v1" +) + +type Scheduler interface { + // Returns the pod to schedule model adapter + SelectPod(ctx context.Context, pods []v1.Pod) (*v1.Pod, error) +}