Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step context #196

Merged
merged 4 commits into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial implementation of step context
  • Loading branch information
schustmi committed Nov 26, 2021
commit 64b89ef2a0b36b9e0dc51201d9a7d9ab5015797f
14 changes: 14 additions & 0 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
default_materializer_registry,
)
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output
from zenml.steps.utils import (
INTERNAL_EXECUTION_PARAMETER_PREFIX,
Expand Down Expand Up @@ -72,6 +73,7 @@ def __new__(
cls.OUTPUT_SIGNATURE = {}
cls.CONFIG_PARAMETER_NAME = None
cls.CONFIG_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None

# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
Expand Down Expand Up @@ -127,6 +129,16 @@ def __new__(
)
cls.CONFIG_PARAMETER_NAME = arg
cls.CONFIG_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get passed later, so we simply store the
Expand Down Expand Up @@ -170,10 +182,12 @@ class BaseStep(metaclass=BaseStepMeta):
OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
CONFIG_PARAMETER_NAME: ClassVar[Optional[str]] = None
CONFIG_CLASS: ClassVar[Optional[Type[BaseStepConfig]]] = None
CONTEXT_PARAMETER_NAME: ClassVar[Optional[str]] = None

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.step_name = self.__class__.__name__
self.enable_cache = getattr(self, PARAM_ENABLE_CACHE)
self.requires_context = bool(self.CONTEXT_PARAMETER_NAME)

self.PARAM_SPEC: Dict[str, Any] = {}
self.INPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
Expand Down
59 changes: 59 additions & 0 deletions src/zenml/steps/step_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Optional, Type

from zenml.exceptions import StepInterfaceError

if TYPE_CHECKING:
from zenml.artifacts.base_artifact import BaseArtifact
from zenml.materializers.base_materializer import BaseMaterializer


StepContextOutput = namedtuple(
"StepContextOutput", ["materializer_class", "artifact"]
)


class StepContext:
def __init__(
self,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifacts: Dict[str, "BaseArtifact"],
):
if output_materializers.keys() != output_artifacts.keys():
raise StepInterfaceError()

self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifacts[key]
)
for key in output_materializers.keys()
}

def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
output_count = len(self._outputs)
if output_count == 0:
raise StepInterfaceError()

if not output_name and output_count > 1:
raise StepInterfaceError()

if output_name:
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))

def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
materializer_class, artifact = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact)

def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
return self._get_output(output_name).artifact.uri
8 changes: 8 additions & 0 deletions src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_context import StepContext
from zenml.steps.step_output import Output
from zenml.utils import source_utils

Expand Down Expand Up @@ -345,6 +346,13 @@ def Do(
getattr(self, PARAM_STEP_NAME), missing_fields, arg_type
) from None
function_params[arg] = config_object
elif issubclass(arg_type, StepContext):
output_artifacts = {k: v[0] for k, v in output_dict.items()}
context = StepContext(
output_materializers=self.materializers,
output_artifacts=output_artifacts,
)
function_params[arg] = context
else:
# At this point, it has to be an artifact, so we resolve
function_params[arg] = self.resolve_input_artifact(
Expand Down