From 3d6d959a7932dba0ff61f611756a1d8de8a462cb Mon Sep 17 00:00:00 2001 From: abchoo Date: Fri, 16 Apr 2021 09:18:13 -0400 Subject: [PATCH] Introducing HTTPS protocol for MMS (#1510) * Added test for https downloader * update invalid prefix test case * Added https protocol and updated GetAllProtocol * added https provider case * set up https download for uncompressed/gzip/tar/zip files client copy response body to a file if not zip.tar/gzip type s * Moved support protocols to protocols * fixed variable reference * support for http protocol * update prefix to be clearly invalid --- pkg/agent/downloader.go | 4 +- pkg/agent/storage/https.go | 210 ++++++++++++++++++ pkg/agent/storage/provider.go | 12 +- pkg/agent/storage/utils.go | 12 +- pkg/agent/watcher_test.go | 194 +++++++++++++++- .../v1alpha1/trainedmodel_webhook_test.go | 8 +- 6 files changed, 428 insertions(+), 12 deletions(-) create mode 100644 pkg/agent/storage/https.go diff --git a/pkg/agent/downloader.go b/pkg/agent/downloader.go index f5bece154b8..cac4dc74da1 100644 --- a/pkg/agent/downloader.go +++ b/pkg/agent/downloader.go @@ -37,8 +37,6 @@ type Downloader struct { Logger *zap.SugaredLogger } -var SupportedProtocols = []storage.Protocol{storage.S3, storage.GCS} - func (d *Downloader) DownloadModel(modelName string, modelSpec *v1alpha1.ModelSpec) error { if modelSpec != nil { sha256 := storage.AsSha256(modelSpec) @@ -105,7 +103,7 @@ func extractProtocol(storageURI string) (storage.Protocol, error) { return "", fmt.Errorf("there is no protocol specificed for the storageUri") } - for _, prefix := range SupportedProtocols { + for _, prefix := range storage.SupportedProtocols { if strings.HasPrefix(storageURI, string(prefix)) { return prefix, nil } diff --git a/pkg/agent/storage/https.go b/pkg/agent/storage/https.go new file mode 100644 index 00000000000..e332ecca320 --- /dev/null +++ b/pkg/agent/storage/https.go @@ -0,0 +1,210 @@ +/* +Copyright 2021 kubeflow.org. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" +) + +const ( + HEADER_SUFFIX = "-headers" +) + +type HTTPSProvider struct { + Client *http.Client +} + +func (m *HTTPSProvider) DownloadModel(modelDir string, modelName string, storageUri string) error { + log.Info("Download model ", "modelName", modelName, "storageUri", storageUri, "modelDir", modelDir) + uri, err := url.Parse(storageUri) + if err != nil { + return fmt.Errorf("unable to parse storage uri: %v", err) + } + HTTPSDownloader := &HTTPSDownloader{ + StorageUri: storageUri, + ModelDir: modelDir, + ModelName: modelName, + Uri: uri, + } + if err := HTTPSDownloader.Download(*m.Client); err != nil { + return err + } + return nil +} + +type HTTPSDownloader struct { + StorageUri string + ModelDir string + ModelName string + Uri *url.URL +} + +func (h *HTTPSDownloader) Download(client http.Client) error { + // Create request + req, err := http.NewRequest("GET", h.StorageUri, nil) + if err != nil { + return err + } + + headers, err := h.extractHeaders() + if err != nil { + return err + } + for key, element := range headers { + req.Header.Add(key, element) + } + + // Query request + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to make a request: %v", err) + } + + defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("URI: %s returned a %d response code", h.StorageUri, resp.StatusCode) + } + + // Write content into file(s) + contentType := resp.Header.Get("Content-type") + fileDirectory := filepath.Join(h.ModelDir, h.ModelName) + + if strings.Contains(contentType, "application/zip") { + if err := extractZipFiles(resp.Body, fileDirectory); err != nil { + return err + } + } else if strings.Contains(contentType, "application/x-tar") || strings.Contains(contentType, "application/gzip") { + if err := extractTarFiles(resp.Body, fileDirectory); err != nil { + return err + } + } else { + paths := strings.Split(h.Uri.Path, "/") + fileName := paths[len(paths)-1] + fileFullName := filepath.Join(fileDirectory, fileName) + file, err := createNewFile(fileFullName) + if err != nil { + return err + } + if _, err = io.Copy(file, resp.Body); err != nil { + return fmt.Errorf("unable to copy file content: %v", err) + } + } + + return nil +} + +func (h *HTTPSDownloader) extractHeaders() (map[string]string, error) { + var headers map[string]string + hostname := h.Uri.Hostname() + headerJSON := os.Getenv(hostname + HEADER_SUFFIX) + json.Unmarshal([]byte(headerJSON), &headers) + return headers, nil +} + +func createNewFile(fileFullName string) (*os.File, error) { + if FileExists(fileFullName) { + if err := os.Remove(fileFullName); err != nil { + return nil, fmt.Errorf("file is unable to be deleted: %v", err) + } + } + + file, err := Create(fileFullName) + if err != nil { + return nil, fmt.Errorf("file is already created: %v", err) + } + return file, nil +} + +func extractZipFiles(reader io.Reader, dest string) error { + body, err := ioutil.ReadAll(reader) + if err != nil { + return err + } + + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return fmt.Errorf("unable to create new reader: %v", err) + } + + // Read all the files from zip archive + for _, zipFile := range zipReader.File { + fileFullPath := filepath.Join(dest, zipFile.Name) + if !strings.HasPrefix(fileFullPath, filepath.Clean(dest)+string(os.PathSeparator)) { + return fmt.Errorf("%s: illegal file path", fileFullPath) + } + + file, err := createNewFile(fileFullPath) + if err != nil { + return err + } + + rc, err := zipFile.Open() + if err != nil { + return fmt.Errorf("unable to open file: %v", err) + } + + _, err = io.Copy(file, rc) + file.Close() + rc.Close() + if err != nil { + return fmt.Errorf("unable to copy file content: %v", err) + } + } + return nil +} + +func extractTarFiles(reader io.Reader, dest string) error { + gzr, err := gzip.NewReader(reader) + if err != nil { + return err + } + defer gzr.Close() + + tr := tar.NewReader(gzr) + + // Read all the files from tar archive + for { + header, err := tr.Next() + if err == io.EOF { + break + } else if err != nil { + return fmt.Errorf("unable to access next tar file: %v", err) + } + + fileFullPath := filepath.Join(dest, header.Name) + newFile, err := createNewFile(fileFullPath) + if err != nil { + return err + } + if _, err := io.Copy(newFile, tr); err != nil { + return fmt.Errorf("unable to copy contents to %s: %v", header.Name, err) + } + } + return nil +} diff --git a/pkg/agent/storage/provider.go b/pkg/agent/storage/provider.go index b297c92ba81..4f7d3f1c63e 100644 --- a/pkg/agent/storage/provider.go +++ b/pkg/agent/storage/provider.go @@ -27,9 +27,15 @@ const ( GCS Protocol = "gs://" //PVC Protocol = "pvc://" //File Protocol = "file://" - //HTTPS Protocol = "https://" + HTTPS Protocol = "https://" + HTTP Protocol = "http://" ) -func GetAllProtocol() []string { - return []string{string(S3), string(GCS)} +var SupportedProtocols = []Protocol{S3, GCS, HTTPS, HTTP} + +func GetAllProtocol() (protocols []string) { + for _, protocol := range SupportedProtocols { + protocols = append(protocols, string(protocol)) + } + return protocols } diff --git a/pkg/agent/storage/utils.go b/pkg/agent/storage/utils.go index e469c5e6fda..1d912abf4b6 100644 --- a/pkg/agent/storage/utils.go +++ b/pkg/agent/storage/utils.go @@ -29,6 +29,7 @@ import ( gcscredential "github.com/kubeflow/kfserving/pkg/credentials/gcs" s3credential "github.com/kubeflow/kfserving/pkg/credentials/s3" "google.golang.org/api/option" + "net/http" "os" "path/filepath" "strings" @@ -132,7 +133,16 @@ func GetProvider(providers map[Protocol]Provider, protocol Protocol) (Provider, }), } } - + case HTTPS: + httpsClient := &http.Client{} + providers[HTTPS] = &HTTPSProvider{ + Client: httpsClient, + } + case HTTP: + httpsClient := &http.Client{} + providers[HTTP] = &HTTPSProvider{ + Client: httpsClient, + } } return providers[protocol], nil diff --git a/pkg/agent/watcher_test.go b/pkg/agent/watcher_test.go index b1b4b3d2066..483dd0d6b2f 100644 --- a/pkg/agent/watcher_test.go +++ b/pkg/agent/watcher_test.go @@ -32,6 +32,8 @@ import ( "io/ioutil" "k8s.io/apimachinery/pkg/api/resource" logger "log" + "net/http" + "net/http/httptest" "os" "path/filepath" "sync" @@ -84,7 +86,7 @@ var _ = Describe("Watcher", func() { channelMap: make(map[string]*ModelChannel), completions: make(chan *ModelOp, 4), opStats: make(map[string]map[OpType]int), - waitGroup: WaitGroupWrapper{sync.WaitGroup{}}, + waitGroup: WaitGroupWrapper{sync.WaitGroup{}}, Downloader: Downloader{ ModelDir: modelDir + "/test1", Providers: map[storage.Protocol]storage.Provider{ @@ -532,4 +534,194 @@ var _ = Describe("Watcher", func() { }) }) }) + + Describe("Use HTTP(S) Downloader", func() { + Context("Download Uncompressed Model", func() { + It("should download test model and write contents", func() { + modelContents := "Temporary content" + scenarios := map[string]struct { + server *httptest.Server + }{ + "HTTP": { + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, modelContents) + })), + }, + "HTTPS": { + httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, modelContents) + })), + }, + } + + for protocol, scenario := range scenarios { + logger.Printf("Setting up %s Server", protocol) + ts := scenario.server + defer ts.Close() + + modelName := "model1" + modelFile := "model.joblib" + modelStorageURI := ts.URL + "/" + modelFile + cl := storage.HTTPSProvider{ + Client: ts.Client(), + } + + err := cl.DownloadModel(modelDir, modelName, modelStorageURI) + Expect(err).To(BeNil()) + + testFile := filepath.Join(modelDir, modelName, modelFile) + dat, err := ioutil.ReadFile(testFile) + Expect(err).To(BeNil()) + Expect(string(dat)).To(Equal(modelContents + "\n")) + } + }) + }) + + Context("Model Download Failure", func() { + It("should fail out if the uri does not exist", func() { + logger.Printf("Creating Client") + modelName := "model1" + invalidModelStorageURI := "https://example.com/model.joblib" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + cl := storage.HTTPSProvider{ + Client: ts.Client(), + } + + expectedErr := fmt.Errorf("URI: %s returned a %d response code", invalidModelStorageURI, 404) + actualErr := cl.DownloadModel(modelDir, modelName, invalidModelStorageURI) + Expect(actualErr).To(Equal(expectedErr)) + }) + }) + + Context("Download All Models", func() { + It("should download and load zip and tar files", func() { + tarContent := "1f8b0800bac550600003cbcd4f49cdd12b28c960a01d3030303033315100d1e666a660dac008c287010" + + "54313a090a189919981998281a1b1b1a1118382010ddd0407a5c525894540a754656466e464e2560754" + + "969686c71ca83fe0f4281805a360140c7200009f7e1bb400060000" + + zipContents := "504b030414000800080035b67052000000000000000000000000090020006d6f64656c2e70746855540" + + "d000786c5506086c5506086c5506075780b000104f501000004140000000300504b07080000000002000" + + "00000000000504b0102140314000800080035b6705200000000020000000000000009002000000000000" + + "0000000a481000000006d6f64656c2e70746855540d000786c5506086c5506086c5506075780b000104f" + + "50100000414000000504b0506000000000100010057000000590000000000" + + scenarios := map[string]struct { + tarServer *httptest.Server + zipServer *httptest.Server + }{ + "HTTP": { + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, tarContent) + })), + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, zipContents) + })), + }, + "HTTPS": { + httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, tarContent) + })), + httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, zipContents) + })), + }, + } + for protocol, scenario := range scenarios { + logger.Printf("Using %s Server", protocol) + logger.Printf("Setting up tar model") + tarServer := scenario.tarServer + defer tarServer.Close() + + tarModel := "model1" + tarStorageURI := tarServer.URL + "/test.tar" + tarcl := storage.HTTPSProvider{ + Client: tarServer.Client(), + } + + logger.Printf("Setting up zip model") + zipServer := scenario.zipServer + defer zipServer.Close() + + zipModel := "model2" + zipStorageURI := zipServer.URL + "/test.zip" + zipcl := storage.HTTPSProvider{ + Client: tarServer.Client(), + } + + err := zipcl.DownloadModel(modelDir, zipModel, zipStorageURI) + Expect(err).To(BeNil()) + err = tarcl.DownloadModel(modelDir, tarModel, tarStorageURI) + Expect(err).To(BeNil()) + } + }) + }) + + Context("Getting new model events", func() { + It("should download and load the new models", func() { + modelContents := "Temporary content" + scenarios := map[string]struct { + server *httptest.Server + }{ + "HTTP": { + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, modelContents) + })), + }, + "HTTPS": { + httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, modelContents) + })), + }, + } + for protocol, scenario := range scenarios { + logger.Printf("Setting up %s Server", protocol) + logger.Printf("Sync model config using temp dir %v\n", modelDir) + watcher := NewWatcher("/tmp/configs", modelDir, sugar) + modelConfigs := modelconfig.ModelConfigs{ + { + Name: "model1", + Spec: v1alpha1.ModelSpec{ + StorageURI: "http://example.com/test.tar", + Framework: "sklearn", + }, + }, + { + Name: "model2", + Spec: v1alpha1.ModelSpec{ + StorageURI: "https://example.com/test.zip", + Framework: "sklearn", + }, + }, + } + + // Create HTTPS client + ts := scenario.server + defer ts.Close() + cl := storage.HTTPSProvider{ + Client: ts.Client(), + } + + watcher.parseConfig(modelConfigs, false) + puller := Puller{ + channelMap: make(map[string]*ModelChannel), + completions: make(chan *ModelOp, 4), + opStats: make(map[string]map[OpType]int), + Downloader: Downloader{ + ModelDir: modelDir + "/test1", + Providers: map[storage.Protocol]storage.Provider{ + storage.HTTPS: &cl, + }, + Logger: sugar, + }, + logger: sugar, + } + go puller.processCommands(watcher.ModelEvents) + Eventually(func() int { return len(puller.channelMap) }).Should(Equal(0)) + Eventually(func() int { return puller.opStats["model1"][Add] }).Should(Equal(1)) + Eventually(func() int { return puller.opStats["model2"][Add] }).Should(Equal(1)) + } + }) + }) + }) }) diff --git a/pkg/apis/serving/v1alpha1/trainedmodel_webhook_test.go b/pkg/apis/serving/v1alpha1/trainedmodel_webhook_test.go index 982bcd69d2d..08c283c9653 100644 --- a/pkg/apis/serving/v1alpha1/trainedmodel_webhook_test.go +++ b/pkg/apis/serving/v1alpha1/trainedmodel_webhook_test.go @@ -96,9 +96,9 @@ func TestValidateCreate(t *testing.T) { "invalid storageURI prefix": { tm: makeTestTrainModel(), update: map[string]string{ - storageURI: "https://kfserving/sklearn/iris", + storageURI: "foo://kfserving/sklearn/iris", }, - matcher: gomega.MatchError(fmt.Errorf(InvalidStorageUriFormatError, "bar", StorageUriProtocols, "https://kfserving/sklearn/iris")), + matcher: gomega.MatchError(fmt.Errorf(InvalidStorageUriFormatError, "bar", StorageUriProtocols, "foo://kfserving/sklearn/iris")), }, } @@ -190,9 +190,9 @@ func TestValidateUpdate(t *testing.T) { "invalid storageURI prefix": { tm: makeTestTrainModel(), update: map[string]string{ - storageURI: "https://kfserving/sklearn/iris", + storageURI: "foo://kfserving/sklearn/iris", }, - matcher: gomega.MatchError(fmt.Errorf(InvalidStorageUriFormatError, "bar", StorageUriProtocols, "https://kfserving/sklearn/iris")), + matcher: gomega.MatchError(fmt.Errorf(InvalidStorageUriFormatError, "bar", StorageUriProtocols, "foo://kfserving/sklearn/iris")), }, "framework": { tm: makeTestTrainModel(),