Skip to content

Commit

Permalink
Add exception texts and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Nov 27, 2021
1 parent adec502 commit 84e2fe1
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 27 deletions.
5 changes: 5 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class StepInterfaceError(Exception):
in an unsupported way."""


class StepContextError(Exception):
"""Raises exception when interacting with a StepContext
in an unsupported way."""


class PipelineInterfaceError(Exception):
"""Raises exception when interacting with the Pipeline interface
in an unsupported way."""
Expand Down
123 changes: 112 additions & 11 deletions src/zenml/steps/step_context.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,64 @@
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Optional, Type
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Type, cast

from zenml.exceptions import StepInterfaceError
from zenml.exceptions import StepContextError

if TYPE_CHECKING:
from zenml.artifacts.base_artifact import BaseArtifact
from zenml.materializers.base_materializer import BaseMaterializer


StepContextOutput = namedtuple(
"StepContextOutput", ["materializer_class", "artifact"]
)
class StepContextOutput(NamedTuple):
"""Tuple containing materializer class and artifact for a step output."""

materializer_class: Type["BaseMaterializer"]
artifact: "BaseArtifact"


class StepContext:
"""Provides additional context inside a step function.
This class is used to access materializers and artifact URIs inside
a step function. To use it, add a `StepContext` object to the signature
of your step function like this:
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
You do not need to create a `StepContext` object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the `StepContext` and automatically pass it when executing your
step.
"""

def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifacts: Dict[str, "BaseArtifact"],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifacts: The output artifacts of the step that this
context is used in.
Raises:
StepInterfaceError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifacts.keys():
raise StepInterfaceError()
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact "
f"keys: {set(output_artifacts)}"
)

self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifacts[key]
Expand All @@ -32,16 +69,43 @@ def __init__(
def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepInterfaceError()
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step does not have any outputs."
)

if not output_name and output_count > 1:
raise StepInterfaceError()
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)

if output_name:
if output_name not in self._outputs:
raise StepInterfaceError()
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step '{self.step_name}'. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))
Expand All @@ -51,11 +115,48 @@ def get_output_materializer(
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
materializer_class, artifact = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact)

def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
return self._get_output(output_name).artifact.uri
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
return cast(str, self._get_output(output_name).artifact.uri)
5 changes: 2 additions & 3 deletions src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@ class _FunctionExecutor(BaseExecutor):
"""Base TFX Executor class which is compatible with ZenML steps"""

_FUNCTION = staticmethod(lambda: None)
materializers: ClassVar[
Optional[Dict[str, Type["BaseMaterializer"]]]
] = None
materializers: ClassVar[Dict[str, Type["BaseMaterializer"]]] = {}

def resolve_materializer_with_registry(
self, param_name: str, artifact: BaseArtifact
Expand Down Expand Up @@ -349,6 +347,7 @@ def Do(
elif issubclass(arg_type, StepContext):
output_artifacts = {k: v[0] for k, v in output_dict.items()}
context = StepContext(
step_name=getattr(self, PARAM_STEP_NAME),
output_materializers=self.materializers,
output_artifacts=output_artifacts,
)
Expand Down
12 changes: 9 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def _step(input_1: int, input_2: int):

@pytest.fixture
def step_context_with_no_output():
return StepContext(output_materializers={}, output_artifacts={})
return StepContext(
step_name="", output_materializers={}, output_artifacts={}
)


@pytest.fixture
Expand All @@ -453,7 +455,9 @@ def step_context_with_single_output():
artifacts = {"output_1": BaseArtifact()}

return StepContext(
output_materializers=materializers, output_artifacts=artifacts
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)


Expand All @@ -463,5 +467,7 @@ def step_context_with_two_outputs():
artifacts = {"output_1": BaseArtifact(), "output_2": BaseArtifact()}

return StepContext(
output_materializers=materializers, output_artifacts=artifacts
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)
24 changes: 14 additions & 10 deletions tests/steps/test_step_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from zenml.artifacts.base_artifact import BaseArtifact
from zenml.exceptions import StepInterfaceError
from zenml.exceptions import StepContextError
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.materializers.built_in_materializer import BuiltInMaterializer
from zenml.steps.step_context import StepContext
Expand All @@ -29,9 +29,11 @@ def test_initialize_step_context_with_mismatched_keys():
materializers = {"some_output_name": BaseMaterializer}
artifacts = {"some_different_output_name": BaseArtifact()}

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
StepContext(
output_materializers=materializers, output_artifacts=artifacts
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)


Expand All @@ -44,7 +46,9 @@ def test_initialize_step_context_with_matching_keys():

with does_not_raise():
StepContext(
output_materializers=materializers, output_artifacts=artifacts
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)


Expand All @@ -54,10 +58,10 @@ def test_get_step_context_output_for_step_with_no_outputs(
"""Tests that getting the artifact uri or materializer for a step context
with no outputs raises an exception."""

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
step_context_with_no_output.get_output_artifact_uri()

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
step_context_with_no_output.get_output_materializer()


Expand All @@ -78,10 +82,10 @@ def test_get_step_context_output_for_step_with_multiple_outputs(
"""Tests that getting the artifact uri or materializer for a step context
with multiple outputs raises an exception."""

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
step_context_with_two_outputs.get_output_artifact_uri()

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
step_context_with_two_outputs.get_output_materializer()


Expand All @@ -91,12 +95,12 @@ def test_get_step_context_output_for_non_existent_output_key(
"""Tests that getting the artifact uri or materializer for a non-existent
output raises an exception."""

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
step_context_with_single_output.get_output_artifact_uri(
"some_non_existent_output_name"
)

with pytest.raises(StepInterfaceError):
with pytest.raises(StepContextError):
step_context_with_single_output.get_output_materializer(
"some_non_existent_output_name"
)
Expand Down

0 comments on commit 84e2fe1

Please sign in to comment.