Skip to content

Commit

Permalink
Rebased and solved conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Jan 12, 2022
1 parent 3f2dd2a commit 6b87e6b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 28 deletions.
14 changes: 10 additions & 4 deletions examples/drift_detection/evidently.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@
"from rich import print\n",
"from sklearn import datasets\n",
"\n",
"from zenml.integrations.evidently import steps as evidently_steps\n",
"from zenml.integrations.evidently.steps import (\n",
" EvidentlyProfileConfig,\n",
" EvidentlyProfileStep,\n",
")\n",
"from zenml.pipelines import pipeline\n",
"from zenml.steps import step"
]
Expand Down Expand Up @@ -247,8 +250,11 @@
},
"outputs": [],
"source": [
"drift_detector = evidently_steps.EvidentlyDriftDetectionStep(\n",
" evidently_steps.EvidentlyDriftDetectionConfig(column_mapping=None)\n",
"drift_detector = EvidentlyProfileStep(\n",
" EvidentlyProfileConfig(\n",
" column_mapping=None,\n",
" profile_sections=[\"datadrift\"],\n",
" )\n",
")"
]
},
Expand Down Expand Up @@ -340,7 +346,7 @@
"id": "NrJA5OSgnydC"
},
"source": [
"Running the pipeline is as simple as calling the `run()` method on an instance of the defined pipeline. Here we explicitly name our pipeline run to make it easier to access later on. Be aware that you can only run the pipeline once with this name. To rerun, rename the the run, or remove the run name."
"Running the pipeline is as simple as calling the `run()` method on an instance of the defined pipeline."
]
},
{
Expand Down
22 changes: 14 additions & 8 deletions examples/drift_detection/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

import json
import pandas as pd
from rich import print
from sklearn import datasets
Expand Down Expand Up @@ -58,7 +59,7 @@ def partial_split(
drift_detector = EvidentlyProfileStep(
EvidentlyProfileConfig(
column_mapping=None,
profile_section=["datadrift"],
profile_sections=["datadrift"],
)
)

Expand Down Expand Up @@ -109,11 +110,16 @@ def visualize_statistics():
pipeline.run()

repo = Repository()
pipeline = repo.get_pipelines()[-1]
runs = pipeline.runs
run = runs[-1]
steps = run.steps
step = steps[-1]
output = step.output
print(output.read())
pipeline = repo.get_pipelines()[0]
last_run = pipeline.runs[-1]
drift_analysis_step = last_run.get_step(
name="drift_analyzer"
)
print(f'Data drift detected: {drift_analysis_step.output.read()}')

drift_detection_step = last_run.get_step(
name="drift_detector"
)
print(json.dumps(drift_detection_step.outputs['profile'].read(), indent=2))

visualize_statistics()
72 changes: 58 additions & 14 deletions src/zenml/integrations/evidently/steps/evidently_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Optional, Sequence, cast
from typing import List, Optional, Sequence, Tuple

from evidently.dashboard import Dashboard # type: ignore
from evidently.model_profile import Profile # type: ignore
from evidently.pipeline.column_mapping import ColumnMapping # type: ignore
from evidently.profile_sections import ( # type: ignore
Expand All @@ -26,9 +27,18 @@
from evidently.profile_sections.base_profile_section import ( # type: ignore
ProfileSection,
)
from evidently.tabs import ( # type: ignore
CatTargetDriftTab,
ClassificationPerformanceTab,
DataDriftTab,
NumTargetDriftTab,
ProbClassificationPerformanceTab,
RegressionPerformanceTab,
)
from evidently.tabs.base_tab import Tab # type: ignore

from zenml.artifacts import DataArtifact
from zenml.steps import StepContext
from zenml.artifacts import DataAnalysisArtifact, DataArtifact
from zenml.steps import Output, StepContext
from zenml.steps.step_interfaces.base_drift_detection_step import (
BaseDriftDetectionConfig,
BaseDriftDetectionStep,
Expand All @@ -43,6 +53,15 @@
"probabilisticmodelperformance": ProbClassificationPerformanceProfileSection,
}

dashboard_mapper = {
"datadrift": DataDriftTab,
"categoricaltargetdrift": CatTargetDriftTab,
"numericaltargetdrift": NumTargetDriftTab,
"classificationmodelperformance": ClassificationPerformanceTab,
"regressionmodelperformance": RegressionPerformanceTab,
"probabilisticmodelperformance": ProbClassificationPerformanceTab,
}


class EvidentlyProfileConfig(BaseDriftDetectionConfig):
"""Config class for Evidently profile steps.
Expand All @@ -58,34 +77,50 @@ class EvidentlyProfileConfig(BaseDriftDetectionConfig):
- "probabilisticmodelperformance"
"""

def get_profile_sections(self) -> ProfileSection:
def get_profile_sections_and_tabs(
self,
) -> Tuple[List[ProfileSection], List[Tab]]:
try:
return [
profile_mapper[profile]() for profile in self.profile_section
]
return (
[
profile_mapper[profile]()
for profile in self.profile_sections
],
[
dashboard_mapper[profile]()
for profile in self.profile_sections
],
)
except KeyError:
nl = "\n"
raise ValueError(
f"Invalid profile section: {self.profile_section} \n\n"
f"Invalid profile section: {self.profile_sections} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
)

column_mapping: Optional[ColumnMapping]
profile_section: Sequence[str]
profile_sections: Sequence[str]


class EvidentlyProfileStep(BaseDriftDetectionStep):
"""Simple step implementation which implements Evidently's functionality for
creating a profile."""

OUTPUT_SPEC = {
"profile": DataAnalysisArtifact,
"dashboard": DataAnalysisArtifact,
}

def entrypoint( # type: ignore[override]
self,
reference_dataset: DataArtifact,
comparison_dataset: DataArtifact,
config: EvidentlyProfileConfig,
context: StepContext,
) -> dict: # type: ignore[type-arg]
) -> Output( # type:ignore[valid-type]
profile=dict, dashboard=str
):
"""Main entrypoint for the Evidently categorical target drift detection
step.
Expand All @@ -97,13 +132,22 @@ def entrypoint( # type: ignore[override]
context: the context of the step
Returns:
a dict containing the results of the drift detection
profile: dictionary report extracted from an Evidently Profile
generated for the data drift
dashboard: HTML report extracted from an Evidently Dashboard
generated for the data drift
"""

data_drift_profile = Profile(sections=config.get_profile_sections())
sections, tabs = config.get_profile_sections_and_tabs()
data_drift_dashboard = Dashboard(tabs=tabs)
data_drift_dashboard.calculate(
reference_dataset,
comparison_dataset,
column_mapping=config.column_mapping or None,
)
data_drift_profile = Profile(sections=sections)
data_drift_profile.calculate(
reference_dataset,
comparison_dataset,
column_mapping=config.column_mapping or None,
)
return cast(dict, data_drift_profile.object()) # type: ignore[type-arg]
return [data_drift_profile.object(), data_drift_dashboard.html()]
4 changes: 2 additions & 2 deletions src/zenml/steps/step_interfaces/base_drift_detection_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.

from abc import abstractmethod
from typing import Dict
from typing import Any

from zenml.artifacts import DataArtifact
from zenml.steps import BaseStep, BaseStepConfig, StepContext
Expand All @@ -36,5 +36,5 @@ def entrypoint( # type: ignore[override]
comparison_dataset: DataArtifact,
config: BaseDriftDetectionConfig,
context: StepContext,
) -> Dict: # type: ignore[type-arg]
) -> Any:
"""Base entrypoint for any drift detection implementation"""

0 comments on commit 6b87e6b

Please sign in to comment.