diff --git a/backend/src/apiserver/BUILD.bazel b/backend/src/apiserver/BUILD.bazel index a3b7950ab997..cddbf491fbae 100644 --- a/backend/src/apiserver/BUILD.bazel +++ b/backend/src/apiserver/BUILD.bazel @@ -18,7 +18,6 @@ go_library( "//backend/src/apiserver/server:go_default_library", "//backend/src/apiserver/storage:go_default_library", "//backend/src/common/util:go_default_library", - "//backend/src/crd/pkg/client/clientset/versioned/typed/scheduledworkflow/v1beta1:go_default_library", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_fsnotify_fsnotify//:go_default_library", "@com_github_golang_glog//:go_default_library", diff --git a/backend/src/apiserver/client/BUILD.bazel b/backend/src/apiserver/client/BUILD.bazel index 4163defc388c..e69b20775068 100644 --- a/backend/src/apiserver/client/BUILD.bazel +++ b/backend/src/apiserver/client/BUILD.bazel @@ -11,14 +11,18 @@ go_library( "kubernetes_core_fake.go", "minio.go", "pod_fake.go", - "scheduled_workflow.go", + "scheduled_workflow_fake.go", "sql.go", + "swf.go", + "swf_fake.go", "workflow_fake.go", ], importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/client", visibility = ["//visibility:public"], deps = [ + "//backend/src/apiserver/common:go_default_library", "//backend/src/common/util:go_default_library", + "//backend/src/crd/pkg/apis/scheduledworkflow/v1beta1:go_default_library", "//backend/src/crd/pkg/client/clientset/versioned:go_default_library", "//backend/src/crd/pkg/client/clientset/versioned/typed/scheduledworkflow/v1beta1:go_default_library", "@com_github_argoproj_argo//pkg/apis/workflow/v1alpha1:go_default_library", diff --git a/backend/src/apiserver/resource/scheduled_workflow_fake.go b/backend/src/apiserver/client/scheduled_workflow_fake.go similarity index 79% rename from backend/src/apiserver/resource/scheduled_workflow_fake.go rename to backend/src/apiserver/client/scheduled_workflow_fake.go index 47baefd7153e..a8fe234386eb 100644 --- a/backend/src/apiserver/resource/scheduled_workflow_fake.go +++ b/backend/src/apiserver/client/scheduled_workflow_fake.go @@ -12,38 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package resource +package client import ( "errors" "github.com/golang/glog" "github.com/kubeflow/pipelines/backend/src/crd/pkg/apis/scheduledworkflow/v1beta1" - "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/watch" ) type FakeScheduledWorkflowClient struct { - workflows map[string]*v1beta1.ScheduledWorkflow + scheduledWorkflows map[string]*v1beta1.ScheduledWorkflow } func NewScheduledWorkflowClientFake() *FakeScheduledWorkflowClient { return &FakeScheduledWorkflowClient{ - workflows: make(map[string]*v1beta1.ScheduledWorkflow), + scheduledWorkflows: make(map[string]*v1beta1.ScheduledWorkflow), } } -func (c *FakeScheduledWorkflowClient) Create(workflow *v1beta1.ScheduledWorkflow) (*v1beta1.ScheduledWorkflow, error) { - workflow.UID = "123" - workflow.Namespace = "default" - workflow.Name = workflow.GenerateName - c.workflows[workflow.Name] = workflow - return workflow, nil +func (c *FakeScheduledWorkflowClient) Create(scheduledWorkflow *v1beta1.ScheduledWorkflow) (*v1beta1.ScheduledWorkflow, error) { + scheduledWorkflow.UID = "123e4567-e89b-12d3-a456-426655440000" + scheduledWorkflow.Namespace = "ns1" + scheduledWorkflow.Name = scheduledWorkflow.GenerateName + c.scheduledWorkflows[scheduledWorkflow.Name] = scheduledWorkflow + return scheduledWorkflow, nil } func (c *FakeScheduledWorkflowClient) Delete(name string, options *v1.DeleteOptions) error { - delete(c.workflows, name) + delete(c.scheduledWorkflows, name) return nil } @@ -52,9 +52,9 @@ func (c *FakeScheduledWorkflowClient) Patch(name string, pt types.PatchType, dat } func (c *FakeScheduledWorkflowClient) Get(name string, options v1.GetOptions) (*v1beta1.ScheduledWorkflow, error) { - workflow, ok := c.workflows[name] + scheduledWorkflow, ok := c.scheduledWorkflows[name] if ok { - return workflow, nil + return scheduledWorkflow, nil } return nil, errors.New("not found") } diff --git a/backend/src/apiserver/client/scheduled_workflow.go b/backend/src/apiserver/client/swf.go similarity index 69% rename from backend/src/apiserver/client/scheduled_workflow.go rename to backend/src/apiserver/client/swf.go index 4a88df8ed01d..b6018a1e37b9 100644 --- a/backend/src/apiserver/client/scheduled_workflow.go +++ b/backend/src/apiserver/client/swf.go @@ -24,16 +24,28 @@ import ( "k8s.io/client-go/rest" ) +type SwfClientInterface interface { + ScheduledWorkflow(namespace string) v1beta1.ScheduledWorkflowInterface +} + +type SwfClient struct { + swfV1beta1Client v1beta1.ScheduledworkflowV1beta1Interface +} + +func (swfClient *SwfClient) ScheduledWorkflow(namespace string) v1beta1.ScheduledWorkflowInterface { + return swfClient.swfV1beta1Client.ScheduledWorkflows(namespace) +} + // creates a new client for the Kubernetes ScheduledWorkflow CRD. -func CreateScheduledWorkflowClientOrFatal(namespace string, initConnectionTimeout time.Duration) v1beta1.ScheduledWorkflowInterface { - var swfClient v1beta1.ScheduledWorkflowInterface +func NewScheduledWorkflowClientOrFatal(initConnectionTimeout time.Duration) *SwfClient { + var swfClient v1beta1.ScheduledworkflowV1beta1Interface var operation = func() error { restConfig, err := rest.InClusterConfig() if err != nil { return err } swfClientSet := swfclient.NewForConfigOrDie(restConfig) - swfClient = swfClientSet.ScheduledworkflowV1beta1().ScheduledWorkflows(namespace) + swfClient = swfClientSet.ScheduledworkflowV1beta1() return nil } @@ -43,5 +55,5 @@ func CreateScheduledWorkflowClientOrFatal(namespace string, initConnectionTimeou glog.Fatalf("Failed to create scheduled workflow client. Error: %v", err) } - return swfClient + return &SwfClient{swfClient} } diff --git a/backend/src/apiserver/client/swf_fake.go b/backend/src/apiserver/client/swf_fake.go new file mode 100644 index 000000000000..f29c27a2c0d5 --- /dev/null +++ b/backend/src/apiserver/client/swf_fake.go @@ -0,0 +1,47 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "github.com/kubeflow/pipelines/backend/src/common/util" + "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/clientset/versioned/typed/scheduledworkflow/v1beta1" +) + +type FakeSwfClient struct { + scheduledWorkflowClientFake *FakeScheduledWorkflowClient +} + +func NewFakeSwfClient() *FakeSwfClient { + return &FakeSwfClient{NewScheduledWorkflowClientFake()} +} + +func (c *FakeSwfClient) ScheduledWorkflow(namespace string) v1beta1.ScheduledWorkflowInterface { + if len(namespace) == 0 { + panic(util.NewResourceNotFoundError("Namespace", namespace)) + } + return c.scheduledWorkflowClientFake +} + +type FakeSwfClientWithBadWorkflow struct { + scheduledWorkflowClientFake *FakeBadScheduledWorkflowClient +} + +func NewFakeSwfClientWithBadWorkflow() *FakeSwfClientWithBadWorkflow { + return &FakeSwfClientWithBadWorkflow{&FakeBadScheduledWorkflowClient{}} +} + +func (c *FakeSwfClientWithBadWorkflow) ScheduledWorkflow(namespace string) v1beta1.ScheduledWorkflowInterface { + return c.scheduledWorkflowClientFake +} diff --git a/backend/src/apiserver/client_manager.go b/backend/src/apiserver/client_manager.go index 6bbac153e515..774cd4c558d8 100644 --- a/backend/src/apiserver/client_manager.go +++ b/backend/src/apiserver/client_manager.go @@ -29,7 +29,6 @@ import ( "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/apiserver/storage" "github.com/kubeflow/pipelines/backend/src/common/util" - scheduledworkflowclient "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/clientset/versioned/typed/scheduledworkflow/v1beta1" "github.com/minio/minio-go" ) @@ -68,7 +67,7 @@ type ClientManager struct { defaultExperimentStore storage.DefaultExperimentStoreInterface objectStore storage.ObjectStoreInterface argoClient client.ArgoClientInterface - swfClient scheduledworkflowclient.ScheduledWorkflowInterface + swfClient client.SwfClientInterface k8sCoreClient client.KubernetesCoreInterface kfamClient client.KFAMClientInterface time util.TimeInterface @@ -111,7 +110,7 @@ func (c *ClientManager) ArgoClient() client.ArgoClientInterface { return c.argoClient } -func (c *ClientManager) ScheduledWorkflow() scheduledworkflowclient.ScheduledWorkflowInterface { +func (c *ClientManager) SwfClient() client.SwfClientInterface { return c.swfClient } @@ -152,8 +151,7 @@ func (c *ClientManager) init() { c.argoClient = client.NewArgoClientOrFatal(common.GetDurationConfig(initConnectionTimeout)) - c.swfClient = client.CreateScheduledWorkflowClientOrFatal( - common.GetPodNamespace(), common.GetDurationConfig(initConnectionTimeout)) + c.swfClient = client.NewScheduledWorkflowClientOrFatal(common.GetDurationConfig(initConnectionTimeout)) c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout)) diff --git a/backend/src/apiserver/resource/BUILD.bazel b/backend/src/apiserver/resource/BUILD.bazel index 9274c460470d..576d567ff4ec 100644 --- a/backend/src/apiserver/resource/BUILD.bazel +++ b/backend/src/apiserver/resource/BUILD.bazel @@ -7,7 +7,6 @@ go_library( "model_converter.go", "resource_manager.go", "resource_manager_util.go", - "scheduled_workflow_fake.go", ], importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/resource", visibility = ["//visibility:public"], @@ -30,7 +29,6 @@ go_library( "@io_k8s_apimachinery//pkg/api/errors:go_default_library", "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", "@io_k8s_apimachinery//pkg/types:go_default_library", - "@io_k8s_apimachinery//pkg/watch:go_default_library", ], ) diff --git a/backend/src/apiserver/resource/client_manager_fake.go b/backend/src/apiserver/resource/client_manager_fake.go index 5cdf7fd4dc59..90b8fddb24b4 100644 --- a/backend/src/apiserver/resource/client_manager_fake.go +++ b/backend/src/apiserver/resource/client_manager_fake.go @@ -19,7 +19,6 @@ import ( "github.com/kubeflow/pipelines/backend/src/apiserver/client" "github.com/kubeflow/pipelines/backend/src/apiserver/storage" "github.com/kubeflow/pipelines/backend/src/common/util" - scheduledworkflowclient "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/clientset/versioned/typed/scheduledworkflow/v1beta1" ) const ( @@ -28,21 +27,21 @@ const ( ) type FakeClientManager struct { - db *storage.DB - experimentStore storage.ExperimentStoreInterface - pipelineStore storage.PipelineStoreInterface - jobStore storage.JobStoreInterface - runStore storage.RunStoreInterface - resourceReferenceStore storage.ResourceReferenceStoreInterface - dBStatusStore storage.DBStatusStoreInterface - defaultExperimentStore storage.DefaultExperimentStoreInterface - objectStore storage.ObjectStoreInterface - ArgoClientFake *client.FakeArgoClient - scheduledWorkflowClientFake *FakeScheduledWorkflowClient - k8sCoreClientFake *client.FakeKuberneteCoreClient - KfamClientFake client.KFAMClientInterface - time util.TimeInterface - uuid util.UUIDGeneratorInterface + db *storage.DB + experimentStore storage.ExperimentStoreInterface + pipelineStore storage.PipelineStoreInterface + jobStore storage.JobStoreInterface + runStore storage.RunStoreInterface + resourceReferenceStore storage.ResourceReferenceStoreInterface + dBStatusStore storage.DBStatusStoreInterface + defaultExperimentStore storage.DefaultExperimentStoreInterface + objectStore storage.ObjectStoreInterface + ArgoClientFake *client.FakeArgoClient + swfClientFake *client.FakeSwfClient + k8sCoreClientFake *client.FakeKuberneteCoreClient + KfamClientFake client.KFAMClientInterface + time util.TimeInterface + uuid util.UUIDGeneratorInterface } func NewFakeClientManager(time util.TimeInterface, uuid util.UUIDGeneratorInterface) ( @@ -64,21 +63,21 @@ func NewFakeClientManager(time util.TimeInterface, uuid util.UUIDGeneratorInterf // TODO(neuromage): Pass in metadata.Store instance for tests as well. return &FakeClientManager{ - db: db, - experimentStore: storage.NewExperimentStore(db, time, uuid), - pipelineStore: storage.NewPipelineStore(db, time, uuid), - jobStore: storage.NewJobStore(db, time), - runStore: storage.NewRunStore(db, time), - ArgoClientFake: client.NewFakeArgoClient(), - resourceReferenceStore: storage.NewResourceReferenceStore(db), - dBStatusStore: storage.NewDBStatusStore(db), - defaultExperimentStore: storage.NewDefaultExperimentStore(db), - objectStore: storage.NewFakeObjectStore(), - scheduledWorkflowClientFake: NewScheduledWorkflowClientFake(), - k8sCoreClientFake: client.NewFakeKuberneteCoresClient(), - KfamClientFake: client.NewFakeKFAMClientAuthorized(), - time: time, - uuid: uuid, + db: db, + experimentStore: storage.NewExperimentStore(db, time, uuid), + pipelineStore: storage.NewPipelineStore(db, time, uuid), + jobStore: storage.NewJobStore(db, time), + runStore: storage.NewRunStore(db, time), + ArgoClientFake: client.NewFakeArgoClient(), + resourceReferenceStore: storage.NewResourceReferenceStore(db), + dBStatusStore: storage.NewDBStatusStore(db), + defaultExperimentStore: storage.NewDefaultExperimentStore(db), + objectStore: storage.NewFakeObjectStore(), + swfClientFake: client.NewFakeSwfClient(), + k8sCoreClientFake: client.NewFakeKuberneteCoresClient(), + KfamClientFake: client.NewFakeKFAMClientAuthorized(), + time: time, + uuid: uuid, }, nil } @@ -139,8 +138,8 @@ func (f *FakeClientManager) DefaultExperimentStore() storage.DefaultExperimentSt return f.defaultExperimentStore } -func (f *FakeClientManager) ScheduledWorkflow() scheduledworkflowclient.ScheduledWorkflowInterface { - return f.scheduledWorkflowClientFake +func (f *FakeClientManager) SwfClient() client.SwfClientInterface { + return f.swfClientFake } func (f *FakeClientManager) KubernetesCoreClient() client.KubernetesCoreInterface { diff --git a/backend/src/apiserver/resource/resource_manager.go b/backend/src/apiserver/resource/resource_manager.go index 339ccdc391a6..666e563fb665 100644 --- a/backend/src/apiserver/resource/resource_manager.go +++ b/backend/src/apiserver/resource/resource_manager.go @@ -57,7 +57,7 @@ type ClientManagerInterface interface { DefaultExperimentStore() storage.DefaultExperimentStoreInterface ObjectStore() storage.ObjectStoreInterface ArgoClient() client.ArgoClientInterface - ScheduledWorkflow() scheduledworkflowclient.ScheduledWorkflowInterface + SwfClient() client.SwfClientInterface KubernetesCoreClient() client.KubernetesCoreInterface KFAMClient() client.KFAMClientInterface Time() util.TimeInterface @@ -65,38 +65,38 @@ type ClientManagerInterface interface { } type ResourceManager struct { - experimentStore storage.ExperimentStoreInterface - pipelineStore storage.PipelineStoreInterface - jobStore storage.JobStoreInterface - runStore storage.RunStoreInterface - resourceReferenceStore storage.ResourceReferenceStoreInterface - dBStatusStore storage.DBStatusStoreInterface - defaultExperimentStore storage.DefaultExperimentStoreInterface - objectStore storage.ObjectStoreInterface - argoClient client.ArgoClientInterface - scheduledWorkflowClient scheduledworkflowclient.ScheduledWorkflowInterface - k8sCoreClient client.KubernetesCoreInterface - kfamClient client.KFAMClientInterface - time util.TimeInterface - uuid util.UUIDGeneratorInterface + experimentStore storage.ExperimentStoreInterface + pipelineStore storage.PipelineStoreInterface + jobStore storage.JobStoreInterface + runStore storage.RunStoreInterface + resourceReferenceStore storage.ResourceReferenceStoreInterface + dBStatusStore storage.DBStatusStoreInterface + defaultExperimentStore storage.DefaultExperimentStoreInterface + objectStore storage.ObjectStoreInterface + argoClient client.ArgoClientInterface + swfClient client.SwfClientInterface + k8sCoreClient client.KubernetesCoreInterface + kfamClient client.KFAMClientInterface + time util.TimeInterface + uuid util.UUIDGeneratorInterface } func NewResourceManager(clientManager ClientManagerInterface) *ResourceManager { return &ResourceManager{ - experimentStore: clientManager.ExperimentStore(), - pipelineStore: clientManager.PipelineStore(), - jobStore: clientManager.JobStore(), - runStore: clientManager.RunStore(), - resourceReferenceStore: clientManager.ResourceReferenceStore(), - dBStatusStore: clientManager.DBStatusStore(), - defaultExperimentStore: clientManager.DefaultExperimentStore(), - objectStore: clientManager.ObjectStore(), - argoClient: clientManager.ArgoClient(), - scheduledWorkflowClient: clientManager.ScheduledWorkflow(), - k8sCoreClient: clientManager.KubernetesCoreClient(), - kfamClient: clientManager.KFAMClient(), - time: clientManager.Time(), - uuid: clientManager.UUID(), + experimentStore: clientManager.ExperimentStore(), + pipelineStore: clientManager.PipelineStore(), + jobStore: clientManager.JobStore(), + runStore: clientManager.RunStore(), + resourceReferenceStore: clientManager.ResourceReferenceStore(), + dBStatusStore: clientManager.DBStatusStore(), + defaultExperimentStore: clientManager.DefaultExperimentStore(), + objectStore: clientManager.ObjectStore(), + argoClient: clientManager.ArgoClient(), + swfClient: clientManager.SwfClient(), + k8sCoreClient: clientManager.KubernetesCoreClient(), + kfamClient: clientManager.KFAMClient(), + time: clientManager.Time(), + uuid: clientManager.UUID(), } } @@ -104,6 +104,10 @@ func (r *ResourceManager) getWorkflowClient(namespace string) workflowclient.Wor return r.argoClient.Workflow(namespace) } +func (r *ResourceManager) getScheduledWorkflowClient(namespace string) scheduledworkflowclient.ScheduledWorkflowInterface { + return r.swfClient.ScheduledWorkflow(namespace) +} + func (r *ResourceManager) GetTime() util.TimeInterface { return r.time } @@ -274,15 +278,7 @@ func (r *ResourceManager) CreateRun(apiRun *api.Run) (*model.RunDetail, error) { return nil, util.Wrap(err, "Failed to verify parameters.") } - multiuserMode := common.IsMultiUserMode() - if multiuserMode == true { - if len(workflow.Spec.ServiceAccountName) == 0 || workflow.Spec.ServiceAccountName == defaultPipelineRunnerServiceAccount { - // To reserve SDK backward compatibility, the backend currently replaces the serviceaccount in multi-user mode. - workflow.SetServiceAccount(defaultServiceAccount) - } - } else { - workflow.SetServiceAccount(r.getDefaultSA()) - } + r.setWorkflowServiceAccount(&workflow) // Disable istio sidecar injection workflow.SetAnnotationsToAllTemplates(util.AnnotationKeyIstioSidecarInject, util.AnnotationValueIstioSidecarInjectDisabled) @@ -319,34 +315,20 @@ func (r *ResourceManager) CreateRun(apiRun *api.Run) (*model.RunDetail, error) { } } - resourceReferences := apiRun.GetResourceReferences() - experimentID := common.GetExperimentIDFromAPIResourceReferences(resourceReferences) - if len(experimentID) == 0 { - if multiuserMode { - return nil, util.NewInternalServerError(errors.New("Missing experiment"), "Experiment is required for CreateRun/CreateJob.") - } else { - // Add a reference to the default experiment - ref, err := r.getDefaultExperimentResourceReference(resourceReferences) - if err != nil { - return nil, util.Wrap(err, "Failed to create run.") - } - apiRun.ResourceReferences = append(apiRun.ResourceReferences, ref) - experimentID = ref.GetKey().GetId() - } - } - experiment, err := r.GetExperiment(experimentID) + // Add a reference to the default experiment if run does not already have a containing experiment + ref, err := r.getDefaultExperimentIfNoExperiment(apiRun.GetResourceReferences()) if err != nil { - return nil, util.NewInternalServerError(err, "Failed to get experiment.") + return nil, err + } + if ref != nil { + apiRun.ResourceReferences = append(apiRun.GetResourceReferences(), ref) } - namespace := experiment.Namespace - if len(namespace) == 0 { - if multiuserMode { - return nil, util.NewInternalServerError(errors.New("Missing namespace"), "Experiment %v doesn't have a namespace.", experiment.Name) - } else { - namespace = common.GetPodNamespace() - } + namespace, err := r.getNamespaceFromExperiment(apiRun.GetResourceReferences()) + if err != nil { + return nil, err } + // Create argo workflow CRD resource newWorkflow, err := r.getWorkflowClient(namespace).Create(workflow.Get()) if err != nil { @@ -541,14 +523,16 @@ func (r *ResourceManager) CreateJob(apiJob *api.Job) (*model.Job, error) { if err != nil { return nil, util.Wrap(err, "Create job failed") } + + r.setWorkflowServiceAccount(&workflow) + + // Disable istio sidecar injection + workflow.SetAnnotationsToAllTemplates(util.AnnotationKeyIstioSidecarInject, util.AnnotationValueIstioSidecarInjectDisabled) + swfGeneratedName, err := toSWFCRDResourceGeneratedName(apiJob.Name) if err != nil { return nil, util.Wrap(err, "Create job failed") } - - // Set workflow to be run using default pipeline runner service account. - workflow.SetServiceAccount(r.getDefaultSA()) - scheduledWorkflow := &scheduledworkflow.ScheduledWorkflow{ ObjectMeta: v1.ObjectMeta{GenerateName: swfGeneratedName}, Spec: scheduledworkflow.ScheduledWorkflowSpec{ @@ -573,11 +557,6 @@ func (r *ResourceManager) CreateJob(apiJob *api.Job) (*model.Job, error) { } } - newScheduledWorkflow, err := r.scheduledWorkflowClient.Create(scheduledWorkflow) - if err != nil { - return nil, util.NewInternalServerError(err, "Failed to create a scheduled workflow for (%s)", scheduledWorkflow.Name) - } - // Add a reference to the default experiment if run does not already have a containing experiment ref, err := r.getDefaultExperimentIfNoExperiment(apiJob.GetResourceReferences()) if err != nil { @@ -587,6 +566,16 @@ func (r *ResourceManager) CreateJob(apiJob *api.Job) (*model.Job, error) { apiJob.ResourceReferences = append(apiJob.GetResourceReferences(), ref) } + namespace, err := r.getNamespaceFromExperiment(apiJob.GetResourceReferences()) + if err != nil { + return nil, err + } + + newScheduledWorkflow, err := r.getScheduledWorkflowClient(namespace).Create(scheduledWorkflow) + if err != nil { + return nil, util.NewInternalServerError(err, "Failed to create a scheduled workflow for (%s)", scheduledWorkflow.Name) + } + job, err := r.ToModelJob(apiJob, util.NewScheduledWorkflow(newScheduledWorkflow), string(workflowSpecManifestBytes)) if err != nil { return nil, util.Wrap(err, "Create job failed") @@ -603,7 +592,8 @@ func (r *ResourceManager) EnableJob(jobID string, enabled bool) error { if err != nil { return util.Wrap(err, "Enable/Disable job failed") } - _, err = r.scheduledWorkflowClient.Patch( + + _, err = r.getScheduledWorkflowClient(job.Namespace).Patch( job.Name, types.MergePatchType, []byte(fmt.Sprintf(`{"spec":{"enabled":%s}}`, strconv.FormatBool(enabled)))) @@ -627,7 +617,8 @@ func (r *ResourceManager) DeleteJob(jobID string) error { if err != nil { return util.Wrap(err, "Delete job failed") } - err = r.scheduledWorkflowClient.Delete(job.Name, &v1.DeleteOptions{}) + + err = r.getScheduledWorkflowClient(job.Namespace).Delete(job.Name, &v1.DeleteOptions{}) if err != nil { return util.NewInternalServerError(err, "Delete job CRD failed.") } @@ -764,7 +755,8 @@ func (r *ResourceManager) checkJobExist(jobID string) (*model.Job, error) { if err != nil { return nil, util.Wrap(err, "Check job exist failed") } - scheduledWorkflow, err := r.scheduledWorkflowClient.Get(job.Name, v1.GetOptions{}) + + scheduledWorkflow, err := r.getScheduledWorkflowClient(job.Namespace).Get(job.Name, v1.GetOptions{}) if err != nil { return nil, util.NewInternalServerError(err, "Check job exist failed") } @@ -862,6 +854,9 @@ func (r *ResourceManager) getDefaultExperimentIfNoExperiment(references []*api.R return nil, nil } } + if common.IsMultiUserMode() { + return nil, util.NewInvalidInputError("Experiment is required in resource references.") + } return r.getDefaultExperimentResourceReference(references) } @@ -1050,15 +1045,41 @@ func (r *ResourceManager) GetNamespaceFromRunID(runId string) (string, error) { if err != nil { return "", util.Wrap(err, "Failed to get namespace from run id.") } - namespace := runDetail.Namespace + return runDetail.Namespace, nil +} + +func (r *ResourceManager) GetNamespaceFromJobID(jobId string) (string, error) { + job, err := r.GetJob(jobId) + if err != nil { + return "", util.Wrap(err, "Failed to get namespace from Job ID.") + } + return job.Namespace, nil +} + +func (r *ResourceManager) setWorkflowServiceAccount(workflow *util.Workflow) { + if common.IsMultiUserMode() { + if len(workflow.Spec.ServiceAccountName) == 0 || workflow.Spec.ServiceAccountName == defaultPipelineRunnerServiceAccount { + // To reserve SDK backward compatibility, the backend currently replaces the serviceaccount in multi-user mode. + workflow.SetServiceAccount(defaultServiceAccount) + } + } else { + workflow.SetServiceAccount(r.getDefaultSA()) + } +} + +func (r *ResourceManager) getNamespaceFromExperiment(references []*api.ResourceReference) (string, error) { + experimentID := common.GetExperimentIDFromAPIResourceReferences(references) + experiment, err := r.GetExperiment(experimentID) + if err != nil { + return "", util.NewInternalServerError(err, "Failed to get experiment.") + } + + namespace := experiment.Namespace if len(namespace) == 0 { if common.IsMultiUserMode() { - // All runs should have namespace in multi user mode. - return "", errors.New("Invalid db data: run_details doesn't have a namespace") + return "", util.NewInternalServerError(errors.New("Missing namespace"), "Experiment %v doesn't have a namespace.", experiment.Name) } else { - // When db model doesn't have namespace stored (e.g. legacy runs), use - // pod namespace as default. - return common.GetPodNamespace(), nil + namespace = common.GetPodNamespace() } } return namespace, nil diff --git a/backend/src/apiserver/resource/resource_manager_test.go b/backend/src/apiserver/resource/resource_manager_test.go index 35cf2bfeeadc..6ac7ac730275 100644 --- a/backend/src/apiserver/resource/resource_manager_test.go +++ b/backend/src/apiserver/resource/resource_manager_test.go @@ -875,10 +875,10 @@ func TestCreateJob_ThroughWorkflowSpec(t *testing.T) { store, _, job := initWithJob(t) defer store.Close() expectedJob := &model.Job{ - UUID: "123", + UUID: "123e4567-e89b-12d3-a456-426655440000", DisplayName: "j1", Name: "j1", - Namespace: "default", + Namespace: "ns1", Enabled: true, CreatedAtInSec: 2, UpdatedAtInSec: 2, @@ -888,7 +888,7 @@ func TestCreateJob_ThroughWorkflowSpec(t *testing.T) { }, ResourceReferences: []*model.ResourceReference{ { - ResourceUUID: "123", + ResourceUUID: "123e4567-e89b-12d3-a456-426655440000", ResourceType: common.Job, ReferenceUUID: DefaultFakeUUID, ReferenceName: "e1", @@ -923,10 +923,10 @@ func TestCreateJob_ThroughPipelineID(t *testing.T) { } newJob, err := manager.CreateJob(job) expectedJob := &model.Job{ - UUID: "123", + UUID: "123e4567-e89b-12d3-a456-426655440000", DisplayName: "j1", Name: "j1", - Namespace: "default", + Namespace: "ns1", Enabled: true, CreatedAtInSec: 3, UpdatedAtInSec: 3, @@ -939,7 +939,7 @@ func TestCreateJob_ThroughPipelineID(t *testing.T) { }, ResourceReferences: []*model.ResourceReference{ { - ResourceUUID: "123", + ResourceUUID: "123e4567-e89b-12d3-a456-426655440000", ResourceType: common.Job, ReferenceUUID: experiment.UUID, ReferenceName: "e1", @@ -994,10 +994,10 @@ func TestCreateJob_ThroughPipelineVersion(t *testing.T) { } newJob, err := manager.CreateJob(job) expectedJob := &model.Job{ - UUID: "123", + UUID: "123e4567-e89b-12d3-a456-426655440000", DisplayName: "j1", Name: "j1", - Namespace: "default", + Namespace: "ns1", Enabled: true, CreatedAtInSec: 4, UpdatedAtInSec: 4, @@ -1008,7 +1008,7 @@ func TestCreateJob_ThroughPipelineVersion(t *testing.T) { }, ResourceReferences: []*model.ResourceReference{ { - ResourceUUID: "123", + ResourceUUID: "123e4567-e89b-12d3-a456-426655440000", ResourceType: common.Job, ReferenceUUID: experiment.UUID, ReferenceName: "e1", @@ -1016,7 +1016,7 @@ func TestCreateJob_ThroughPipelineVersion(t *testing.T) { Relationship: common.Owner, }, { - ResourceUUID: "123", + ResourceUUID: "123e4567-e89b-12d3-a456-426655440000", ResourceType: common.Job, ReferenceUUID: version.UUID, ReferenceName: "version_for_job", @@ -1087,7 +1087,7 @@ func TestCreateJob_ExtraInputParameterError(t *testing.T) { func TestCreateJob_FailedToCreateScheduleWorkflow(t *testing.T) { store, manager, p := initWithPipeline(t) defer store.Close() - manager.scheduledWorkflowClient = &FakeBadScheduledWorkflowClient{} + manager.swfClient = client.NewFakeSwfClientWithBadWorkflow() job := &api.Job{ Name: "pp1", Enabled: true, @@ -1104,10 +1104,10 @@ func TestEnableJob(t *testing.T) { err := manager.EnableJob(job.UUID, false) job, err = manager.GetJob(job.UUID) expectedJob := &model.Job{ - UUID: "123", + UUID: "123e4567-e89b-12d3-a456-426655440000", DisplayName: "j1", Name: "j1", - Namespace: "default", + Namespace: "ns1", Enabled: false, CreatedAtInSec: 2, UpdatedAtInSec: 3, @@ -1117,7 +1117,7 @@ func TestEnableJob(t *testing.T) { }, ResourceReferences: []*model.ResourceReference{ { - ResourceUUID: "123", + ResourceUUID: "123e4567-e89b-12d3-a456-426655440000", ResourceType: common.Job, ReferenceUUID: DefaultFakeUUID, ReferenceName: "e1", @@ -1142,7 +1142,7 @@ func TestEnableJob_JobNotExist(t *testing.T) { func TestEnableJob_CrdFailure(t *testing.T) { store, manager, job := initWithJob(t) defer store.Close() - manager.scheduledWorkflowClient = &FakeBadScheduledWorkflowClient{} + manager.swfClient = client.NewFakeSwfClientWithBadWorkflow() err := manager.EnableJob(job.UUID, false) assert.Equal(t, codes.Internal, err.(*util.UserError).ExternalStatusCode()) assert.Contains(t, err.Error(), "Check job exist failed: some error") @@ -1165,7 +1165,7 @@ func TestDeleteJob(t *testing.T) { _, err = manager.GetJob(job.UUID) assert.Equal(t, codes.NotFound, err.(*util.UserError).ExternalStatusCode()) - assert.Contains(t, err.Error(), "Job 123 not found") + assert.Contains(t, err.Error(), fmt.Sprintf("Job %v not found", job.UUID)) } func TestDeleteJob_JobNotExist(t *testing.T) { @@ -1181,7 +1181,7 @@ func TestDeleteJob_CrdFailure(t *testing.T) { store, manager, job := initWithJob(t) defer store.Close() - manager.scheduledWorkflowClient = &FakeBadScheduledWorkflowClient{} + manager.swfClient = client.NewFakeSwfClientWithBadWorkflow() err := manager.DeleteJob(job.UUID) assert.Equal(t, codes.Internal, err.(*util.UserError).ExternalStatusCode()) assert.Contains(t, err.Error(), "Check job exist failed: some error") diff --git a/backend/src/apiserver/server/job_server.go b/backend/src/apiserver/server/job_server.go index c264afd781f8..d2c4f6cf0bea 100644 --- a/backend/src/apiserver/server/job_server.go +++ b/backend/src/apiserver/server/job_server.go @@ -32,13 +32,15 @@ type JobServer struct { } func (s *JobServer) CreateJob(ctx context.Context, request *api.CreateJobRequest) (*api.Job, error) { - if common.IsMultiUserMode() == true { - return nil, util.NewBadRequestError(errors.New("Job APIs are temporarily disabled in the multi-user mode until it is fully ready."), "Job APIs are temporarily disabled in the multi-user mode until it is fully ready.") - } err := s.validateCreateJobRequest(request) if err != nil { - return nil, err + return nil, util.Wrap(err, "Validate create job request failed.") + } + err = CanAccessExperimentInResourceReferences(s.resourceManager, ctx, request.Job.ResourceReferences) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request.") } + newJob, err := s.resourceManager.CreateJob(request.Job) if err != nil { return nil, err @@ -47,6 +49,11 @@ func (s *JobServer) CreateJob(ctx context.Context, request *api.CreateJobRequest } func (s *JobServer) GetJob(ctx context.Context, request *api.GetJobRequest) (*api.Job, error) { + err := s.canAccessJob(ctx, request.Id) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request.") + } + job, err := s.resourceManager.GetJob(request.Id) if err != nil { return nil, err @@ -65,6 +72,35 @@ func (s *JobServer) ListJobs(ctx context.Context, request *api.ListJobsRequest) if err != nil { return nil, util.Wrap(err, "Validating filter failed.") } + + if common.IsMultiUserMode() { + refKey := filterContext.ReferenceKey + if refKey == nil { + return nil, util.NewInvalidInputError("ListJobs must filter by resource reference in multi-user mode.") + } + if refKey.Type == common.Namespace { + namespace := refKey.ID + if len(namespace) == 0 { + return nil, util.NewInvalidInputError("Invalid resource references for ListJobs. Namespace is empty.") + } + err = isAuthorized(s.resourceManager, ctx, namespace) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize with namespace resource reference.") + } + } else if refKey.Type == common.Experiment { + experimentID := refKey.ID + if len(experimentID) == 0 { + return nil, util.NewInvalidInputError("Invalid resource references for job. Experiment ID is empty.") + } + err = CanAccessExperiment(s.resourceManager, ctx, experimentID) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize with experiment resource reference.") + } + } else { + return nil, util.NewInvalidInputError("Invalid resource references for ListJobs. Got %+v", request.ResourceReferenceKey) + } + } + jobs, total_size, nextPageToken, err := s.resourceManager.ListJobs(filterContext, opts) if err != nil { return nil, util.Wrap(err, "Failed to list jobs.") @@ -73,18 +109,30 @@ func (s *JobServer) ListJobs(ctx context.Context, request *api.ListJobsRequest) } func (s *JobServer) EnableJob(ctx context.Context, request *api.EnableJobRequest) (*empty.Empty, error) { + err := s.canAccessJob(ctx, request.Id) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request.") + } + return s.enableJob(request.Id, true) } func (s *JobServer) DisableJob(ctx context.Context, request *api.DisableJobRequest) (*empty.Empty, error) { + err := s.canAccessJob(ctx, request.Id) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request.") + } + return s.enableJob(request.Id, false) } func (s *JobServer) DeleteJob(ctx context.Context, request *api.DeleteJobRequest) (*empty.Empty, error) { - if common.IsMultiUserMode() == true { - return nil, util.NewBadRequestError(errors.New("Job APIs are temporarily disabled in the multi-user mode until it is fully ready."), "Job APIs are temporarily disabled in the multi-user mode until it is fully ready.") + err := s.canAccessJob(ctx, request.Id) + if err != nil { + return nil, util.Wrap(err, "Failed to authorize the request.") } - err := s.resourceManager.DeleteJob(request.Id) + + err = s.resourceManager.DeleteJob(request.Id) if err != nil { return nil, err } @@ -127,6 +175,26 @@ func (s *JobServer) enableJob(id string, enabled bool) (*empty.Empty, error) { return &empty.Empty{}, nil } +func (s *JobServer) canAccessJob(ctx context.Context, jobID string) error { + if common.IsMultiUserMode() == false { + // Skip authorization if not multi-user mode. + return nil + } + namespace, err := s.resourceManager.GetNamespaceFromJobID(jobID) + if err != nil { + return util.Wrap(err, "Failed to authorize with the job ID.") + } + if len(namespace) == 0 { + return util.NewInternalServerError(errors.New("Empty namespace"), "The job doesn't have a valid namespace.") + } + + err = isAuthorized(s.resourceManager, ctx, namespace) + if err != nil { + return util.Wrap(err, "Failed to authorize with API resource references") + } + return nil +} + func NewJobServer(resourceManager *resource.ResourceManager) *JobServer { return &JobServer{resourceManager: resourceManager} } diff --git a/backend/src/apiserver/server/job_server_test.go b/backend/src/apiserver/server/job_server_test.go index e40e7f7e4ff9..7a8f78a5fc7c 100644 --- a/backend/src/apiserver/server/job_server_test.go +++ b/backend/src/apiserver/server/job_server_test.go @@ -1,22 +1,26 @@ package server import ( + "context" + "strings" "testing" "github.com/golang/protobuf/ptypes/timestamp" + "github.com/google/go-cmp/cmp" api "github.com/kubeflow/pipelines/backend/api/go_client" + "github.com/kubeflow/pipelines/backend/src/apiserver/client" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" ) -func TestValidateApiJob(t *testing.T) { - clients, manager, experiment := initWithExperiment(t) - defer clients.Close() - server := NewJobServer(manager) - apiJob := &api.Job{ - Id: "job1", - Name: "name1", +var ( + commonApiJob = &api.Job{ + Name: "job1", Enabled: true, MaxConcurrency: 1, Trigger: &api.Trigger{ @@ -29,10 +33,44 @@ func TestValidateApiJob(t *testing.T) { Parameters: []*api.Parameter{{Name: "param1", Value: "world"}}, }, ResourceReferences: []*api.ResourceReference{ - {Key: &api.ResourceKey{Type: api.ResourceType_EXPERIMENT, Id: experiment.UUID}, Relationship: api.Relationship_OWNER}, + { + Key: &api.ResourceKey{Type: api.ResourceType_EXPERIMENT, Id: "123e4567-e89b-12d3-a456-426655440000"}, + Relationship: api.Relationship_OWNER, + }, }, } - err := server.validateCreateJobRequest(&api.CreateJobRequest{Job: apiJob}) + + commonExpectedJob = &api.Job{ + Id: "123e4567-e89b-12d3-a456-426655440000", + Name: "job1", + Enabled: true, + MaxConcurrency: 1, + Trigger: &api.Trigger{ + Trigger: &api.Trigger_CronSchedule{CronSchedule: &api.CronSchedule{ + StartTime: ×tamp.Timestamp{Seconds: 1}, + Cron: "1 * * * *", + }}}, + CreatedAt: ×tamp.Timestamp{Seconds: 2}, + UpdatedAt: ×tamp.Timestamp{Seconds: 2}, + Status: "NO_STATUS", + PipelineSpec: &api.PipelineSpec{ + WorkflowManifest: testWorkflow.ToStringForStore(), + Parameters: []*api.Parameter{{Name: "param1", Value: "world"}}, + }, + ResourceReferences: []*api.ResourceReference{ + { + Key: &api.ResourceKey{Type: api.ResourceType_EXPERIMENT, Id: "123e4567-e89b-12d3-a456-426655440000"}, + Name: "exp1", Relationship: api.Relationship_OWNER, + }, + }, + } +) + +func TestValidateApiJob(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + err := server.validateCreateJobRequest(&api.CreateJobRequest{Job: commonApiJob}) assert.Nil(t, err) } @@ -41,8 +79,7 @@ func TestValidateApiJob_WithPipelineVersion(t *testing.T) { defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 1, Trigger: &api.Trigger{ @@ -61,8 +98,7 @@ func TestValidateApiJob_ValidateNoExperimentResourceReferenceSucceeds(t *testing defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 1, Trigger: &api.Trigger{ @@ -85,8 +121,7 @@ func TestValidateApiJob_ValidatePipelineSpecFailed(t *testing.T) { defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 1, Trigger: &api.Trigger{ @@ -112,8 +147,7 @@ func TestValidateApiJob_NoValidPipelineSpecOrPipelineVersion(t *testing.T) { defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 1, Trigger: &api.Trigger{ @@ -133,8 +167,7 @@ func TestValidateApiJob_InvalidCron(t *testing.T) { defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 1, Trigger: &api.Trigger{ @@ -160,8 +193,7 @@ func TestValidateApiJob_MaxConcurrencyOutOfRange(t *testing.T) { defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 0, Trigger: &api.Trigger{ @@ -187,8 +219,7 @@ func TestValidateApiJob_NegativeIntervalSecond(t *testing.T) { defer clients.Close() server := NewJobServer(manager) apiJob := &api.Job{ - Id: "job1", - Name: "name1", + Name: "job1", Enabled: true, MaxConcurrency: 0, Trigger: &api.Trigger{ @@ -207,3 +238,276 @@ func TestValidateApiJob_NegativeIntervalSecond(t *testing.T) { assert.Equal(t, codes.InvalidArgument, err.(*util.UserError).ExternalStatusCode()) assert.Contains(t, err.Error(), "The max concurrency of the job is out of range") } + +func TestCreateJob(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + job, err := server.CreateJob(nil, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + assert.Equal(t, commonExpectedJob, job) +} + +func TestCreateJob_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment_KFAM_Unauthorized(t) + defer clients.Close() + server := NewJobServer(manager) + _, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Unauthorized access") +} + +func TestGetJob_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + job, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + clients.KfamClientFake = client.NewFakeKFAMClientUnauthorized() + manager = resource.NewResourceManager(clients) + server = NewJobServer(manager) + + _, err = server.GetJob(ctx, &api.GetJobRequest{Id: job.Id}) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Unauthorized access") +} + +func TestGetJob_Multiuser(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + createdJob, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + job, err := server.GetJob(ctx, &api.GetJobRequest{Id: createdJob.Id}) + assert.Nil(t, err) + assert.Equal(t, commonExpectedJob, job) +} + +func TestListJobs_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, experiment := initWithExperiment_KFAM_Unauthorized(t) + defer clients.Close() + server := NewJobServer(manager) + _, err := server.ListJobs(ctx, &api.ListJobsRequest{ + ResourceReferenceKey: &api.ResourceKey{ + Type: api.ResourceType_EXPERIMENT, + Id: experiment.UUID, + }, + }) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Unauthorized access") + + _, err = server.ListJobs(ctx, &api.ListJobsRequest{ + ResourceReferenceKey: &api.ResourceKey{ + Type: api.ResourceType_NAMESPACE, + Id: "ns1", + }, + }) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Unauthorized access") +} + +func TestListJobs_Multiuser(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + _, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + var expectedJobs []*api.Job + expectedJobs = append(expectedJobs, commonExpectedJob) + expectedJobsEmpty := []*api.Job{} + + tests := []struct { + name string + request *api.ListJobsRequest + wantError bool + errorMessage string + expectedJobs []*api.Job + }{ + { + "Valid - filter by experiment", + &api.ListJobsRequest{ + ResourceReferenceKey: &api.ResourceKey{ + Type: api.ResourceType_EXPERIMENT, + Id: "123e4567-e89b-12d3-a456-426655440000", + }, + }, + false, + "", + expectedJobs, + }, + { + "Valid - filter by namespace", + &api.ListJobsRequest{ + ResourceReferenceKey: &api.ResourceKey{ + Type: api.ResourceType_NAMESPACE, + Id: "ns1", + }, + }, + false, + "", + expectedJobs, + }, + { + "Vailid - filter by namespace - no result", + &api.ListJobsRequest{ + ResourceReferenceKey: &api.ResourceKey{ + Type: api.ResourceType_NAMESPACE, + Id: "no-such-ns", + }, + }, + false, + "", + expectedJobsEmpty, + }, + { + "Invalid - no filter", + &api.ListJobsRequest{}, + true, + "ListJobs must filter by resource reference", + nil, + }, + { + "Inalid - invalid filter type", + &api.ListJobsRequest{ + ResourceReferenceKey: &api.ResourceKey{ + Type: api.ResourceType_UNKNOWN_RESOURCE_TYPE, + Id: "unknown", + }, + }, + true, + "Unrecognized resource reference type", + nil, + }, + } + + for _, tc := range tests { + response, err := server.ListJobs(ctx, tc.request) + + if tc.wantError { + if err == nil { + t.Errorf("TestListJobs_Multiuser(%v) expect error but got nil", tc.name) + } else if !strings.Contains(err.Error(), tc.errorMessage) { + t.Errorf("TestListJobs_Multiusert(%v) expect error containing: %v, but got: %v", tc.name, tc.errorMessage, err) + } + } else { + if err != nil { + t.Errorf("TestListJobs_Multiuser(%v) expect no error but got %v", tc.name, err) + } else if !cmp.Equal(tc.expectedJobs, response.Jobs) { + t.Errorf("TestListJobs_Multiuser(%v) expect (%+v) but got (%+v)", tc.name, tc.expectedJobs, response.Jobs) + } + } + } +} + +func TestEnableJob_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + job, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + clients.KfamClientFake = client.NewFakeKFAMClientUnauthorized() + manager = resource.NewResourceManager(clients) + server = NewJobServer(manager) + + _, err = server.EnableJob(ctx, &api.EnableJobRequest{Id: job.Id}) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Unauthorized access") +} + +func TestEnableJob_Multiuser(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + + job, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + _, err = server.EnableJob(ctx, &api.EnableJobRequest{Id: job.Id}) + assert.Nil(t, err) +} + +func TestDisableJob_Unauthorized(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + job, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + clients.KfamClientFake = client.NewFakeKFAMClientUnauthorized() + manager = resource.NewResourceManager(clients) + server = NewJobServer(manager) + + _, err = server.DisableJob(ctx, &api.DisableJobRequest{Id: job.Id}) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "Unauthorized access") +} + +func TestDisableJob_Multiuser(t *testing.T) { + viper.Set(common.MultiUserMode, "true") + defer viper.Set(common.MultiUserMode, "false") + + md := metadata.New(map[string]string{common.GoogleIAPUserIdentityHeader: "accounts.google.com:user@google.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewJobServer(manager) + + job, err := server.CreateJob(ctx, &api.CreateJobRequest{Job: commonApiJob}) + assert.Nil(t, err) + + _, err = server.DisableJob(ctx, &api.DisableJobRequest{Id: job.Id}) + assert.Nil(t, err) +} diff --git a/backend/src/apiserver/storage/job_store.go b/backend/src/apiserver/storage/job_store.go index 825e3c1b21b4..f16ebd3d2d18 100644 --- a/backend/src/apiserver/storage/job_store.go +++ b/backend/src/apiserver/storage/job_store.go @@ -112,8 +112,18 @@ func (s *JobStore) ListJobs( func (s *JobStore) buildSelectJobsQuery(selectCount bool, opts *list.Options, filterContext *common.FilterContext) (string, []interface{}, error) { - filteredSelectBuilder, err := list.FilterOnResourceReference("jobs", jobColumns, - common.Job, selectCount, filterContext) + + var filteredSelectBuilder sq.SelectBuilder + var err error + + refKey := filterContext.ReferenceKey + if refKey != nil && refKey.Type == common.Namespace { + filteredSelectBuilder, err = list.FilterOnNamespace("jobs", jobColumns, + selectCount, refKey.ID) + } else { + filteredSelectBuilder, err = list.FilterOnResourceReference("jobs", jobColumns, + common.Job, selectCount, filterContext) + } if err != nil { return "", nil, util.NewInternalServerError(err, "Failed to list jobs: %v", err) } diff --git a/backend/src/apiserver/storage/job_store_test.go b/backend/src/apiserver/storage/job_store_test.go index d20c6d0dcac8..0726950bb885 100644 --- a/backend/src/apiserver/storage/job_store_test.go +++ b/backend/src/apiserver/storage/job_store_test.go @@ -38,9 +38,9 @@ const ( func initializeDbAndStore() (*DB, *JobStore) { db := NewFakeDbOrFatal() expStore := NewExperimentStore(db, util.NewFakeTimeForEpoch(), util.NewFakeUUIDGeneratorOrFatal(defaultFakeExpId, nil)) - expStore.CreateExperiment(&model.Experiment{Name: "exp1"}) + expStore.CreateExperiment(&model.Experiment{Name: "exp1", Namespace: "n1"}) expStore = NewExperimentStore(db, util.NewFakeTimeForEpoch(), util.NewFakeUUIDGeneratorOrFatal(defaultFakeExpIdTwo, nil)) - expStore.CreateExperiment(&model.Experiment{Name: "exp2"}) + expStore.CreateExperiment(&model.Experiment{Name: "exp2", Namespace: "n1"}) jobStore := NewJobStore(db, util.NewFakeTimeForEpoch()) job1 := &model.Job{ UUID: "1", @@ -419,6 +419,12 @@ func TestListJobs_FilterByReferenceKey(t *testing.T) { assert.Equal(t, "", nextPageToken) assert.Equal(t, 1, total_size) assert.Equal(t, jobsExpected, jobs) + + jobs, total_size, nextPageToken, err = jobStore.ListJobs( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Namespace, ID: "n1"}}, opts) + assert.Nil(t, err) + assert.Equal(t, "", nextPageToken) + assert.Equal(t, 2, total_size) // both test jobs belong to namespace `n1` } func TestListJobsError(t *testing.T) {