Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable generic step inputs and outputs #440

Merged
merged 2 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
SINGLE_RETURN_OUT_NAME,
_ZenMLSimpleComponent,
generate_component_class,
resolve_type_annotation,
)
from zenml.utils.source_utils import get_hashed_source

Expand Down Expand Up @@ -115,6 +116,7 @@ def __new__(
# Verify the input arguments of the step function
for arg in step_function_args:
arg_type = step_function_signature.annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)

if not arg_type:
raise StepInterfaceError(
Expand Down Expand Up @@ -157,9 +159,14 @@ def __new__(
return_type = step_function_signature.annotations.get("return", None)
if return_type is not None:
if isinstance(return_type, Output):
cls.OUTPUT_SIGNATURE = dict(return_type.items())
cls.OUTPUT_SIGNATURE = {
name: resolve_type_annotation(type_)
for (name, type_) in return_type.items()
}
else:
cls.OUTPUT_SIGNATURE[SINGLE_RETURN_OUT_NAME] = return_type
cls.OUTPUT_SIGNATURE[
SINGLE_RETURN_OUT_NAME
] = resolve_type_annotation(return_type)

# Raise an exception if input and output names of a step overlap as
# tfx requires them to be unique
Expand Down
17 changes: 17 additions & 0 deletions src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import inspect
import json
import sys
import typing
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -90,6 +91,20 @@ def do_types_match(type_a: Type[Any], type_b: Type[Any]) -> bool:
return type_a == type_b


def resolve_type_annotation(obj: Any) -> Any:
"""Returns the non-generic class for generic aliases of the typing module.

If the input is no generic typing alias, the input itself is returned.

Example: if the input object is `typing.Dict`, this method will return the
concrete class `dict`.
"""
if isinstance(obj, typing._GenericAlias): # type: ignore[attr-defined]
return obj.__origin__
else:
return obj


def generate_component_spec_class(
step_name: str,
input_spec: Dict[str, Type[BaseArtifact]],
Expand Down Expand Up @@ -379,6 +394,8 @@ def Do(

for arg in args:
arg_type = spec.annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)

if issubclass(arg_type, BaseStepConfig):
try:
config_object = arg_type.parse_obj(exec_properties)
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/steps/test_base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from contextlib import ExitStack as does_not_raise
from typing import Optional
from typing import Dict, List, Optional

import pytest

Expand All @@ -21,6 +21,7 @@
from zenml.exceptions import MissingStepParameterError, StepInterfaceError
from zenml.materializers import BuiltInMaterializer
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.pipelines import pipeline
from zenml.steps import BaseStepConfig, Output, StepContext, step


Expand Down Expand Up @@ -786,3 +787,40 @@ def some_step_7() -> Output(a=list, b=int):

with pytest.raises(StepInterfaceError):
pipeline_.run()


def test_step_can_output_generic_types(clean_repo, one_step_pipeline):
"""Tests that a step can output generic typing classes."""

@step
def some_step_1() -> Dict:
return {}

@step
def some_step_2() -> List:
return []

for step_function in [some_step_1, some_step_2]:
pipeline_ = one_step_pipeline(step_function())

with does_not_raise():
pipeline_.run()


def test_step_can_have_generic_input_types(clean_repo):
"""Tests that a step can have generic typing classes as input."""

@step
def step_1() -> Output(dict_output=Dict, list_output=List):
return {}, []

@step
def step_2(dict_input: Dict, list_input: List) -> None:
pass

@pipeline
def p(s1, s2):
s2(*s1())

with does_not_raise():
p(step_1(), step_2()).run()
15 changes: 15 additions & 0 deletions tests/unit/steps/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, List, Set

from numpy import ndarray

from zenml.steps.utils import resolve_type_annotation


def test_type_annotation_resolving():
"""Tests that resolving type annotations works as expected."""
assert resolve_type_annotation(Dict) is dict
assert resolve_type_annotation(List[int]) is list
assert resolve_type_annotation(Set[str]) is set

assert resolve_type_annotation(set) is set
assert resolve_type_annotation(ndarray) is ndarray