diff --git a/backend/src/apiserver/main.go b/backend/src/apiserver/main.go index 8d87bb47fc3..7bdd6243e57 100644 --- a/backend/src/apiserver/main.go +++ b/backend/src/apiserver/main.go @@ -83,6 +83,7 @@ func startRpcServer(resourceManager *resource.ResourceManager) { api.RegisterRunServiceServer(s, server.NewRunServer(resourceManager)) api.RegisterJobServiceServer(s, server.NewJobServer(resourceManager)) api.RegisterReportServiceServer(s, server.NewReportServer(resourceManager)) + api.RegisterVisualizationServiceServer(s, server.NewVisualizationServer(resourceManager)) // Register reflection service on gRPC server. reflection.Register(s) @@ -106,6 +107,7 @@ func startHttpProxy(resourceManager *resource.ResourceManager) { registerHttpHandlerFromEndpoint(api.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, mux) registerHttpHandlerFromEndpoint(api.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, mux) registerHttpHandlerFromEndpoint(api.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, mux) + registerHttpHandlerFromEndpoint(api.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, mux) // Create a top level mux to include both pipeline upload server and gRPC servers. topMux := http.NewServeMux() diff --git a/backend/src/apiserver/server/BUILD.bazel b/backend/src/apiserver/server/BUILD.bazel index e72c46b11e1..83c307b0b8a 100644 --- a/backend/src/apiserver/server/BUILD.bazel +++ b/backend/src/apiserver/server/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "run_server.go", "test_util.go", "util.go", + "visualization_server.go", ], importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/server", visibility = ["//visibility:public"], @@ -50,6 +51,7 @@ go_test( "run_metric_util_test.go", "run_server_test.go", "util_test.go", + "visualization_server_test.go", ], data = glob(["test/**/*"]), # keep embed = [":go_default_library"], diff --git a/backend/src/apiserver/server/visualization_server.go b/backend/src/apiserver/server/visualization_server.go new file mode 100644 index 00000000000..e9982dbde57 --- /dev/null +++ b/backend/src/apiserver/server/visualization_server.go @@ -0,0 +1,78 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "github.com/kubeflow/pipelines/backend/api/go_client" + "github.com/kubeflow/pipelines/backend/src/apiserver/resource" + "github.com/kubeflow/pipelines/backend/src/common/util" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +type VisualizationServer struct { + resourceManager *resource.ResourceManager + serviceURL string +} + +func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *go_client.CreateVisualizationRequest) (*go_client.Visualization, error) { + if err := s.validateCreateVisualizationRequest(request); err != nil { + return nil, err + } + body, err := s.generateVisualizationFromRequest(request) + if err != nil { + return nil, err + } + request.Visualization.Html = string(body) + return request.Visualization, nil +} + +// validateCreateVisualizationRequest ensures that a go_client.Visualization +// object has valid values. +// It returns an error if a go_client.Visualization object does not have valid +// values. +func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_client.CreateVisualizationRequest) error { + if len(request.Visualization.InputPath) == 0 { + return util.NewInvalidInputError("A visualization requires an InputPath to be provided. Received %s", request.Visualization.InputPath) + } + // Manually set Arguments to empty JSON if nothing is provided. This is done + // because visualizations such as TFDV and TFMA only require an InputPath to + // be provided for a visualization to be generated. If no JSON is provided + // json.Valid will fail without this check as an empty string is provided for + // those visualizations. + if len(request.Visualization.Arguments) == 0 { + request.Visualization.Arguments = "{}" + } + if !json.Valid([]byte(request.Visualization.Arguments)) { + return util.NewInvalidInputError("A visualization requires valid JSON to be provided as Arguments. Received %s", request.Visualization.Arguments) + } + return nil +} + +// generateVisualizationFromRequest communicates with the python visualization +// service to generate HTML visualizations from a request. +// It returns the generated HTML as a string and any error that is encountered. +func (s *VisualizationServer) generateVisualizationFromRequest(request *go_client.CreateVisualizationRequest) ([]byte, error) { + visualizationType := strings.ToLower(go_client.Visualization_Type_name[int32(request.Visualization.Type)]) + arguments := fmt.Sprintf("--type %s --input_path %s --arguments '%s'", visualizationType, request.Visualization.InputPath, request.Visualization.Arguments) + resp, err := http.PostForm(s.serviceURL, url.Values{"arguments": {arguments}}) + if err != nil { + return nil, util.Wrap(err, "Unable to initialize visualization request.") + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf(resp.Status) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, util.Wrap(err, "Unable to parse visualization response.") + } + return body, nil +} + +func NewVisualizationServer(resourceManager *resource.ResourceManager) *VisualizationServer { + return &VisualizationServer{resourceManager: resourceManager, serviceURL: "http://visualization-service.kubeflow"} +} diff --git a/backend/src/apiserver/server/visualization_server_test.go b/backend/src/apiserver/server/visualization_server_test.go new file mode 100644 index 00000000000..bceb52785fe --- /dev/null +++ b/backend/src/apiserver/server/visualization_server_test.go @@ -0,0 +1,117 @@ +package server + +import ( + "github.com/kubeflow/pipelines/backend/api/go_client" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestValidateCreateVisualizationRequest(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewVisualizationServer(manager) + visualization := &go_client.Visualization{ + Type: go_client.Visualization_ROC_CURVE, + InputPath: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + } + request := &go_client.CreateVisualizationRequest{ + Visualization: visualization, + } + err := server.validateCreateVisualizationRequest(request) + assert.Nil(t, err) +} + +func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewVisualizationServer(manager) + visualization := &go_client.Visualization{ + Type: go_client.Visualization_ROC_CURVE, + InputPath: "gs://ml-pipeline/roc/data.csv", + Arguments: "", + } + request := &go_client.CreateVisualizationRequest{ + Visualization: visualization, + } + err := server.validateCreateVisualizationRequest(request) + assert.Nil(t, err) +} + +func TestValidateCreateVisualizationRequest_InputPathIsEmpty(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewVisualizationServer(manager) + visualization := &go_client.Visualization{ + Type: go_client.Visualization_ROC_CURVE, + InputPath: "", + Arguments: "{}", + } + request := &go_client.CreateVisualizationRequest{ + Visualization: visualization, + } + err := server.validateCreateVisualizationRequest(request) + assert.Contains(t, err.Error(), "A visualization requires an InputPath to be provided. Received") +} + +func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + server := NewVisualizationServer(manager) + visualization := &go_client.Visualization{ + Type: go_client.Visualization_ROC_CURVE, + InputPath: "gs://ml-pipeline/roc/data.csv", + Arguments: "{", + } + request := &go_client.CreateVisualizationRequest{ + Visualization: visualization, + } + err := server.validateCreateVisualizationRequest(request) + assert.Contains(t, err.Error(), "A visualization requires valid JSON to be provided as Arguments. Received {") +} + +func TestGenerateVisualization(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/", req.URL.String()) + rw.Write([]byte("roc_curve")) + })) + defer httpServer.Close() + server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL} + visualization := &go_client.Visualization{ + Type: go_client.Visualization_ROC_CURVE, + InputPath: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + } + request := &go_client.CreateVisualizationRequest{ + Visualization: visualization, + } + body, err := server.generateVisualizationFromRequest(request) + assert.Equal(t, []byte("roc_curve"), body) + assert.Nil(t, err) +} + +func TestGenerateVisualization_ServerError(t *testing.T) { + clients, manager, _ := initWithExperiment(t) + defer clients.Close() + httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, "/", req.URL.String()) + rw.WriteHeader(500) + })) + defer httpServer.Close() + server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL} + visualization := &go_client.Visualization{ + Type: go_client.Visualization_ROC_CURVE, + InputPath: "gs://ml-pipeline/roc/data.csv", + Arguments: "{}", + } + request := &go_client.CreateVisualizationRequest{ + Visualization: visualization, + } + body, err := server.generateVisualizationFromRequest(request) + assert.Nil(t, body) + assert.Equal(t, "500 Internal Server Error", err.Error()) +}