Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step context #196

Merged
merged 4 commits into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class StepInterfaceError(Exception):
in an unsupported way."""


class StepContextError(Exception):
"""Raises exception when interacting with a StepContext
in an unsupported way."""


class PipelineInterfaceError(Exception):
"""Raises exception when interacting with the Pipeline interface
in an unsupported way."""
Expand Down
14 changes: 14 additions & 0 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
default_materializer_registry,
)
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output
from zenml.steps.utils import (
INTERNAL_EXECUTION_PARAMETER_PREFIX,
Expand Down Expand Up @@ -72,6 +73,7 @@ def __new__(
cls.OUTPUT_SIGNATURE = {}
cls.CONFIG_PARAMETER_NAME = None
cls.CONFIG_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None

# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
Expand Down Expand Up @@ -127,6 +129,16 @@ def __new__(
)
cls.CONFIG_PARAMETER_NAME = arg
cls.CONFIG_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get passed later, so we simply store the
Expand Down Expand Up @@ -170,10 +182,12 @@ class BaseStep(metaclass=BaseStepMeta):
OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
CONFIG_PARAMETER_NAME: ClassVar[Optional[str]] = None
CONFIG_CLASS: ClassVar[Optional[Type[BaseStepConfig]]] = None
CONTEXT_PARAMETER_NAME: ClassVar[Optional[str]] = None

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.step_name = self.__class__.__name__
self.enable_cache = getattr(self, PARAM_ENABLE_CACHE)
self.requires_context = bool(self.CONTEXT_PARAMETER_NAME)

self.PARAM_SPEC: Dict[str, Any] = {}
self.INPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
Expand Down
162 changes: 162 additions & 0 deletions src/zenml/steps/step_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Type, cast

from zenml.exceptions import StepContextError

if TYPE_CHECKING:
from zenml.artifacts.base_artifact import BaseArtifact
from zenml.materializers.base_materializer import BaseMaterializer


class StepContextOutput(NamedTuple):
"""Tuple containing materializer class and artifact for a step output."""

materializer_class: Type["BaseMaterializer"]
artifact: "BaseArtifact"


class StepContext:
"""Provides additional context inside a step function.
This class is used to access materializers and artifact URIs inside
a step function. To use it, add a `StepContext` object to the signature
of your step function like this:
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
You do not need to create a `StepContext` object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the `StepContext` and automatically pass it when executing your
step.
"""

def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifacts: Dict[str, "BaseArtifact"],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifacts: The output artifacts of the step that this
context is used in.
Raises:
StepInterfaceError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifacts.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact "
f"keys: {set(output_artifacts)}"
)

self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifacts[key]
)
for key in output_materializers.keys()
}

def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step does not have any outputs."
)

if not output_name and output_count > 1:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)

if output_name:
if output_name not in self._outputs:
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step '{self.step_name}'. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))

def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
materializer_class, artifact = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact)

def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
return cast(str, self._get_output(output_name).artifact.uri)
9 changes: 9 additions & 0 deletions src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output
from zenml.utils import source_utils

Expand Down Expand Up @@ -345,6 +346,14 @@ def Do(
getattr(self, PARAM_STEP_NAME), missing_fields, arg_type
) from None
function_params[arg] = config_object
elif issubclass(arg_type, StepContext):
output_artifacts = {k: v[0] for k, v in output_dict.items()}
context = StepContext(
step_name=getattr(self, PARAM_STEP_NAME),
output_materializers=self.materializers or {},
output_artifacts=output_artifacts,
)
function_params[arg] = context
else:
# At this point, it has to be an artifact, so we resolve
function_params[arg] = self.resolve_input_artifact(
Expand Down
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,13 @@

import pytest

from zenml.artifacts.base_artifact import BaseArtifact
from zenml.constants import ENV_ZENML_DEBUG
from zenml.core.repo import Repository
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.pipelines import pipeline
from zenml.steps import step
from zenml.steps.step_context import StepContext


def pytest_sessionstart(session):
Expand Down Expand Up @@ -437,3 +440,34 @@ def _step(input_1: int, input_2: int):
pass

return _step


@pytest.fixture
def step_context_with_no_output():
return StepContext(
step_name="", output_materializers={}, output_artifacts={}
)


@pytest.fixture
def step_context_with_single_output():
materializers = {"output_1": BaseMaterializer}
artifacts = {"output_1": BaseArtifact()}

return StepContext(
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)


@pytest.fixture
def step_context_with_two_outputs():
materializers = {"output_1": BaseMaterializer, "output_2": BaseMaterializer}
artifacts = {"output_1": BaseArtifact(), "output_2": BaseArtifact()}

return StepContext(
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)
11 changes: 11 additions & 0 deletions tests/steps/test_base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from zenml.materializers.built_in_materializer import BuiltInMaterializer
from zenml.steps import step
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output


Expand Down Expand Up @@ -57,6 +58,16 @@ def some_step(
pass


def test_define_step_with_multiple_contexts():
"""Tests that defining a step with multiple contexts raises
a StepInterfaceError."""
with pytest.raises(StepInterfaceError):

@step
def some_step(first_context: StepContext, second_context: StepContext):
pass


def test_define_step_without_input_annotation():
"""Tests that defining a step with a missing input annotation raises
a StepInterfaceError."""
Expand Down
Loading