Skip to content

Commit

Permalink
Merge pull request #196 from zenml-io/michael/step-context
Browse files Browse the repository at this point in the history
Add step context
  • Loading branch information
htahir1 authored Nov 27, 2021
2 parents f5e7f68 + 4ed49d4 commit 10839e2
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class StepInterfaceError(Exception):
in an unsupported way."""


class StepContextError(Exception):
"""Raises exception when interacting with a StepContext
in an unsupported way."""


class PipelineInterfaceError(Exception):
"""Raises exception when interacting with the Pipeline interface
in an unsupported way."""
Expand Down
14 changes: 14 additions & 0 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
default_materializer_registry,
)
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output
from zenml.steps.utils import (
INTERNAL_EXECUTION_PARAMETER_PREFIX,
Expand Down Expand Up @@ -72,6 +73,7 @@ def __new__(
cls.OUTPUT_SIGNATURE = {}
cls.CONFIG_PARAMETER_NAME = None
cls.CONFIG_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None

# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
Expand Down Expand Up @@ -127,6 +129,16 @@ def __new__(
)
cls.CONFIG_PARAMETER_NAME = arg
cls.CONFIG_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get passed later, so we simply store the
Expand Down Expand Up @@ -170,10 +182,12 @@ class BaseStep(metaclass=BaseStepMeta):
OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
CONFIG_PARAMETER_NAME: ClassVar[Optional[str]] = None
CONFIG_CLASS: ClassVar[Optional[Type[BaseStepConfig]]] = None
CONTEXT_PARAMETER_NAME: ClassVar[Optional[str]] = None

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.step_name = self.__class__.__name__
self.enable_cache = getattr(self, PARAM_ENABLE_CACHE)
self.requires_context = bool(self.CONTEXT_PARAMETER_NAME)

self.PARAM_SPEC: Dict[str, Any] = {}
self.INPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
Expand Down
162 changes: 162 additions & 0 deletions src/zenml/steps/step_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Type, cast

from zenml.exceptions import StepContextError

if TYPE_CHECKING:
from zenml.artifacts.base_artifact import BaseArtifact
from zenml.materializers.base_materializer import BaseMaterializer


class StepContextOutput(NamedTuple):
"""Tuple containing materializer class and artifact for a step output."""

materializer_class: Type["BaseMaterializer"]
artifact: "BaseArtifact"


class StepContext:
"""Provides additional context inside a step function.
This class is used to access materializers and artifact URIs inside
a step function. To use it, add a `StepContext` object to the signature
of your step function like this:
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
You do not need to create a `StepContext` object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the `StepContext` and automatically pass it when executing your
step.
"""

def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifacts: Dict[str, "BaseArtifact"],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifacts: The output artifacts of the step that this
context is used in.
Raises:
StepInterfaceError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifacts.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact "
f"keys: {set(output_artifacts)}"
)

self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifacts[key]
)
for key in output_materializers.keys()
}

def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step does not have any outputs."
)

if not output_name and output_count > 1:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)

if output_name:
if output_name not in self._outputs:
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step '{self.step_name}'. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))

def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
materializer_class, artifact = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact)

def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
Raises:
StepInterfaceError: If the step has no outputs, no output for
the given `output_name` or if no `output_name` was given but
the step has multiple outputs.
"""
return cast(str, self._get_output(output_name).artifact.uri)
9 changes: 9 additions & 0 deletions src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output
from zenml.utils import source_utils

Expand Down Expand Up @@ -345,6 +346,14 @@ def Do(
getattr(self, PARAM_STEP_NAME), missing_fields, arg_type
) from None
function_params[arg] = config_object
elif issubclass(arg_type, StepContext):
output_artifacts = {k: v[0] for k, v in output_dict.items()}
context = StepContext(
step_name=getattr(self, PARAM_STEP_NAME),
output_materializers=self.materializers or {},
output_artifacts=output_artifacts,
)
function_params[arg] = context
else:
# At this point, it has to be an artifact, so we resolve
function_params[arg] = self.resolve_input_artifact(
Expand Down
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,13 @@

import pytest

from zenml.artifacts.base_artifact import BaseArtifact
from zenml.constants import ENV_ZENML_DEBUG
from zenml.core.repo import Repository
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.pipelines import pipeline
from zenml.steps import step
from zenml.steps.step_context import StepContext


def pytest_sessionstart(session):
Expand Down Expand Up @@ -437,3 +440,34 @@ def _step(input_1: int, input_2: int):
pass

return _step


@pytest.fixture
def step_context_with_no_output():
return StepContext(
step_name="", output_materializers={}, output_artifacts={}
)


@pytest.fixture
def step_context_with_single_output():
materializers = {"output_1": BaseMaterializer}
artifacts = {"output_1": BaseArtifact()}

return StepContext(
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)


@pytest.fixture
def step_context_with_two_outputs():
materializers = {"output_1": BaseMaterializer, "output_2": BaseMaterializer}
artifacts = {"output_1": BaseArtifact(), "output_2": BaseArtifact()}

return StepContext(
step_name="",
output_materializers=materializers,
output_artifacts=artifacts,
)
11 changes: 11 additions & 0 deletions tests/steps/test_base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from zenml.materializers.built_in_materializer import BuiltInMaterializer
from zenml.steps import step
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output


Expand Down Expand Up @@ -57,6 +58,16 @@ def some_step(
pass


def test_define_step_with_multiple_contexts():
"""Tests that defining a step with multiple contexts raises
a StepInterfaceError."""
with pytest.raises(StepInterfaceError):

@step
def some_step(first_context: StepContext, second_context: StepContext):
pass


def test_define_step_without_input_annotation():
"""Tests that defining a step with a missing input annotation raises
a StepInterfaceError."""
Expand Down
Loading

0 comments on commit 10839e2

Please sign in to comment.