From fc1a4947b18ae87c1edfc75378e25edade611a9a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 6 Sep 2022 18:56:28 -0700 Subject: [PATCH] feat: Make list of supported built-in runtime server types overridable (#220) #### Motivation Model-mesh serving makes use of a single adapter image to interface with a number of "built-in" model server types, currently comprising `triton`, `mlserver`, `ovms`. Currently which kind of adapter to run is controlled by the entrypoint to the adapter container within the modelmesh pods, which is set by the controller based on the `builtInAdapter.serverType` field of the `ServingRuntime` CRD. However this is currently validated against a hardcoded list of type strings meaning that any extensions to the runtime adapter image to support new kinds of model servers also require a code change and rebuild of the controller image. #### Modifications Move the list of supported built-in server types to a list in the global config. Update controller validation logic and tests accordingly. #### Result The shared built-in runtime adapter can be extended to support new runtime types without requiring a code change to the controller. Signed-off-by: Nick Hill --- config/default/config-defaults.yaml | 4 +++ controllers/servingruntime_controller.go | 6 +--- controllers/servingruntime_validator.go | 32 ++++++++++++++------ controllers/servingruntime_validator_test.go | 6 +++- controllers/suite_test.go | 12 +++++--- fvt/tls.go | 2 +- pkg/config/config.go | 6 ++++ pkg/config/config_test.go | 21 +++++++++++++ 8 files changed, 68 insertions(+), 21 deletions(-) diff --git a/config/default/config-defaults.yaml b/config/default/config-defaults.yaml index 2b583f57..4c93c819 100644 --- a/config/default/config-defaults.yaml +++ b/config/default/config-defaults.yaml @@ -51,3 +51,7 @@ storageHelperResources: serviceAccountName: "" metrics: enabled: true +builtInServerTypes: + - triton + - mlserver + - ovms diff --git a/controllers/servingruntime_controller.go b/controllers/servingruntime_controller.go index bb7753e0..431fe521 100644 --- a/controllers/servingruntime_controller.go +++ b/controllers/servingruntime_controller.go @@ -85,10 +85,6 @@ type runtimeInfo struct { TimeTransitionedToNoPredictors *time.Time } -var builtInServerTypes = map[kserveapi.ServerType]interface{}{ - kserveapi.MLServer: nil, kserveapi.Triton: nil, kserveapi.OVMS: nil, -} - // +kubebuilder:rbac:groups=serving.kserve.io,resources=servingruntimes;servingruntimes/finalizers,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=serving.kserve.io,resources=servingruntimes/status,verbs=get;update;patch // +kubebuilder:rbac:groups=apps,resources=deployments;deployments/finalizers,verbs=get;list;watch;create;update;patch;delete @@ -188,7 +184,7 @@ func (r *ServingRuntimeReconciler) Reconcile(ctx context.Context, req ctrl.Reque } // Check that ServerType is provided in rt.Spec and that this value matches that of the specified container - if err = validateServingRuntimeSpec(rt); err != nil { + if err = validateServingRuntimeSpec(rt, cfg); err != nil { return ctrl.Result{}, fmt.Errorf("Invalid ServingRuntime Spec: %w", err) } diff --git a/controllers/servingruntime_validator.go b/controllers/servingruntime_validator.go index eb18173d..71a98ab9 100644 --- a/controllers/servingruntime_validator.go +++ b/controllers/servingruntime_validator.go @@ -18,6 +18,8 @@ import ( "fmt" "strings" + "github.com/kserve/modelmesh-serving/pkg/config" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/sets" @@ -33,34 +35,44 @@ import ( // - containers do not mount internal only volumes // - some fields in containers are controlled by model mesh and cannot be set // - check for overlaps in declared ports with internal ports -func validateServingRuntimeSpec(rt *kserveapi.ServingRuntime) error { - return validationChain(rt, +func validateServingRuntimeSpec(rt *kserveapi.ServingRuntime, config *config.Config) error { + return validationChain(rt, config, validateBuiltInAdapterSpec, validateContainers, validateVolumes, ) } -func validationChain(rt *kserveapi.ServingRuntime, funcs ...func(*kserveapi.ServingRuntime) error) error { +func validationChain(rt *kserveapi.ServingRuntime, config *config.Config, + funcs ...func(*kserveapi.ServingRuntime, *config.Config) error) error { for _, f := range funcs { - if err := f(rt); err != nil { + if err := f(rt, config); err != nil { return err } } return nil } -func validateBuiltInAdapterSpec(rt *kserveapi.ServingRuntime) error { +func validateBuiltInAdapterSpec(rt *kserveapi.ServingRuntime, config *config.Config) error { if rt.Spec.BuiltInAdapter == nil { return nil // nothing to check } - st := rt.Spec.BuiltInAdapter.ServerType - if _, ok := builtInServerTypes[st]; !ok { + st := string(rt.Spec.BuiltInAdapter.ServerType) + found := false + if config.BuiltInServerTypes != nil { + for _, bist := range config.BuiltInServerTypes { + if bist == st { + found = true + break + } + } + } + if !found { return fmt.Errorf("unrecognized built-in runtime server type %s", st) } for ic := range rt.Spec.Containers { - if rt.Spec.Containers[ic].Name == string(st) { + if rt.Spec.Containers[ic].Name == st { return nil // found, all good } } @@ -68,7 +80,7 @@ func validateBuiltInAdapterSpec(rt *kserveapi.ServingRuntime) error { return fmt.Errorf("must include runtime container with name %s", st) } -func validateContainers(rt *kserveapi.ServingRuntime) error { +func validateContainers(rt *kserveapi.ServingRuntime, _ *config.Config) error { for i := range rt.Spec.Containers { c := &rt.Spec.Containers[i] if err := validateContainer(c); err != nil { @@ -116,7 +128,7 @@ func validateContainer(c *corev1.Container) error { return nil } -func validateVolumes(rt *kserveapi.ServingRuntime) error { +func validateVolumes(rt *kserveapi.ServingRuntime, _ *config.Config) error { // Block volume names that conflict with injected volumes or reserved prefixes for vi := range rt.Spec.Volumes { if err := checkName(rt.Spec.Volumes[vi].Name, internalVolumes, "volume"); err != nil { diff --git a/controllers/servingruntime_validator_test.go b/controllers/servingruntime_validator_test.go index c3d848a3..98cbfa3b 100644 --- a/controllers/servingruntime_validator_test.go +++ b/controllers/servingruntime_validator_test.go @@ -310,8 +310,12 @@ func TestValidateServingRuntimeSpec(t *testing.T) { expectError: true, }, } { + cfg, err := getDefaultConfig() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } t.Run(tt.name, func(t *testing.T) { - err := validateServingRuntimeSpec(tt.servingRuntime) + err := validateServingRuntimeSpec(tt.servingRuntime, cfg) if tt.expectError && err == nil { t.Errorf("Expected an error, but didn't get one") diff --git a/controllers/suite_test.go b/controllers/suite_test.go index 2b88ec7c..72c4759f 100644 --- a/controllers/suite_test.go +++ b/controllers/suite_test.go @@ -206,15 +206,19 @@ var _ = AfterEach(func() { var defaultTestConfigFileContents []byte -func resetReconcilerConfig() { +func getDefaultConfig() (*config2.Config, error) { if defaultTestConfigFileContents == nil { var err error var testConfigFile = "./testdata/test-config-defaults.yaml" - defaultTestConfigFileContents, err = ioutil.ReadFile(testConfigFile) - Expect(err).ToNot(HaveOccurred()) + if defaultTestConfigFileContents, err = ioutil.ReadFile(testConfigFile); err != nil { + return nil, err + } } + return config2.NewMergedConfigFromString(string(defaultTestConfigFileContents)) +} - config, err := config2.NewMergedConfigFromString(string(defaultTestConfigFileContents)) +func resetReconcilerConfig() { + config, err := getDefaultConfig() Expect(err).ToNot(HaveOccurred()) // re-assign the reference to the config diff --git a/fvt/tls.go b/fvt/tls.go index 62cce1bb..4a99f518 100644 --- a/fvt/tls.go +++ b/fvt/tls.go @@ -110,7 +110,7 @@ func (g *CertGenerator) generate() error { } if err = pem.Encode(g.PrivateKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", + Type: "PRIVATE KEY", Bytes: privBytes, }); err != nil { return err diff --git a/pkg/config/config.go b/pkg/config/config.go index 0ac37105..b55da5c3 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -24,6 +24,8 @@ import ( "sync/atomic" "unsafe" + kserveapi "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -72,6 +74,7 @@ type Config struct { PodsPerRuntime uint16 StorageSecretName string EnableAccessLogging bool + BuiltInServerTypes []string ServiceAccountName string @@ -335,6 +338,9 @@ func defaults(v *viper.Viper) { v.SetDefault(concatStringsWithDelimiter([]string{"ScaleToZero", "GracePeriodSeconds"}), 60) // default size 16MiB in bytes v.SetDefault("GrpcMaxMessageSizeBytes", 16777216) + v.SetDefault("BuiltInServerTypes", []string{ + string(kserveapi.MLServer), string(kserveapi.Triton), string(kserveapi.OVMS), + }) } func concatStringsWithDelimiter(elems []string) string { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index eca22da4..4c2fd292 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -14,6 +14,7 @@ package config import ( + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -195,6 +196,26 @@ grpcMaxMessageSizeBytes: 33554432` } } +func TestBuiltInServerTypes(t *testing.T) { + yaml := ` +builtInServerTypes: + - triton + - mlserver + - ovms + - a_new_one` + + expectedTypes := []string{"triton", "mlserver", "ovms", "a_new_one"} + + conf, err := NewMergedConfigFromString(yaml) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expectedTypes, conf.BuiltInServerTypes) { + t.Fatalf("Expected BuiltInServerTypes=%v but found %v", expectedTypes, conf.BuiltInServerTypes) + } +} + func TestResourceRequirements(t *testing.T) { rr := ResourceRequirements{ Requests: ResourceQuantities{