-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an integration test for Artifact API
- Loading branch information
Showing
3 changed files
with
610 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
//go:build test_integration | ||
|
||
/* | ||
Copyright 2024. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package integration | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
"os" | ||
"os/exec" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
TestUtil "github.com/opendatahub-io/data-science-pipelines-operator/tests/util" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func (suite *IntegrationTestSuite) TestFetchArtifacts() { | ||
|
||
suite.T().Run("Should successfully fetch and download artifacts", func(t *testing.T) { | ||
|
||
// Start port-forwarding | ||
cmd := exec.CommandContext(context.Background(), | ||
"kubectl", "port-forward", "-n", suite.DSPANamespace, "svc/artifact-service", fmt.Sprintf("%d:8080", PortForwardLocalPort)) | ||
err := cmd.Start() | ||
require.NoError(t, err, "Failed to start port-forwarding") | ||
|
||
// Ensure the port-forwarding process is terminated after the test | ||
defer func() { | ||
_ = cmd.Process.Kill() | ||
cmd.Wait() // Wait for the process to terminate completely | ||
}() | ||
|
||
// Wait briefly to ensure port-forwarding is established | ||
time.Sleep(5 * time.Second) | ||
|
||
type ResponseArtifact struct { | ||
ArtifactID string `json:"artifact_id"` | ||
DownloadUrl string `json:"download_url"` | ||
} | ||
type ResponseArtifactData struct { | ||
Artifacts []ResponseArtifact `json:"artifacts"` | ||
} | ||
|
||
name := "Test Iris Pipeline" | ||
uploadUrl := fmt.Sprintf("%s/apis/v2beta1/pipelines/upload?name=%s", APIServerURL, url.QueryEscape(name)) | ||
vals := map[string]string{ | ||
"uploadfile": "@resources/iris_pipeline_without_cache_compiled.yaml", | ||
} | ||
bodyUpload, contentTypeUpload := TestUtil.FormFromFile(t, vals) | ||
response, err := suite.Clientmgr.httpClient.Post(uploadUrl, contentTypeUpload, bodyUpload) | ||
require.NoError(t, err) | ||
responseData, err := io.ReadAll(response.Body) | ||
require.NoError(t, err) | ||
assert.Equal(t, http.StatusOK, response.StatusCode) | ||
|
||
// Retrieve Pipeline ID to create a new run | ||
pipelineID, err := TestUtil.RetrievePipelineId(t, suite.Clientmgr.httpClient, APIServerURL, name) | ||
require.NoError(t, err) | ||
|
||
// Create a new run | ||
runUrl := fmt.Sprintf("%s/apis/v2beta1/runs", APIServerURL) | ||
bodyRun := TestUtil.FormatRequestBody(t, pipelineID, name) | ||
contentTypeRun := "application/json" | ||
response, err = suite.Clientmgr.httpClient.Post(runUrl, contentTypeRun, bytes.NewReader(bodyRun)) | ||
require.NoError(t, err) | ||
responseData, err = io.ReadAll(response.Body) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusOK, response.StatusCode) | ||
err = TestUtil.WaitForPipelineRunCompletion(t, suite.Clientmgr.httpClient, APIServerURL) | ||
require.NoError(t, err) | ||
|
||
// fetch artifacts | ||
artifactsUrl := fmt.Sprintf("%s/apis/v2beta1/artifacts?namespace=%s", APIServerURL, suite.DSPANamespace) | ||
response, err = suite.Clientmgr.httpClient.Get(artifactsUrl) | ||
require.NoError(t, err) | ||
responseData, err = io.ReadAll(response.Body) | ||
require.NoError(t, err) | ||
assert.Equal(t, http.StatusOK, response.StatusCode) | ||
|
||
// iterate over the artifacts | ||
var responseArtifactsData ResponseArtifactData | ||
err = json.Unmarshal([]byte(string(responseData)), &responseArtifactsData) | ||
if err != nil { | ||
t.Errorf("Error unmarshaling JSON: %v", err) | ||
return | ||
} | ||
hasDownloadError := false | ||
for _, artifact := range responseArtifactsData.Artifacts { | ||
// get the artifact by ID | ||
artifactsByIdUrl := fmt.Sprintf("%s/apis/v2beta1/artifacts/%s", APIServerURL, artifact.ArtifactID) | ||
response, err = suite.Clientmgr.httpClient.Get(artifactsByIdUrl) | ||
require.NoError(t, err) | ||
responseData, err = io.ReadAll(response.Body) | ||
require.NoError(t, err) | ||
assert.Equal(t, http.StatusOK, response.StatusCode) | ||
|
||
// get download url | ||
artifactsByIdUrl = fmt.Sprintf("%s/apis/v2beta1/artifacts/%s?view=DOWNLOAD", APIServerURL, artifact.ArtifactID) | ||
response, err = suite.Clientmgr.httpClient.Get(artifactsByIdUrl) | ||
require.NoError(t, err) | ||
responseData, err = io.ReadAll(response.Body) | ||
require.NoError(t, err) | ||
assert.Equal(t, http.StatusOK, response.StatusCode) | ||
loggr.Info(string(responseData)) | ||
|
||
var responseArtifactData ResponseArtifact | ||
err = json.Unmarshal([]byte(string(responseData)), &responseArtifactData) | ||
if err != nil { | ||
t.Errorf("Error unmarshaling JSON: %v", err) | ||
return | ||
} | ||
|
||
content, err := downloadFile(responseArtifactData.DownloadUrl, "/tmp/download", suite.Clientmgr.httpClient) | ||
|
||
require.NoError(t, err) | ||
// There were an issue in the past that the URL was returning Access Denied | ||
if strings.Contains(content, "Access Denied") { | ||
hasDownloadError = true | ||
loggr.Error(errors.New("error downloading the artifact"), content) | ||
} | ||
} | ||
if hasDownloadError { | ||
t.Errorf("Error downloading the artifacts. Double check the error messages in the log") | ||
} | ||
}) | ||
} | ||
|
||
func downloadFile(url, filepath string, httpClient http.Client) (string, error) { | ||
// Create an HTTP GET request to fetch the file from the URL | ||
response, err := httpClient.Get(url) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to fetch the file: %w", err) | ||
} | ||
defer response.Body.Close() | ||
|
||
// Check if the response status is OK (200) | ||
if response.StatusCode != http.StatusOK { | ||
return "", fmt.Errorf("failed to download file: status code %d", response.StatusCode) | ||
} | ||
|
||
// Read the content from the response body | ||
content, err := ioutil.ReadAll(response.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to read content: %w", err) | ||
} | ||
|
||
// Create the file | ||
file, err := os.Create(filepath) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to create file: %w", err) | ||
} | ||
defer file.Close() | ||
|
||
// Write the content to the file | ||
_, err = file.Write(content) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to write content to file: %w", err) | ||
} | ||
|
||
return string(content), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from kfp import compiler, dsl | ||
from kfp.dsl import ClassificationMetrics, Dataset, Input, Model, Output | ||
|
||
common_base_image = ( | ||
"registry.redhat.io/ubi8/python-39@sha256:3523b184212e1f2243e76d8094ab52b01ea3015471471290d011625e1763af61" | ||
) | ||
# common_base_image = "quay.io/opendatahub/ds-pipelines-sample-base:v1.0" | ||
|
||
|
||
@dsl.component(base_image=common_base_image, packages_to_install=["pandas==2.2.0"]) | ||
def create_dataset(iris_dataset: Output[Dataset]): | ||
from io import StringIO # noqa: PLC0415 | ||
|
||
import pandas as pd # noqa: PLC0415 | ||
|
||
data = """ | ||
5.1,3.5,1.4,0.2,Iris-setosa | ||
4.9,3.0,1.4,0.2,Iris-setosa | ||
4.7,3.2,1.3,0.2,Iris-setosa | ||
4.6,3.1,1.5,0.2,Iris-setosa | ||
5.0,3.6,1.4,0.2,Iris-setosa | ||
5.7,3.8,1.7,0.3,Iris-setosa | ||
5.1,3.8,1.5,0.3,Iris-setosa | ||
5.4,3.4,1.7,0.2,Iris-setosa | ||
5.1,3.7,1.5,0.4,Iris-setosa | ||
5.1,3.4,1.5,0.2,Iris-setosa | ||
5.0,3.5,1.3,0.3,Iris-setosa | ||
4.5,2.3,1.3,0.3,Iris-setosa | ||
4.4,3.2,1.3,0.2,Iris-setosa | ||
5.0,3.5,1.6,0.6,Iris-setosa | ||
5.1,3.8,1.9,0.4,Iris-setosa | ||
4.8,3.0,1.4,0.3,Iris-setosa | ||
5.1,3.8,1.6,0.2,Iris-setosa | ||
4.6,3.2,1.4,0.2,Iris-setosa | ||
5.3,3.7,1.5,0.2,Iris-setosa | ||
5.0,3.3,1.4,0.2,Iris-setosa | ||
7.0,3.2,4.7,1.4,Iris-versicolor | ||
6.4,3.2,4.5,1.5,Iris-versicolor | ||
6.9,3.1,4.9,1.5,Iris-versicolor | ||
5.5,2.3,4.0,1.3,Iris-versicolor | ||
6.5,2.8,4.6,1.5,Iris-versicolor | ||
6.2,2.2,4.5,1.5,Iris-versicolor | ||
5.6,2.5,3.9,1.1,Iris-versicolor | ||
5.9,3.2,4.8,1.8,Iris-versicolor | ||
6.1,2.8,4.0,1.3,Iris-versicolor | ||
6.3,2.5,4.9,1.5,Iris-versicolor | ||
6.1,2.8,4.7,1.2,Iris-versicolor | ||
6.4,2.9,4.3,1.3,Iris-versicolor | ||
6.6,3.0,4.4,1.4,Iris-versicolor | ||
5.6,2.7,4.2,1.3,Iris-versicolor | ||
5.7,3.0,4.2,1.2,Iris-versicolor | ||
5.7,2.9,4.2,1.3,Iris-versicolor | ||
6.2,2.9,4.3,1.3,Iris-versicolor | ||
5.1,2.5,3.0,1.1,Iris-versicolor | ||
5.7,2.8,4.1,1.3,Iris-versicolor | ||
6.3,3.3,6.0,2.5,Iris-virginica | ||
5.8,2.7,5.1,1.9,Iris-virginica | ||
7.1,3.0,5.9,2.1,Iris-virginica | ||
6.3,2.9,5.6,1.8,Iris-virginica | ||
6.5,3.0,5.8,2.2,Iris-virginica | ||
6.9,3.1,5.1,2.3,Iris-virginica | ||
5.8,2.7,5.1,1.9,Iris-virginica | ||
6.8,3.2,5.9,2.3,Iris-virginica | ||
6.7,3.3,5.7,2.5,Iris-virginica | ||
6.7,3.0,5.2,2.3,Iris-virginica | ||
6.3,2.5,5.0,1.9,Iris-virginica | ||
6.5,3.0,5.2,2.0,Iris-virginica | ||
6.2,3.4,5.4,2.3,Iris-virginica | ||
5.9,3.0,5.1,1.8,Iris-virginica | ||
""" | ||
col_names = ["Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Labels"] | ||
df = pd.read_csv(StringIO(data), names=col_names) | ||
|
||
with open(iris_dataset.path, "w") as f: | ||
df.to_csv(f) | ||
|
||
|
||
@dsl.component( | ||
base_image=common_base_image, | ||
packages_to_install=["pandas==2.2.0", "scikit-learn==1.4.0"], | ||
) | ||
def normalize_dataset( | ||
input_iris_dataset: Input[Dataset], | ||
normalized_iris_dataset: Output[Dataset], | ||
standard_scaler: bool, | ||
): | ||
import pandas as pd # noqa: PLC0415 | ||
from sklearn.preprocessing import MinMaxScaler, StandardScaler # noqa: PLC0415 | ||
|
||
with open(input_iris_dataset.path) as f: | ||
df = pd.read_csv(f) | ||
labels = df.pop("Labels") | ||
|
||
scaler = StandardScaler() if standard_scaler else MinMaxScaler() | ||
|
||
df = pd.DataFrame(scaler.fit_transform(df)) | ||
df["Labels"] = labels | ||
normalized_iris_dataset.metadata["state"] = "Normalized" | ||
with open(normalized_iris_dataset.path, "w") as f: | ||
df.to_csv(f) | ||
|
||
|
||
@dsl.component( | ||
base_image=common_base_image, | ||
packages_to_install=["pandas==2.2.0", "scikit-learn==1.4.0"], | ||
) | ||
def train_model( | ||
normalized_iris_dataset: Input[Dataset], | ||
model: Output[Model], | ||
metrics: Output[ClassificationMetrics], | ||
n_neighbors: int, | ||
): | ||
import pickle # noqa: PLC0415 | ||
|
||
import pandas as pd # noqa: PLC0415 | ||
from sklearn.metrics import confusion_matrix # noqa: PLC0415 | ||
from sklearn.model_selection import cross_val_predict, train_test_split # noqa: PLC0415 | ||
from sklearn.neighbors import KNeighborsClassifier # noqa: PLC0415 | ||
|
||
with open(normalized_iris_dataset.path) as f: | ||
df = pd.read_csv(f) | ||
|
||
y = df.pop("Labels") | ||
X = df | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) # noqa: F841 | ||
|
||
clf = KNeighborsClassifier(n_neighbors=n_neighbors) | ||
clf.fit(X_train, y_train) | ||
|
||
predictions = cross_val_predict(clf, X_train, y_train, cv=3) | ||
metrics.log_confusion_matrix( | ||
["Iris-Setosa", "Iris-Versicolour", "Iris-Virginica"], | ||
confusion_matrix(y_train, predictions).tolist(), # .tolist() to convert np array to list. | ||
) | ||
|
||
model.metadata["framework"] = "scikit-learn" | ||
with open(model.path, "wb") as f: | ||
pickle.dump(clf, f) | ||
|
||
|
||
@dsl.pipeline(name="iris-training-pipeline") | ||
def my_pipeline( | ||
standard_scaler: bool = True, | ||
neighbors: int = 3, | ||
): | ||
create_dataset_task = create_dataset().set_caching_options(False) | ||
|
||
normalize_dataset_task = normalize_dataset( | ||
input_iris_dataset=create_dataset_task.outputs["iris_dataset"], standard_scaler=standard_scaler | ||
).set_caching_options(False) | ||
|
||
train_model( | ||
normalized_iris_dataset=normalize_dataset_task.outputs["normalized_iris_dataset"], n_neighbors=neighbors | ||
).set_caching_options(False) | ||
|
||
|
||
if __name__ == "__main__": | ||
compiler.Compiler().compile(my_pipeline, package_path=__file__.replace(".py", "_compiled.yaml")) | ||
|
Oops, something went wrong.