diff --git a/src/zenml/materializers/base_materializer.py b/src/zenml/materializers/base_materializer.py index db2e40a4d0d..65dae5e9023 100644 --- a/src/zenml/materializers/base_materializer.py +++ b/src/zenml/materializers/base_materializer.py @@ -39,10 +39,11 @@ def __new__( "You should specify a list of ASSOCIATED_TYPES when creating a " "Materializer!" ) - [ - default_materializer_registry.register_materializer_type(x, cls) - for x in cls.ASSOCIATED_TYPES - ] + for associated_type in cls.ASSOCIATED_TYPES: + default_materializer_registry.register_materializer_type( + associated_type, cls + ) + return cls diff --git a/src/zenml/materializers/default_materializer_registry.py b/src/zenml/materializers/default_materializer_registry.py index ca4e9afe625..b507211a7c8 100644 --- a/src/zenml/materializers/default_materializer_registry.py +++ b/src/zenml/materializers/default_materializer_registry.py @@ -12,7 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Type +from typing import TYPE_CHECKING, Any, Dict, Type from zenml.logger import get_logger @@ -22,14 +22,14 @@ from zenml.materializers.base_materializer import BaseMaterializer -class DefaultMaterializerRegistry(object): +class MaterializerRegistry: """Matches a python type to a default materializer.""" - materializer_types: ClassVar[Dict[Type[Any], Type["BaseMaterializer"]]] = {} + def __init__(self) -> None: + self.materializer_types: Dict[Type[Any], Type["BaseMaterializer"]] = {} - @classmethod def register_materializer_type( - cls, key: Type[Any], type_: Type["BaseMaterializer"] + self, key: Type[Any], type_: Type["BaseMaterializer"] ) -> None: """Registers a new materializer. @@ -37,13 +37,13 @@ def register_materializer_type( key: Indicates the type of an object. type_: A BaseMaterializer subclass. """ - if key not in cls.materializer_types: - cls.materializer_types[key] = type_ + if key not in self.materializer_types: + self.materializer_types[key] = type_ logger.debug(f"Registered materializer {type_} for {key}") else: logger.debug( f"{key} already registered with " - f"{cls.materializer_types[key]}. Cannot register {type_}." + f"{self.materializer_types[key]}. Cannot register {type_}." ) def register_and_overwrite_type( @@ -58,17 +58,14 @@ def register_and_overwrite_type( self.materializer_types[key] = type_ logger.debug(f"Registered materializer {type_} for {key}") - def get_single_materializer_type( - self, key: Type[Any] - ) -> Type["BaseMaterializer"]: + def __getitem__(self, key: Type[Any]) -> Type["BaseMaterializer"]: """Get a single materializers based on the key. Args: key: Indicates the type of an object. Returns: - Instance of a `BaseMaterializer` subclass initialized with the - artifact of this factory. + `BaseMaterializer` subclass that was registered for this key. """ if key in self.materializer_types: return self.materializer_types[key] @@ -81,14 +78,12 @@ def get_single_materializer_type( def get_materializer_types( self, ) -> Dict[Type[Any], Type["BaseMaterializer"]]: - """Get all registered materializers.""" + """Get all registered materializer types.""" return self.materializer_types def is_registered(self, key: Type[Any]) -> bool: - """Returns true if key type is registered, else returns False.""" - if key in self.materializer_types: - return True - return False + """Returns if a materializer class is registered for the given type.""" + return key in self.materializer_types -default_materializer_registry = DefaultMaterializerRegistry() +default_materializer_registry = MaterializerRegistry() diff --git a/src/zenml/materializers/spec_materializer_registry.py b/src/zenml/materializers/spec_materializer_registry.py deleted file mode 100644 index ae10f0c55a8..00000000000 --- a/src/zenml/materializers/spec_materializer_registry.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) ZenML GmbH 2021. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. - -from typing import TYPE_CHECKING, Dict, Type - -from zenml.logger import get_logger - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from zenml.materializers.base_materializer import BaseMaterializer - - -class SpecMaterializerRegistry: - """Matches spec of a step to a materializer.""" - - def __init__(self) -> None: - """Materializer types registry.""" - self.materializer_types: Dict[str, Type["BaseMaterializer"]] = {} - - def register_materializer_type( - self, key: str, type_: Type["BaseMaterializer"] - ) -> None: - """Registers a new materializer. - - Args: - key: Name of input or output parameter. - type_: A BaseMaterializer subclass. - """ - self.materializer_types[key] = type_ - logger.debug(f"Registered materializer {type_} for {key}") - - def get_materializer_types( - self, - ) -> Dict[str, Type["BaseMaterializer"]]: - """Get all registered materializers.""" - return self.materializer_types - - def get_single_materializer_type( - self, key: str - ) -> Type["BaseMaterializer"]: - """Gets a single pre-registered materializer type based on `key`.""" - if key in self.materializer_types: - return self.materializer_types[key] - logger.debug( - "Tried to fetch %s but its not registered. Available keys: %s", - key, - self.materializer_types.keys(), - ) - raise KeyError( - f"Key '{key}' does not have a registered `Materializer`!" - ) - - def is_registered(self, key: str) -> bool: - """Returns true if key type is registered, else returns False.""" - if key in self.materializer_types: - return True - return False diff --git a/src/zenml/post_execution/artifact.py b/src/zenml/post_execution/artifact.py index 41fc3e0c166..d24da839e72 100644 --- a/src/zenml/post_execution/artifact.py +++ b/src/zenml/post_execution/artifact.py @@ -113,3 +113,10 @@ def __repr__(self) -> str: f"type='{self._type}', uri='{self._uri}', " f"materializer='{self._materializer}')" ) + + def __eq__(self, other: Any) -> bool: + """Returns whether the other object is referring to the + same artifact.""" + if isinstance(other, ArtifactView): + return self._id == other._id and self._uri == other._uri + return NotImplemented diff --git a/src/zenml/post_execution/pipeline.py b/src/zenml/post_execution/pipeline.py index 7cffb449183..5bafd9540a1 100644 --- a/src/zenml/post_execution/pipeline.py +++ b/src/zenml/post_execution/pipeline.py @@ -12,7 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any, List from zenml.logger import get_logger from zenml.post_execution.pipeline_run import PipelineRunView @@ -95,3 +95,13 @@ def __repr__(self) -> str: f"{self.__class__.__qualname__}(id={self._id}, " f"name='{self._name}')" ) + + def __eq__(self, other: Any) -> bool: + """Returns whether the other object is referring to the + same pipeline.""" + if isinstance(other, PipelineView): + return ( + self._id == other._id + and self._metadata_store.uuid == other._metadata_store.uuid + ) + return NotImplemented diff --git a/src/zenml/post_execution/pipeline_run.py b/src/zenml/post_execution/pipeline_run.py index 0ba74e46cfb..453925d87af 100644 --- a/src/zenml/post_execution/pipeline_run.py +++ b/src/zenml/post_execution/pipeline_run.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List from ml_metadata import proto @@ -121,3 +121,13 @@ def __repr__(self) -> str: f"{self.__class__.__qualname__}(id={self._id}, " f"name='{self._name}')" ) + + def __eq__(self, other: Any) -> bool: + """Returns whether the other object is referring to the same + pipeline run.""" + if isinstance(other, PipelineRunView): + return ( + self._id == other._id + and self._metadata_store.uuid == other._metadata_store.uuid + ) + return NotImplemented diff --git a/src/zenml/post_execution/step.py b/src/zenml/post_execution/step.py index 2761b8053a0..a211e881309 100644 --- a/src/zenml/post_execution/step.py +++ b/src/zenml/post_execution/step.py @@ -137,3 +137,12 @@ def __repr__(self) -> str: f"{self.__class__.__qualname__}(id={self._id}, " f"name='{self._name}', parameters={self._parameters})" ) + + def __eq__(self, other: Any) -> bool: + """Returns whether the other object is referring to the same step.""" + if isinstance(other, StepView): + return ( + self._id == other._id + and self._metadata_store.uuid == other._metadata_store.uuid + ) + return NotImplemented diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index c03b51d6097..e4a1f7842f3 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -39,9 +39,6 @@ from zenml.materializers.default_materializer_registry import ( default_materializer_registry, ) -from zenml.materializers.spec_materializer_registry import ( - SpecMaterializerRegistry, -) from zenml.steps.base_step_config import BaseStepConfig from zenml.steps.step_output import Output from zenml.steps.utils import ( @@ -56,19 +53,6 @@ logger = get_logger(__name__) -def check_dict_keys_match(x: Dict[Any, Any], y: Dict[Any, Any]) -> bool: - """Checks whether there is even one key shared between two dicts. - - Returns: - True if there is a shared key, otherwise False. - """ - shared_items = {k: x[k] for k in x if k in y and x[k] == y[k]} - if len(shared_items) == 0: - return False - logger.debug(f"Matched keys for dicts {x} and {y}: {shared_items}") - return True - - class BaseStepMeta(type): """Meta class for `BaseStep`. @@ -89,53 +73,86 @@ def __new__( cls.CONFIG_PARAMETER_NAME = None cls.CONFIG_CLASS = None - # Looking into the signature of the provided process function - process_spec = inspect.getfullargspec( + # Get the signature of the step function + step_function_signature = inspect.getfullargspec( getattr(cls, STEP_INNER_FUNC_NAME) ) - process_args = process_spec.args - logger.debug(f"{name} args: {process_args}") - - # Remove the self from the signature if it exists - if process_args and process_args[0] == "self": - process_args.pop(0) - - # Parse the input signature of the function - for arg in process_args: - arg_type = process_spec.annotations.get(arg, None) - # Check whether its a `BaseStepConfig` or a registered - # materializer type. + + if bases: + # We're not creating the abstract `BaseStep` class + # but a concrete implementation. Make sure the step function + # signature does not contain variable *args or **kwargs + variable_arguments = None + if step_function_signature.varargs: + variable_arguments = f"*{step_function_signature.varargs}" + elif step_function_signature.varkw: + variable_arguments = f"**{step_function_signature.varkw}" + + if variable_arguments: + raise StepInterfaceError( + f"Unable to create step '{name}' with variable arguments " + f"'{variable_arguments}'. Please make sure your step " + f"functions are defined with a fixed amount of arguments." + ) + + step_function_args = ( + step_function_signature.args + step_function_signature.kwonlyargs + ) + + # Remove 'self' from the signature if it exists + if step_function_args and step_function_args[0] == "self": + step_function_args.pop(0) + + # Verify the input arguments of the step function + for arg in step_function_args: + arg_type = step_function_signature.annotations.get(arg, None) + + if not arg_type: + raise StepInterfaceError( + f"Missing type annotation for argument '{arg}' when " + f"trying to create step '{name}'. Please make sure to " + f"include type annotations for all your step inputs " + f"and outputs." + ) + if issubclass(arg_type, BaseStepConfig): - # It needs to be None at this point, otherwise multi configs. + # Raise an error if we already found a config in the signature if cls.CONFIG_CLASS is not None: raise StepInterfaceError( - "Please only use one `BaseStepConfig` type object in " - "your step." + f"Found multiple configuration arguments " + f"('{cls.CONFIG_PARAMETER_NAME}' and '{arg}') when " + f"trying to create step '{name}'. Please make sure to " + f"only have one `BaseStepConfig` subclass as input " + f"argument for a step." ) cls.CONFIG_PARAMETER_NAME = arg cls.CONFIG_CLASS = arg_type else: + # Can't do any check for existing materializers right now + # as they might get passed later, so we simply store the + # argument name and type for later use. cls.INPUT_SIGNATURE.update({arg: arg_type}) - # Infer the returned values - return_spec = process_spec.annotations.get("return", None) - if return_spec is not None: - if isinstance(return_spec, Output): - # If its a named, potentially multi, outputs we go through - # each and create a spec. - for return_tuple in return_spec.items(): - cls.OUTPUT_SIGNATURE.update( - {return_tuple[0]: return_tuple[1]} - ) + # Parse the returns of the step function + return_type = step_function_signature.annotations.get("return", None) + if return_type is not None: + if isinstance(return_type, Output): + cls.OUTPUT_SIGNATURE = dict(return_type.items()) else: - # If its one output, then give it a single return name. - cls.OUTPUT_SIGNATURE.update( - {SINGLE_RETURN_OUT_NAME: return_spec} - ) - - if check_dict_keys_match(cls.INPUT_SIGNATURE, cls.OUTPUT_SIGNATURE): + cls.OUTPUT_SIGNATURE[SINGLE_RETURN_OUT_NAME] = return_type + + # Raise an exception if input and output names of a step overlap as + # tfx requires them to be unique + # TODO [MEDIUM]: Can we prefix inputs and outputs to avoid this + # restriction? + shared_input_output_keys = set(cls.INPUT_SIGNATURE).intersection( + set(cls.OUTPUT_SIGNATURE) + ) + if shared_input_output_keys: raise StepInterfaceError( - "The input names and output names cannot be the same!" + f"There is an overlap in the input and output names of " + f"step '{name}': {shared_input_output_keys}. Please make " + f"sure that your input and output names are distinct." ) return cls @@ -155,40 +172,92 @@ class BaseStep(metaclass=BaseStepMeta): CONFIG_CLASS: ClassVar[Optional[Type[BaseStepConfig]]] = None def __init__(self, *args: Any, **kwargs: Any) -> None: - self.materializers: Dict[str, Type[BaseMaterializer]] = {} - self.__component = None self.step_name = self.__class__.__name__ self.enable_cache = getattr(self, PARAM_ENABLE_CACHE) + self.PARAM_SPEC: Dict[str, Any] = {} self.INPUT_SPEC: Dict[str, Type[BaseArtifact]] = {} self.OUTPUT_SPEC: Dict[str, Type[BaseArtifact]] = {} - self.spec_materializer_registry = SpecMaterializerRegistry() + self._explicit_materializers: Dict[str, Type[BaseMaterializer]] = {} + self.__component = None self._verify_arguments(*args, **kwargs) - @property - def _internal_execution_properties(self) -> Dict[str, str]: - """ZenML internal execution properties for this step. + def get_materializers( + self, ensure_complete: bool = False + ) -> Dict[str, Type[BaseMaterializer]]: + """Returns available materializers for the outputs of this step. + + Args: + ensure_complete: If set to `True`, this method will raise a + `StepInterfaceError` if no materializer can be found for an + output. + + Returns: + A dictionary mapping output names to `BaseMaterializer` subclasses. + If no explicit materializer was set using + `step.with_return_materializers(...)`, this checks the + default materializer registry to find a materializer for the + type of the output. If no materializer is registered, the + output of this method will not contain an entry for this output. - **IMPORTANT**: When modifying this dictionary, make sure to - prefix the key with `INTERNAL_EXECUTION_PARAMETER_PREFIX` and serialize - the value using `json.dumps(...)`. + Raises: + StepInterfaceError: (Only if `ensure_complete` is set to `True`) + If an output does not have an explicit materializer assigned + to it and we there is no default materializer registered for + the output type. """ + materializers = self._explicit_materializers + + for output_name, output_type in self.OUTPUT_SIGNATURE.items(): + if output_name in materializers: + # Materializer for this output was set explicitly + pass + elif default_materializer_registry.is_registered(output_type): + materializer = default_materializer_registry[output_type] + materializers[output_name] = materializer + else: + if ensure_complete: + raise StepInterfaceError( + f"Unable to find materializer for output " + f"'{output_name}' of type `{output_type}` in step " + f"'{self.step_name}'. Please make sure to either " + f"explicitly set a materializer for step outputs " + f"using `step.with_return_materializers(...)` or " + f"registering a default materializer for specific " + f"types by subclassing `BaseMaterializer` and setting " + f"its `ASSOCIATED_TYPES` class variable." + ) + + return materializers + + @property + def _internal_execution_parameters(self) -> Dict[str, str]: + """ZenML internal execution parameters for this step.""" properties = {} - if not self.enable_cache: - # add a random string to the execution properties to disable caching - key = INTERNAL_EXECUTION_PARAMETER_PREFIX + "disable_cache" - random_string = f"{random.getrandbits(128):032x}" - properties[key] = json.dumps(random_string) + if self.enable_cache: + # Caching is enabled so we compute a hash of the step function code + # and materializers to catch changes in the step behavior + + def _get_hashed_source(value: Any) -> str: + """Returns a hash of the objects source code.""" + source_code = inspect.getsource(value) + return hashlib.sha256(source_code.encode("utf-8")).hexdigest() + + properties["step_source"] = _get_hashed_source(self.process) + + for name, materializer in self.get_materializers().items(): + key = f"{name}_materializer_source" + properties[key] = _get_hashed_source(materializer) else: - # caching is enabled so we compute a hash of the step function code - # to catch changes in the step implementation - key = INTERNAL_EXECUTION_PARAMETER_PREFIX + "step_source" - step_source = inspect.getsource(self.process) - step_hash = hashlib.sha256(step_source.encode("utf-8")).hexdigest() - properties[key] = json.dumps(step_hash) + # Add a random string to the execution properties to disable caching + random_string = f"{random.getrandbits(128):032x}" + properties["disable_cache"] = random_string - return properties + return { + INTERNAL_EXECUTION_PARAMETER_PREFIX + key: value + for key, value in properties.items() + } def _verify_arguments(self, *args: Any, **kwargs: Any) -> None: """Verifies the initialization args and kwargs of this step. @@ -245,7 +314,7 @@ def _verify_arguments(self, *args: Any, **kwargs: Any) -> None: self.PARAM_SPEC = config.dict() - def _prepare_parameter_spec(self) -> None: + def _update_and_verify_parameter_spec(self) -> None: """Verifies and prepares the config parameters for running this step. When the step requires config parameters, this method: @@ -279,18 +348,6 @@ def _prepare_parameter_spec(self) -> None: self.step_name, missing_keys, self.CONFIG_CLASS ) - # convert config parameter values to strings - try: - self.PARAM_SPEC = { - k: json.dumps(v) for k, v in self.PARAM_SPEC.items() - } - except TypeError as e: - raise StepInterfaceError( - f"Failed to serialize config parameters for step " - f"'{self.step_name}'. Please make sure to only use " - f"json serializable parameter values." - ) from e - def _prepare_input_artifacts( self, *artifacts: Channel, **kw_artifacts: Channel ) -> Dict[str, Channel]: @@ -377,32 +434,46 @@ def __call__( ) -> Union[Channel, List[Channel]]: """Generates a component when called.""" # TODO [MEDIUM]: replaces Channels with ZenML class (BaseArtifact?) - self._prepare_parameter_spec() + self._update_and_verify_parameter_spec() - # Construct INPUT_SPEC from INPUT_SIGNATURE - self.resolve_signature_materializers(self.INPUT_SIGNATURE, True) - # Construct OUTPUT_SPEC from OUTPUT_SIGNATURE - self.resolve_signature_materializers(self.OUTPUT_SIGNATURE, False) + # Right now all artifacts are BaseArtifacts + self.INPUT_SPEC = {key: BaseArtifact for key in self.INPUT_SIGNATURE} + self.OUTPUT_SPEC = {key: BaseArtifact for key in self.OUTPUT_SIGNATURE} input_artifacts = self._prepare_input_artifacts( *artifacts, **kw_artifacts ) + execution_parameters = { + **self.PARAM_SPEC, + **self._internal_execution_parameters, + } + + # convert execution parameter values to strings + try: + execution_parameters = { + k: json.dumps(v) for k, v in execution_parameters.items() + } + except TypeError as e: + raise StepInterfaceError( + f"Failed to serialize execution parameters for step " + f"'{self.step_name}'. Please make sure to only use " + f"json serializable parameter values." + ) from e + self.__component = generate_component(self)( **input_artifacts, - **self.PARAM_SPEC, - **self._internal_execution_properties, + **execution_parameters, ) # Resolve the returns in the right order. - returns = [] - for k in self.OUTPUT_SPEC.keys(): - returns.append(getattr(self.component.outputs, k)) + returns = [self.component.outputs[key] for key in self.OUTPUT_SPEC] # If its one return we just return the one channel not as a list if len(returns) == 1: - returns = returns[0] - return returns + return returns[0] + else: + return returns @property def component(self) -> _ZenMLSimpleComponent: @@ -417,6 +488,7 @@ def component(self) -> _ZenMLSimpleComponent: @abstractmethod def process(self, *args: Any, **kwargs: Any) -> Any: """Abstract method for core step logic.""" + raise NotImplementedError def with_return_materializers( self: T, @@ -424,69 +496,63 @@ def with_return_materializers( Type[BaseMaterializer], Dict[str, Type[BaseMaterializer]] ], ) -> T: - """Inject materializers from the outside. If one materializer is passed - in then all outputs are assigned that materializer. If a dict is passed - in then we make sure the output names match. + """Register materializers for step outputs. + + If a single materializer is passed, it will be used for all step + outputs. Otherwise the dictionary keys specify the output names + for which the materializers will be used. Args: - materializers: A `BaseMaterializer` subclass or a dict mapping - output names to `BaseMaterializer` subclasses. - """ - if not isinstance(materializers, dict): - assert isinstance(materializers, type) and issubclass( - materializers, BaseMaterializer - ), "Need to pass in a subclass of `BaseMaterializer`!" - if len(self.OUTPUT_SIGNATURE) == 1: - # If only one return, assign to `SINGLE_RETURN_OUT_NAME`. - self.materializers = {SINGLE_RETURN_OUT_NAME: materializers} - else: - # If multi return, then assign to all. - self.materializers = { - k: materializers for k in self.OUTPUT_SIGNATURE - } - else: - # Check whether signature matches. - assert all([x in self.OUTPUT_SIGNATURE for x in materializers]), ( - f"One of {materializers.keys()} not defined in outputs: " - f"{self.OUTPUT_SIGNATURE.keys()}" - ) - self.materializers = materializers - return self + materializers: The materializers for the outputs of this step. - def resolve_signature_materializers( - self, signature: Dict[str, Type[Any]], is_input: bool = True - ) -> None: - """Takes either the INPUT_SIGNATURE and OUTPUT_SIGNATURE and resolves - the materializers for them in the `spec_materializer_registry`. + Returns: + The object that this method was called on. - Args: - signature: Either self.INPUT_SIGNATURE or self.OUTPUT_SIGNATURE. - is_input: If True, then self.INPUT_SPEC used, else self.OUTPUT_SPEC. + Raises: + StepInterfaceError: If a materializer is not a `BaseMaterializer` + subclass or a materializer for a non-existent output is given. """ - for arg, arg_type in signature.items(): - if arg in self.materializers: - self.spec_materializer_registry.register_materializer_type( - arg, self.materializers[arg] - ) - elif default_materializer_registry.is_registered(arg_type): - self.spec_materializer_registry.register_materializer_type( - arg, - default_materializer_registry.get_single_materializer_type( - arg_type - ), - ) - else: - raise StepInterfaceError( - f"Argument `{arg}` of type `{arg_type}` does not have an " - f"associated materializer. ZenML steps can only take input " - f"and output artifacts with an associated materializer. It " - f"looks like we do not have a default materializer for " - f"`{arg_type}`, and you have not provided a custom " - f"materializer either. Please do so and re-run the " - f"pipeline." - ) - spec = self.INPUT_SPEC if is_input else self.OUTPUT_SPEC - # For now, all artifacts are BaseArtifacts - for k in signature.keys(): - spec[k] = BaseArtifact + def _is_materializer_class(value: Any) -> bool: + """Checks whether the given object is a `BaseMaterializer` + subclass.""" + is_class = isinstance(value, type) + return is_class and issubclass(value, BaseMaterializer) + + if isinstance(materializers, dict): + allowed_output_names = set(self.OUTPUT_SIGNATURE) + + for output_name, materializer in materializers.items(): + if output_name not in allowed_output_names: + raise StepInterfaceError( + f"Got unexpected materializers for non-existent " + f"output '{output_name}' in step '{self.step_name}'. " + f"Only materializers for the outputs " + f"{allowed_output_names} of this step can" + f" be registered." + ) + + if not _is_materializer_class(materializer): + raise StepInterfaceError( + f"Got unexpected object `{materializer}` as " + f"materializer for output '{output_name}' of step " + f"'{self.step_name}'. Only `BaseMaterializer` " + f"subclasses are allowed." + ) + self._explicit_materializers[output_name] = materializer + + elif _is_materializer_class(materializers): + # Set the materializer for all outputs of this step + self._explicit_materializers = { + key: materializers for key in self.OUTPUT_SIGNATURE + } + else: + raise StepInterfaceError( + f"Got unexpected object `{materializers}` as output " + f"materializer for step '{self.step_name}'. Only " + f"`BaseMaterializer` subclasses or dictionaries mapping " + f"output names to `BaseMaterializer` subclasses are allowed " + f"as input when specifying return materializers." + ) + + return self diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index 88a824a3f19..7778a598bb7 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -55,9 +55,6 @@ from zenml.exceptions import MissingStepParameterError from zenml.logger import get_logger from zenml.materializers.base_materializer import BaseMaterializer -from zenml.materializers.spec_materializer_registry import ( - SpecMaterializerRegistry, -) from zenml.steps.base_step_config import BaseStepConfig from zenml.steps.step_output import Output from zenml.utils import source_utils @@ -107,7 +104,7 @@ def generate_component(step: "BaseStep") -> Callable[..., Any]: spec_outputs[key] = component_spec.ChannelParameter(type=artifact_type) for key, prim_type in step.PARAM_SPEC.items(): spec_params[key] = component_spec.ExecutionParameter(type=str) # type: ignore[no-untyped-call] # noqa - for key in step._internal_execution_properties.keys(): # noqa + for key in step._internal_execution_parameters.keys(): # noqa spec_params[key] = component_spec.ExecutionParameter(type=str) # type: ignore[no-untyped-call] # noqa component_spec_class = type( @@ -127,7 +124,7 @@ def generate_component(step: "BaseStep") -> Callable[..., Any]: { "_FUNCTION": staticmethod(getattr(step, STEP_INNER_FUNC_NAME)), "__module__": step.__module__, - "spec_materializer_registry": step.spec_materializer_registry, + "materializers": step.get_materializers(ensure_complete=True), PARAM_STEP_NAME: step.step_name, }, ) @@ -224,8 +221,8 @@ class _FunctionExecutor(BaseExecutor): _FUNCTION = staticmethod(lambda: None) # TODO[HIGH]: should this be an instance variable? - spec_materializer_registry: ClassVar[ - Optional[SpecMaterializerRegistry] + materializers: ClassVar[ + Optional[Dict[str, Type["BaseMaterializer"]]] ] = None def resolve_materializer_with_registry( @@ -241,14 +238,10 @@ def resolve_materializer_with_registry( The right materializer based on the defaults or optionally the one set by the user. """ - if not self.spec_materializer_registry: - raise ValueError("Spec Materializer Registry is not set!") + if not self.materializers: + raise ValueError("Materializers are missing is not set!") - materializer_class = ( - self.spec_materializer_registry.get_single_materializer_type( - param_name - ) - ) + materializer_class = self.materializers[param_name] return materializer_class def resolve_input_artifact( diff --git a/tests/conftest.py b/tests/conftest.py index 5af273a9858..b47278267ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -419,3 +419,21 @@ def _pipeline(step_1, step_2): pass return _pipeline + + +@pytest.fixture +def int_step_output(): + @step + def _step() -> int: + return 1 + + return _step()() + + +@pytest.fixture +def step_with_two_int_inputs(): + @step + def _step(input_1: int, input_2: int): + pass + + return _step diff --git a/tests/pipelines/test_base_pipeline.py b/tests/pipelines/test_base_pipeline.py index 1914c3bc40b..692531ab1eb 100644 --- a/tests/pipelines/test_base_pipeline.py +++ b/tests/pipelines/test_base_pipeline.py @@ -12,6 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. import os +from contextlib import ExitStack as does_not_raise import pytest @@ -47,21 +48,24 @@ def test_initialize_pipeline_with_args( unconnected_two_step_pipeline, empty_step ): """Test that a pipeline can be initialized with args.""" - unconnected_two_step_pipeline(empty_step(), empty_step()) + with does_not_raise(): + unconnected_two_step_pipeline(empty_step(), empty_step()) def test_initialize_pipeline_with_kwargs( unconnected_two_step_pipeline, empty_step ): """Test that a pipeline can be initialized with kwargs.""" - unconnected_two_step_pipeline(step_1=empty_step(), step_2=empty_step()) + with does_not_raise(): + unconnected_two_step_pipeline(step_1=empty_step(), step_2=empty_step()) def test_initialize_pipeline_with_args_and_kwargs( unconnected_two_step_pipeline, empty_step ): """Test that a pipeline can be initialized with a mix of args and kwargs.""" - unconnected_two_step_pipeline(empty_step(), step_2=empty_step()) + with does_not_raise(): + unconnected_two_step_pipeline(empty_step(), step_2=empty_step()) def test_initialize_pipeline_with_too_many_args( diff --git a/tests/steps/test_base_step.py b/tests/steps/test_base_step.py index cdaaa0f810d..f3c68905973 100644 --- a/tests/steps/test_base_step.py +++ b/tests/steps/test_base_step.py @@ -11,11 +11,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +from contextlib import ExitStack as does_not_raise + import pytest -from zenml.exceptions import StepInterfaceError +from zenml.exceptions import MissingStepParameterError, StepInterfaceError +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.materializers.built_in_materializer import BuiltInMaterializer +from zenml.steps import step from zenml.steps.base_step_config import BaseStepConfig -from zenml.steps.step_decorator import step +from zenml.steps.step_output import Output + + +def test_define_step_with_shared_input_and_output_name(): + """Tests that defining a step with a shared input and output name raises + a StepInterfaceError.""" + with pytest.raises(StepInterfaceError): + + @step + def some_step(shared_name: int) -> Output(shared_name=int): + return shared_name + + +def test_define_step_with_multiple_configs(): + """Tests that defining a step with multiple configs raises + a StepInterfaceError.""" + with pytest.raises(StepInterfaceError): + + @step + def some_step( + first_config: BaseStepConfig, second_config: BaseStepConfig + ): + pass + + +def test_define_step_without_input_annotation(): + """Tests that defining a step with a missing input annotation raises + a StepInterfaceError.""" + with pytest.raises(StepInterfaceError): + + @step + def some_step(some_argument, some_other_argument: int): + pass + + +def test_define_step_with_variable_args(): + """Tests that defining a step with variable arguments raises + a StepInterfaceError.""" + with pytest.raises(StepInterfaceError): + + @step + def some_step(*args: int): + pass + + +def test_define_step_with_variable_kwargs(): + """Tests that defining a step with variable keyword arguments raises + a StepInterfaceError.""" + with pytest.raises(StepInterfaceError): + + @step + def some_step(**kwargs: int): + pass + + +def test_define_step_with_keyword_only_arguments(): + """Tests that keyword-only arguments get included in the input signature + or a step.""" + + @step + def some_step(some_argument: int, *, keyword_only_argument: int): + pass + + assert "keyword_only_argument" in some_step.INPUT_SIGNATURE def test_initialize_step_with_unexpected_config(): @@ -56,10 +124,12 @@ def step_with_config(config: StepConfig) -> None: step_with_config(wrong_config_key=StepConfig()) # noqa # initializing with correct key should work - step_with_config(config=StepConfig()) + with does_not_raise(): + step_with_config(config=StepConfig()) # initializing as non-kwarg should work as well - step_with_config(StepConfig()) + with does_not_raise(): + step_with_config(StepConfig()) # initializing with multiple args or kwargs should fail with pytest.raises(StepInterfaceError): @@ -70,3 +140,410 @@ def step_with_config(config: StepConfig) -> None: with pytest.raises(StepInterfaceError): step_with_config(config=StepConfig(), config2=StepConfig()) + + +def test_access_step_component_before_calling(): + """Tests that accessing a steps component before calling it raises + a StepInterfaceError.""" + + @step + def some_step(): + pass + + with pytest.raises(StepInterfaceError): + _ = some_step().component + + +def test_access_step_component_after_calling(): + """Tests that a step component exists after the step was called.""" + + @step + def some_step(): + pass + + step_instance = some_step() + step_instance() + + with does_not_raise(): + _ = step_instance.component + + +def test_configure_step_with_wrong_materializer_class(): + """Tests that passing a random class as a materializer raises a + StepInterfaceError.""" + + @step + def some_step() -> Output(some_output=int): + pass + + with pytest.raises(StepInterfaceError): + some_step().with_return_materializers(str) # noqa + + +def test_configure_step_with_wrong_materializer_key(): + """Tests that passing a materializer for a non-existent argument raises a + StepInterfaceError.""" + + @step + def some_step() -> Output(some_output=int): + pass + + with pytest.raises(StepInterfaceError): + materializers = {"some_nonexistent_output": BaseMaterializer} + some_step().with_return_materializers(materializers) + + +def test_configure_step_with_wrong_materializer_class_in_dict(): + """Tests that passing a wrong class as materializer for a specific output + raises a StepInterfaceError.""" + + @step + def some_step() -> Output(some_output=int): + pass + + materializers = {"some_output": "not_a_materializer_class"} + with pytest.raises(StepInterfaceError): + some_step().with_return_materializers(materializers) # noqa + + +def test_setting_a_materializer_for_a_step_with_multiple_outputs(): + """Tests that setting a materializer for a step with multiple outputs + sets the materializer for all the outputs.""" + + @step + def some_step() -> Output(some_output=int, some_other_output=str): + pass + + step_instance = some_step().with_return_materializers(BaseMaterializer) + assert step_instance.get_materializers()["some_output"] is BaseMaterializer + assert ( + step_instance.get_materializers()["some_other_output"] + is BaseMaterializer + ) + + +def test_overwriting_step_materializers(): + """Tests that calling `with_return_materializers` multiple times allows + overwriting of the step materializers.""" + + @step + def some_step() -> Output(some_output=int, some_other_output=str): + pass + + step_instance = some_step() + assert not step_instance._explicit_materializers + + step_instance = step_instance.with_return_materializers( + {"some_output": BaseMaterializer} + ) + assert ( + step_instance._explicit_materializers["some_output"] is BaseMaterializer + ) + assert "some_other_output" not in step_instance._explicit_materializers + + step_instance = step_instance.with_return_materializers( + {"some_other_output": BuiltInMaterializer} + ) + assert ( + step_instance._explicit_materializers["some_other_output"] + is BuiltInMaterializer + ) + assert ( + step_instance._explicit_materializers["some_output"] is BaseMaterializer + ) + + step_instance = step_instance.with_return_materializers( + {"some_output": BuiltInMaterializer} + ) + assert ( + step_instance._explicit_materializers["some_output"] + is BuiltInMaterializer + ) + + step_instance.with_return_materializers(BaseMaterializer) + assert ( + step_instance._explicit_materializers["some_output"] is BaseMaterializer + ) + assert ( + step_instance._explicit_materializers["some_other_output"] + is BaseMaterializer + ) + + +def test_step_with_disabled_cache_has_random_string_as_execution_property(): + """Tests that a step with disabled caching adds a random string as + execution property to disable caching.""" + + @step(enable_cache=False) + def some_step(): + pass + + step_instance_1 = some_step() + step_instance_2 = some_step() + + assert ( + step_instance_1._internal_execution_parameters["zenml-disable_cache"] + != step_instance_2._internal_execution_parameters["zenml-disable_cache"] + ) + + +def test_step_source_execution_parameter_stays_the_same_if_step_is_not_modified(): + """Tests that the step source execution parameter remains constant when + creating multiple steps from the same source code.""" + + @step + def some_step(): + pass + + step_1 = some_step() + step_2 = some_step() + + assert ( + step_1._internal_execution_parameters["zenml-step_source"] + == step_2._internal_execution_parameters["zenml-step_source"] + ) + + +def test_step_source_execution_parameter_changes_when_signature_changes(): + """Tests that modifying the input arguments or outputs of a step + function changes the step source execution parameter.""" + + @step + def some_step(some_argument: int) -> int: + pass + + step_1 = some_step() + + @step + def some_step(some_argument_with_new_name: int) -> int: + pass + + step_2 = some_step() + + assert ( + step_1._internal_execution_parameters["zenml-step_source"] + != step_2._internal_execution_parameters["zenml-step_source"] + ) + + @step + def some_step(some_argument: int) -> str: + pass + + step_3 = some_step() + + assert ( + step_1._internal_execution_parameters["zenml-step_source"] + != step_3._internal_execution_parameters["zenml-step_source"] + ) + + +def test_step_source_execution_parameter_changes_when_function_body_changes(): + """Tests that modifying the step function code changes the step + source execution parameter.""" + + @step + def some_step(): + pass + + step_1 = some_step() + + @step + def some_step(): + # this is new + pass + + step_2 = some_step() + + assert ( + step_1._internal_execution_parameters["zenml-step_source"] + != step_2._internal_execution_parameters["zenml-step_source"] + ) + + +def test_materializer_source_execution_parameter_changes_when_materializer_changes(): + """Tests that changing the step materializer changes the materializer + source execution parameter.""" + + @step + def some_step() -> int: + return 1 + + class MyCustomMaterializer(BuiltInMaterializer): + pass + + step_1 = some_step().with_return_materializers(BuiltInMaterializer) + step_2 = some_step().with_return_materializers(MyCustomMaterializer) + + key = "zenml-output_materializer_source" + assert ( + step_1._internal_execution_parameters[key] + != step_2._internal_execution_parameters[key] + ) + + +def test_call_step_with_args(int_step_output, step_with_two_int_inputs): + """Test that a step can be called with args.""" + with does_not_raise(): + step_with_two_int_inputs()(int_step_output, int_step_output) + + +def test_call_step_with_kwargs(int_step_output, step_with_two_int_inputs): + """Test that a step can be called with kwargs.""" + with does_not_raise(): + step_with_two_int_inputs()( + input_1=int_step_output, input_2=int_step_output + ) + + +def test_call_step_with_args_and_kwargs( + int_step_output, step_with_two_int_inputs +): + """Test that a step can be called with a mix of args and kwargs.""" + with does_not_raise(): + step_with_two_int_inputs()(int_step_output, input_2=int_step_output) + + +def test_call_step_with_too_many_args( + int_step_output, step_with_two_int_inputs +): + """Test that calling a step fails when too many args + are passed.""" + with pytest.raises(StepInterfaceError): + step_with_two_int_inputs()( + int_step_output, int_step_output, int_step_output + ) + + +def test_call_step_with_too_many_args_and_kwargs( + int_step_output, step_with_two_int_inputs +): + """Test that calling a step fails when too many args + and kwargs are passed.""" + with pytest.raises(StepInterfaceError): + step_with_two_int_inputs()( + int_step_output, input_1=int_step_output, input_2=int_step_output + ) + + +def test_call_step_with_missing_key(int_step_output, step_with_two_int_inputs): + """Test that calling a step fails when an argument + is missing.""" + with pytest.raises(StepInterfaceError): + step_with_two_int_inputs()(input_1=int_step_output) + + +def test_call_step_with_unexpected_key( + int_step_output, step_with_two_int_inputs +): + """Test that calling a step fails when an argument + has an unexpected key.""" + with pytest.raises(StepInterfaceError): + step_with_two_int_inputs()( + input_1=int_step_output, + input_2=int_step_output, + input_3=int_step_output, + ) + + +def test_call_step_with_wrong_arg_type( + int_step_output, step_with_two_int_inputs +): + """Test that calling a step fails when an arg has a wrong type.""" + with pytest.raises(StepInterfaceError): + step_with_two_int_inputs()(1, int_step_output) + + +def test_call_step_with_wrong_kwarg_type( + int_step_output, step_with_two_int_inputs +): + """Test that calling a step fails when an kwarg has a wrong type.""" + with pytest.raises(StepInterfaceError): + step_with_two_int_inputs()(input_1=1, input_2=int_step_output) + + +def test_call_step_with_missing_materializer_for_type(): + """Tests that calling a step with an output without registered + materializer raises a StepInterfaceError.""" + + class MyTypeWithoutMaterializer: + pass + + @step + def some_step() -> MyTypeWithoutMaterializer: + return MyTypeWithoutMaterializer() + + with pytest.raises(StepInterfaceError): + some_step()() + + +def test_call_step_with_default_materializer_registered(): + """Tests that calling a step with a registered default materializer for the + output works.""" + + class MyType: + pass + + class MyTypeMaterializer(BaseMaterializer): + ASSOCIATED_TYPES = [MyType] + + @step + def some_step() -> MyType: + return MyType() + + with does_not_raise(): + some_step()() + + +def test_call_step_with_explicit_materializer(): + """Tests that calling a step with an explicit materializer for the + output works.""" + + class MyType: + pass + + class MyTypeMaterializer(BaseMaterializer): + # Not registered as default for `MyType` + ASSOCIATED_TYPES = [int] + + @step + def some_step() -> MyType: + return MyType() + + with does_not_raise(): + some_step().with_return_materializers(MyTypeMaterializer)() + + +def test_step_uses_config_class_default_values_if_no_config_is_passed(): + """Tests that a step falls back to the config class default values if + no config object is passed at initialization.""" + + class ConfigWithDefaultValues(BaseStepConfig): + some_parameter: int = 1 + + @step + def some_step(config: ConfigWithDefaultValues): + pass + + # don't pass the config when initializing the step + step_instance = some_step() + step_instance._update_and_verify_parameter_spec() + + assert step_instance.PARAM_SPEC["some_parameter"] == 1 + + +def test_step_fails_if_config_parameter_value_is_missing(): + """Tests that a step fails if no config object is passed at + initialization and the config class misses some default values.""" + + class ConfigWithoutDefaultValues(BaseStepConfig): + some_parameter: int + + @step + def some_step(config: ConfigWithoutDefaultValues): + pass + + # don't pass the config when initializing the step + step_instance = some_step() + + with pytest.raises(MissingStepParameterError): + step_instance._update_and_verify_parameter_spec()