Skip to content

Commit

Permalink
🐛 Force return on load for old style datasets (#592) (#593)
Browse files Browse the repository at this point in the history
* 🐛 Force return on load for old style datasets (#592)

* skip test for kedro<0.19.7

* fix test assert in new mlflow version
  • Loading branch information
Galileo-Galilei authored Sep 23, 2024
1 parent ed43e0d commit f066d7d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _load(self) -> Any: # pragma: no cover
if getattr(super().load, "__loadwrapped__", False): # modern dataset
return super().load.__wrapped__(self)
else: # legacy dataset
super()._load()
return super()._load()

# rename the class
parent_name = dataset_obj.__name__
Expand Down
2 changes: 1 addition & 1 deletion tests/framework/cli/test_cli_modelify.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_modelify_with_infer_input_example(
"artifact_path": "input_example.json",
"pandas_orient": "split",
"type": "dataframe",
"serving_input_path": "serving_input_payload.json",
"serving_input_path": "serving_input_example.json",
}


Expand Down
12 changes: 9 additions & 3 deletions tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mlflow
import pandas as pd
import pytest
from kedro import __version__ as kedro_version
from kedro.io import AbstractDataset
from kedro_datasets.pandas import CSVDataset
from kedro_datasets.partitions import PartitionedDataset
Expand All @@ -11,6 +12,8 @@

from kedro_mlflow.io.artifacts import MlflowArtifactDataset

KEDRO_VERSION = tuple(int(x) for x in kedro_version.split("."))


@pytest.fixture
def df1():
Expand Down Expand Up @@ -249,7 +252,7 @@ def test_artifact_dataset_load_with_run_id_and_artifact_path(


@pytest.mark.parametrize("artifact_path", [None, "partitioned_data"])
def test_partitioned_dataset_save_and_reload(
def test_artifact_dataset_partitioned_dataset_save_and_reload(
tmp_path, mlflow_client, artifact_path, df1, df2
):
mlflow_dataset = MlflowArtifactDataset(
Expand Down Expand Up @@ -292,7 +295,10 @@ def test_partitioned_dataset_save_and_reload(
pd.testing.assert_frame_equal(df, reloaded_data[k])


def test_modern_dataset(tmp_path, mlflow_client, df1):
@pytest.mark.skipif(
KEDRO_VERSION < (0, 19, 7), reason="modern datasets were introduced in kedro 0.19.7"
)
def test_artifact_dataset_modern_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithoutUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)
Expand Down Expand Up @@ -332,7 +338,7 @@ def _describe(self):
assert df1.equals(mlflow_dataset.load())


def test_legacy_dataset(tmp_path, mlflow_client, df1):
def test_artifact_dataset_legacy_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)
Expand Down

0 comments on commit f066d7d

Please sign in to comment.