Skip to content

Commit

Permalink
Merge pull request #175 from zenml-io/michael/ENG-83-step-and-pipelin…
Browse files Browse the repository at this point in the history
…e-interface

Improve step and pipeline interface
  • Loading branch information
schustmi authored Nov 18, 2021
2 parents 314300e + 32bab9f commit 708bbb1
Show file tree
Hide file tree
Showing 12 changed files with 792 additions and 271 deletions.
9 changes: 5 additions & 4 deletions src/zenml/materializers/base_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ def __new__(
"You should specify a list of ASSOCIATED_TYPES when creating a "
"Materializer!"
)
[
default_materializer_registry.register_materializer_type(x, cls)
for x in cls.ASSOCIATED_TYPES
]
for associated_type in cls.ASSOCIATED_TYPES:
default_materializer_registry.register_materializer_type(
associated_type, cls
)

return cls


Expand Down
33 changes: 14 additions & 19 deletions src/zenml/materializers/default_materializer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import TYPE_CHECKING, Any, ClassVar, Dict, Type
from typing import TYPE_CHECKING, Any, Dict, Type

from zenml.logger import get_logger

Expand All @@ -22,28 +22,28 @@
from zenml.materializers.base_materializer import BaseMaterializer


class DefaultMaterializerRegistry(object):
class MaterializerRegistry:
"""Matches a python type to a default materializer."""

materializer_types: ClassVar[Dict[Type[Any], Type["BaseMaterializer"]]] = {}
def __init__(self) -> None:
self.materializer_types: Dict[Type[Any], Type["BaseMaterializer"]] = {}

@classmethod
def register_materializer_type(
cls, key: Type[Any], type_: Type["BaseMaterializer"]
self, key: Type[Any], type_: Type["BaseMaterializer"]
) -> None:
"""Registers a new materializer.
Args:
key: Indicates the type of an object.
type_: A BaseMaterializer subclass.
"""
if key not in cls.materializer_types:
cls.materializer_types[key] = type_
if key not in self.materializer_types:
self.materializer_types[key] = type_
logger.debug(f"Registered materializer {type_} for {key}")
else:
logger.debug(
f"{key} already registered with "
f"{cls.materializer_types[key]}. Cannot register {type_}."
f"{self.materializer_types[key]}. Cannot register {type_}."
)

def register_and_overwrite_type(
Expand All @@ -58,17 +58,14 @@ def register_and_overwrite_type(
self.materializer_types[key] = type_
logger.debug(f"Registered materializer {type_} for {key}")

def get_single_materializer_type(
self, key: Type[Any]
) -> Type["BaseMaterializer"]:
def __getitem__(self, key: Type[Any]) -> Type["BaseMaterializer"]:
"""Get a single materializers based on the key.
Args:
key: Indicates the type of an object.
Returns:
Instance of a `BaseMaterializer` subclass initialized with the
artifact of this factory.
`BaseMaterializer` subclass that was registered for this key.
"""
if key in self.materializer_types:
return self.materializer_types[key]
Expand All @@ -81,14 +78,12 @@ def get_single_materializer_type(
def get_materializer_types(
self,
) -> Dict[Type[Any], Type["BaseMaterializer"]]:
"""Get all registered materializers."""
"""Get all registered materializer types."""
return self.materializer_types

def is_registered(self, key: Type[Any]) -> bool:
"""Returns true if key type is registered, else returns False."""
if key in self.materializer_types:
return True
return False
"""Returns if a materializer class is registered for the given type."""
return key in self.materializer_types


default_materializer_registry = DefaultMaterializerRegistry()
default_materializer_registry = MaterializerRegistry()
69 changes: 0 additions & 69 deletions src/zenml/materializers/spec_materializer_registry.py

This file was deleted.

7 changes: 7 additions & 0 deletions src/zenml/post_execution/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,10 @@ def __repr__(self) -> str:
f"type='{self._type}', uri='{self._uri}', "
f"materializer='{self._materializer}')"
)

def __eq__(self, other: Any) -> bool:
"""Returns whether the other object is referring to the
same artifact."""
if isinstance(other, ArtifactView):
return self._id == other._id and self._uri == other._uri
return NotImplemented
12 changes: 11 additions & 1 deletion src/zenml/post_execution/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Any, List

from zenml.logger import get_logger
from zenml.post_execution.pipeline_run import PipelineRunView
Expand Down Expand Up @@ -95,3 +95,13 @@ def __repr__(self) -> str:
f"{self.__class__.__qualname__}(id={self._id}, "
f"name='{self._name}')"
)

def __eq__(self, other: Any) -> bool:
"""Returns whether the other object is referring to the
same pipeline."""
if isinstance(other, PipelineView):
return (
self._id == other._id
and self._metadata_store.uuid == other._metadata_store.uuid
)
return NotImplemented
12 changes: 11 additions & 1 deletion src/zenml/post_execution/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.

from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List

from ml_metadata import proto

Expand Down Expand Up @@ -121,3 +121,13 @@ def __repr__(self) -> str:
f"{self.__class__.__qualname__}(id={self._id}, "
f"name='{self._name}')"
)

def __eq__(self, other: Any) -> bool:
"""Returns whether the other object is referring to the same
pipeline run."""
if isinstance(other, PipelineRunView):
return (
self._id == other._id
and self._metadata_store.uuid == other._metadata_store.uuid
)
return NotImplemented
9 changes: 9 additions & 0 deletions src/zenml/post_execution/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,12 @@ def __repr__(self) -> str:
f"{self.__class__.__qualname__}(id={self._id}, "
f"name='{self._name}', parameters={self._parameters})"
)

def __eq__(self, other: Any) -> bool:
"""Returns whether the other object is referring to the same step."""
if isinstance(other, StepView):
return (
self._id == other._id
and self._metadata_store.uuid == other._metadata_store.uuid
)
return NotImplemented
Loading

0 comments on commit 708bbb1

Please sign in to comment.