Skip to content

Commit

Permalink
#12 Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kaemo authored and Galileo-Galilei committed Oct 31, 2020
1 parent ecaba73 commit 556572d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ line_length=88
ensure_newline_before_comments=True
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
known_first_party=kedro_mlflow
known_third_party=anyconfig,click,cookiecutter,jinja2,kedro,mlflow,packaging,pandas,pytest,pytest_lazyfixture,setuptools,yaml
known_third_party=anyconfig,black,click,cookiecutter,flake8,isort,jinja2,kedro,mlflow,pandas,pytest,pytest_lazyfixture,setuptools,sklearn,yaml
1 change: 1 addition & 0 deletions requirements/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pytest>=5.4.0, <6.0.0
pytest-cov>=2.8.0, <3.0.0
pytest-lazy-fixture>=0.6.0, <1.0.0
pytest-mock>=3.1.0, <4.0.0
sklearn>=0.23.0, <0.24.0
flake8>=3.0.0, <4.0.0
black==19.10b0 # pin black version because it is not compatible with a pip range (because of non semver version number)
isort>=5.0.0, <6.0.0
96 changes: 96 additions & 0 deletions tests/io/test_mlflow_model_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import mlflow
import pytest
from kedro.io.core import DataSetError
from mlflow.tracking import MlflowClient
from sklearn.linear_model import LinearRegression

from kedro_mlflow.io import MlflowModelDataSet


@pytest.fixture
def linreg_model():
return LinearRegression()


@pytest.fixture
def linreg_path(tmp_path):
return tmp_path / "06_models/linreg"


@pytest.fixture
def mlflow_client_run_id(tmp_path):
tracking_uri = tmp_path / "mlruns"
mlflow.set_tracking_uri(tracking_uri.as_uri())
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())
mlflow.start_run()
yield mlflow_client, mlflow.active_run().info.run_id
mlflow.end_run()


def test_flavor_does_not_exists(linreg_path):
with pytest.raises(DataSetError):
MlflowModelDataSet.from_config(
name="whoops",
config={
"type": "kedro_mlflow.io.MlflowModelDataSet",
"flavor": "mlflow.whoops",
"path": linreg_path,
},
)


def test_save_unversioned_under_same_path(
linreg_path, linreg_model, mlflow_client_run_id
):
model_config = {
"name": "linreg",
"config": {
"type": "kedro_mlflow.io.MlflowModelDataSet",
"flavor": "mlflow.sklearn",
"path": linreg_path,
},
}
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config)
mlflow_model_ds.save(linreg_model)
mlflow_model_ds.save(linreg_model)


@pytest.mark.parametrize(
"versioned,from_run_id",
[(False, False), (True, False), (False, True), (True, True)],
)
def test_save_load_local(
linreg_path, linreg_model, mlflow_client_run_id, versioned, from_run_id
):
model_config = {
"name": "linreg",
"config": {
"type": "kedro_mlflow.io.MlflowModelDataSet",
"flavor": "mlflow.sklearn",
"path": linreg_path,
"versioned": versioned,
},
}
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config)
mlflow_model_ds.save(linreg_model)

if versioned:
assert (
linreg_path / mlflow_model_ds._version.save / linreg_path.name
).exists(), "Versioned model saved locally"
else:
assert linreg_path.exists(), "Unversioned model saved locally"

mlflow_client, run_id = mlflow_client_run_id
artifact = mlflow_client.list_artifacts(run_id=run_id)[0]
versioned_str = "Versioned" if versioned else "Unversioned"
assert linreg_path.name == artifact.path, f"{versioned_str} model logged to MLflow"

if from_run_id:
model_config["config"]["run_id"] = run_id
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config)

linreg_model_loaded = mlflow_model_ds.load()
assert isinstance(
linreg_model_loaded, LinearRegression
), f"{versioned_str} model loaded"

0 comments on commit 556572d

Please sign in to comment.