Skip to content

Commit

Permalink
Merge pull request #157 from zenml-io/michael/ENG-24-caching
Browse files Browse the repository at this point in the history
Step caching enabled
  • Loading branch information
htahir1 authored Nov 5, 2021
2 parents d88737b + f35b5be commit 4b68e54
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 53 deletions.
7 changes: 2 additions & 5 deletions src/zenml/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
import inspect
import json
from abc import abstractmethod
from typing import Any, ClassVar, Dict, NoReturn, Optional, Tuple, Type, cast

Expand Down Expand Up @@ -229,21 +228,19 @@ def _read_config_steps(

step = self.__steps[step_name]
step_parameters = (
step.CONFIG.__fields__.keys() if step.CONFIG else {}
step.CONFIG_CLASS.__fields__.keys() if step.CONFIG_CLASS else {}
)
parameters = step_dict.get(StepConfigurationKeys.PARAMETERS_, {})
for parameter, value in parameters.items():
if parameter not in step_parameters:
raise PipelineConfigurationError(
f"Found parameter '{parameter}' for '{step_name}' step "
f"in configuration yaml but it doesn't exist in the "
f"configuration class `{step.CONFIG}`. Available "
f"configuration class `{step.CONFIG_CLASS}`. Available "
f"parameters for this step: "
f"{list(step_parameters)}."
)

# make sure the value gets serialized to a string
value = json.dumps(value)
previous_value = step.PARAM_SPEC.get(parameter, None)

if overwrite:
Expand Down
7 changes: 4 additions & 3 deletions src/zenml/pipelines/pipeline_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def pipeline(
nested decorator structure.
Args:
_func: Optional func from outside.
name: str, the given name for the pipeline
enable_cache: Whether to use cache or not.
_func: The decorated function.
name: The name of the pipeline. If left empty, the name of the
decorated function will be used as a fallback.
enable_cache: Whether to use caching or not.
Returns:
the inner decorator which creates the pipeline class based on the
Expand Down
162 changes: 126 additions & 36 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import inspect
import json
import random
from abc import abstractmethod
from typing import (
Any,
Expand All @@ -31,7 +32,7 @@
from tfx.types.channel import Channel

from zenml.artifacts.base_artifact import BaseArtifact
from zenml.exceptions import StepInterfaceError
from zenml.exceptions import MissingStepParameterError, StepInterfaceError
from zenml.logger import get_logger
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.materializers.default_materializer_registry import (
Expand All @@ -43,6 +44,8 @@
from zenml.steps.base_step_config import BaseStepConfig
from zenml.steps.step_output import Output
from zenml.steps.utils import (
INTERNAL_EXECUTION_PARAMETER_PREFIX,
PARAM_ENABLE_CACHE,
SINGLE_RETURN_OUT_NAME,
STEP_INNER_FUNC_NAME,
_ZenMLSimpleComponent,
Expand Down Expand Up @@ -82,7 +85,8 @@ def __new__(

cls.INPUT_SIGNATURE = {}
cls.OUTPUT_SIGNATURE = {}
cls.CONFIG = None
cls.CONFIG_PARAMETER_NAME = None
cls.CONFIG_CLASS = None

# Looking into the signature of the provided process function
process_spec = inspect.getfullargspec(
Expand All @@ -102,12 +106,13 @@ def __new__(
# materializer type.
if issubclass(arg_type, BaseStepConfig):
# It needs to be None at this point, otherwise multi configs.
if cls.CONFIG is not None:
if cls.CONFIG_CLASS is not None:
raise StepInterfaceError(
"Please only use one `BaseStepConfig` type object in "
"your step."
)
cls.CONFIG = arg_type
cls.CONFIG_PARAMETER_NAME = arg
cls.CONFIG_CLASS = arg_type
else:
cls.INPUT_SIGNATURE.update({arg: arg_type})

Expand All @@ -127,6 +132,11 @@ def __new__(
{SINGLE_RETURN_OUT_NAME: return_spec}
)

if check_dict_keys_match(cls.INPUT_SIGNATURE, cls.OUTPUT_SIGNATURE):
raise StepInterfaceError(
"The input names and output names cannot be the same!"
)

return cls


Expand All @@ -140,67 +150,145 @@ class BaseStep(metaclass=BaseStepMeta):
# TODO [MEDIUM]: Ensure these are ordered
INPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
CONFIG: ClassVar[Optional[Type[BaseStepConfig]]] = None
CONFIG_PARAMETER_NAME: ClassVar[Optional[str]] = None
CONFIG_CLASS: ClassVar[Optional[Type[BaseStepConfig]]] = None

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.materializers: Dict[str, Type[BaseMaterializer]] = {}
self.__component = None
self.step_name = self.__class__.__name__
self.enable_cache = getattr(self, PARAM_ENABLE_CACHE)
self.PARAM_SPEC: Dict[str, Any] = {}
self.INPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
self.OUTPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
self.spec_materializer_registry = SpecMaterializerRegistry()

# TODO [LOW]: Support args
# TODO [LOW]: Currently the kwarg name doesn't matter
# and will be used as the config
if args:
raise StepInterfaceError(
"When you are creating an instance of a step, please only "
"use key-word arguments."
)
self._verify_arguments(*args, **kwargs)

@property
def _internal_execution_properties(self) -> Dict[str, str]:
"""ZenML internal execution properties for this step.
**IMPORTANT**: When modifying this dictionary, make sure to
prefix the key with `INTERNAL_EXECUTION_PARAMETER_PREFIX` and serialize
the value using `json.dumps(...)`.
"""
properties = {}
if not self.enable_cache:
# add a random string to the execution properties to disable caching
key = INTERNAL_EXECUTION_PARAMETER_PREFIX + "disable_cache"
random_string = f"{random.getrandbits(128):032x}"
properties[key] = json.dumps(random_string)

return properties

def _verify_arguments(self, *args: Any, **kwargs: Any) -> None:
"""Verifies the initialization args and kwargs of this step.
This method makes sure that there is only a config object passed at
initialization and that it was passed using the correct name and
type specified in the step declaration.
If the correct config object was found, additionally saves the
config parameters to `self.PARAM_SPEC`.
Args:
*args: The args passed to the init method of this step.
**kwargs: The kwargs passed to the init method of this step.
maximum_kwarg_count = 1 if self.CONFIG else 0
if len(kwargs) > maximum_kwarg_count:
Raises:
StepInterfaceError: If there are too many arguments or arguments
with a wrong name/type.
"""
maximum_arg_count = 1 if self.CONFIG_CLASS else 0
if (len(args) + len(kwargs)) > maximum_arg_count:
raise StepInterfaceError(
f"Too many keyword arguments ({len(kwargs)}, "
f"expected: {maximum_kwarg_count}) passed when "
f"creating a '{self.step_name}' step."
f"Too many arguments ({len(kwargs)}, expected: "
f"{maximum_arg_count}) passed when creating a "
f"'{self.step_name}' step."
)

if self.CONFIG and len(kwargs) == 1:
config = kwargs.popitem()[1]
if self.CONFIG_PARAMETER_NAME and self.CONFIG_CLASS:
if args:
config = args[0]
elif kwargs:
key, config = kwargs.popitem()

if not isinstance(config, self.CONFIG):
if key != self.CONFIG_PARAMETER_NAME:
raise StepInterfaceError(
f"Unknown keyword argument '{key}' when creating a "
f"'{self.step_name}' step, only expected a single "
f"argument with key '{self.CONFIG_PARAMETER_NAME}'."
)
else:
# This step requires configuration parameters but no config
# object was passed as an argument. The parameters might be
# set via default values in the config class or in a
# configuration file, so we continue for now and verify
# that all parameters are set before running the step
return

if not isinstance(config, self.CONFIG_CLASS):
raise StepInterfaceError(
f"`{config}` object passed when creating a "
f"'{self.step_name}' step is not a "
f"`{self.CONFIG.__name__}` instance."
f"`{self.CONFIG_CLASS.__name__}` instance."
)

self.PARAM_SPEC = config.dict()

def _prepare_parameter_spec(self) -> None:
"""Verifies and prepares the config parameters for running this step.
When the step requires config parameters, this method:
- checks if config parameters were set via a config object or file
- tries to set missing config parameters from default values of the
config class
Raises:
MissingStepParameterError: If no value could be found for one or
more config parameters.
StepInterfaceError: If a config parameter value couldn't be
serialized to json.
"""
if self.CONFIG_CLASS:
# we need to store a value for all config keys inside the
# metadata store to make sure caching works as expected
missing_keys = []
for name, field in self.CONFIG_CLASS.__fields__.items():
if name in self.PARAM_SPEC:
# a value for this parameter has been set already
continue

if field.default is not None:
# use default value from the pydantic config class
self.PARAM_SPEC[name] = field.default
else:
missing_keys.append(name)

if missing_keys:
raise MissingStepParameterError(
self.step_name, missing_keys, self.CONFIG_CLASS
)

# convert config parameter values to strings
try:
# TODO [MEDIUM]: include pydantic default values so they get
# stored inside the metadata store as well
self.PARAM_SPEC = {
k: json.dumps(v) for k, v in config.dict().items()
k: json.dumps(v) for k, v in self.PARAM_SPEC.items()
}
except RuntimeError as e:
# TODO [LOW]: Attach a URL with all supported types.
logger.debug(f"Pydantic Error: {str(e)}")
except TypeError as e:
raise StepInterfaceError(
"You passed in a parameter that we cannot serialize!"
)
f"Failed to serialize config parameters for step "
f"'{self.step_name}'. Please make sure to only use "
f"json serializable parameter values."
) from e

def __call__(
self, **artifacts: BaseArtifact
) -> Union[Channel, List[Channel]]:
"""Generates a component when called."""
# TODO [MEDIUM]: Support *args as well.
# register defaults
if check_dict_keys_match(self.INPUT_SIGNATURE, self.OUTPUT_SIGNATURE):
raise StepInterfaceError(
"The input names and output names cannot be the same!"
)

self._prepare_parameter_spec()

# Construct INPUT_SPEC from INPUT_SIGNATURE
self.resolve_signature_materializers(self.INPUT_SIGNATURE, True)
Expand All @@ -224,7 +312,9 @@ def __call__(
)

self.__component = generate_component(self)(
**artifacts, **self.PARAM_SPEC
**artifacts,
**self.PARAM_SPEC,
**self._internal_execution_properties,
)

# Resolve the returns in the right order.
Expand Down
18 changes: 13 additions & 5 deletions src/zenml/steps/step_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Callable, Optional, Type, TypeVar, Union, overload

from zenml.steps.base_step import BaseStep
from zenml.steps.utils import STEP_INNER_FUNC_NAME
from zenml.steps.utils import PARAM_ENABLE_CACHE, STEP_INNER_FUNC_NAME

F = TypeVar("F", bound=Callable[..., Any])

Expand All @@ -26,22 +26,29 @@ def step(_func: F) -> Type[BaseStep]:


@overload
def step(*, name: Optional[str] = None) -> Callable[[F], Type[BaseStep]]:
def step(
*, name: Optional[str] = None, enable_cache: bool = True
) -> Callable[[F], Type[BaseStep]]:
"""Type annotations for step decorator in case of arguments."""
...


def step(
_func: Optional[F] = None, *, name: Optional[str] = None
_func: Optional[F] = None,
*,
name: Optional[str] = None,
enable_cache: bool = True
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
"""Outer decorator function for the creation of a ZenML step
In order to be able work with parameters such as `name`, it features a
nested decorator structure.
Args:
_func: Optional func from outside.
name (required) the given name for the step.
_func: The decorated function.
name: The name of the step. If left empty, the name of the decorated
function will be used as a fallback.
enable_cache: Whether to use caching or not.
Returns:
the inner decorator which creates the step class based on the
Expand All @@ -64,6 +71,7 @@ def inner_decorator(func: F) -> Type[BaseStep]:
(BaseStep,),
{
STEP_INNER_FUNC_NAME: staticmethod(func),
PARAM_ENABLE_CACHE: enable_cache,
},
)

Expand Down
11 changes: 11 additions & 0 deletions src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
STEP_INNER_FUNC_NAME: str = "process"
SINGLE_RETURN_OUT_NAME: str = "output"
PARAM_STEP_NAME: str = "step_name"
PARAM_ENABLE_CACHE: str = "enable_cache"
INTERNAL_EXECUTION_PARAMETER_PREFIX: str = "zenml-"


def do_types_match(type_a: Type[Any], type_b: Type[Any]) -> bool:
Expand Down Expand Up @@ -105,6 +107,8 @@ def generate_component(step: "BaseStep") -> Callable[..., Any]:
spec_outputs[key] = component_spec.ChannelParameter(type=artifact_type)
for key, prim_type in step.PARAM_SPEC.items():
spec_params[key] = component_spec.ExecutionParameter(type=str) # type: ignore[no-untyped-call] # noqa
for key in step._internal_execution_properties.keys(): # noqa
spec_params[key] = component_spec.ExecutionParameter(type=str) # type: ignore[no-untyped-call] # noqa

component_spec_class = type(
"%s_Spec" % step.__class__.__name__,
Expand Down Expand Up @@ -319,6 +323,13 @@ def Do(
output_dict: dictionary containing the output artifacts
exec_properties: dictionary containing the execution parameters
"""
# remove all ZenML internal execution properties
exec_properties = {
k: v
for k, v in exec_properties.items()
if not k.startswith(INTERNAL_EXECUTION_PARAMETER_PREFIX)
}

# Building the args for the process function
function_params = {}

Expand Down
Loading

0 comments on commit 4b68e54

Please sign in to comment.