Skip to content

Commit

Permalink
Separate stack persistence from repo implementation (#462)
Browse files Browse the repository at this point in the history
* Rip out stack persistance from Repository implementation

* Refactor common code from Local to BaseStackStore

* Use correct subclass constructor for component flavors

* Add sqlite stackstore backend

* Make all BaseStackStore interfaces serializable

* Improve API documentation for stack store

* Remove StackConfiguration in favor of simple dict

* Move active_stack back to repository, fix legacy loading

* Let mypy do its thing on sql stack store

* Add unit tests for public methods of stack store

* Update src/zenml/utils/enum_utils.py

Co-authored-by: Michael Schuster <schustmi@users.noreply.github.com>

* Remove plus sign

Co-authored-by: Michael Schuster <schustmi@users.noreply.github.com>
  • Loading branch information
jwwwb and schustmi authored Mar 23, 2022
1 parent dfbe6c2 commit 8714a8e
Show file tree
Hide file tree
Showing 23 changed files with 1,455 additions and 265 deletions.
15 changes: 5 additions & 10 deletions examples/kubeflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,17 @@
import click
from pipeline import (
TrainerConfig,
mnist_pipeline,
importer,
trainer,
evaluator,
importer,
mnist_pipeline,
normalizer,
trainer,
)
from rich import print
from zenml.integrations.tensorflow.services import (
TensorboardService,
TensorboardServiceConfig,
)

from zenml.integrations.tensorflow.visualizers import (
visualize_tensorboard,
stop_tensorboard_server,
visualize_tensorboard,
)
from zenml.repository import Repository


@click.command()
Expand Down
9 changes: 2 additions & 7 deletions examples/kubeflow/run_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@

from datetime import datetime, timedelta

from pipeline import (
mnist_pipeline,
importer,
trainer,
evaluator,
normalizer,
)
from pipeline import evaluator, importer, mnist_pipeline, normalizer, trainer

from zenml.pipelines import Schedule

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/sagemaker_step_operator/train_on_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def importer() -> Output(
# setting the custom_step_operator param will tell ZenML
# to run this step on a custom backend defined by the name
# of the operator you provide.
@step(custom_step_operator='sagemaker')
@step(custom_step_operator="sagemaker")
def trainer(
X_train: np.ndarray,
y_train: np.ndarray,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ distro = "^1.6.0"
rich = {extras = ["jupyter"], version = "^12.0.0"}
httplib2 = "<0.20,>=0.19.1"
pyparsing = "<3,>=2.4.0"
sqlmodel = "0.0.6"


[tool.poetry.dev-dependencies]
Expand Down
11 changes: 7 additions & 4 deletions src/zenml/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from zenml.cli.cli import cli
from zenml.cli.utils import confirmation, declare, error
from zenml.console import console
from zenml.enums import StorageType
from zenml.exceptions import InitializationException
from zenml.repository import Repository

Expand All @@ -31,9 +32,8 @@
exists=True, file_okay=False, dir_okay=True, path_type=Path
),
)
def init(
path: Optional[Path],
) -> None:
@click.option("--storage-type", type=click.Choice(StorageType.names()))
def init(path: Optional[Path], storage_type: Optional[StorageType]) -> None:
"""Initialize ZenML on given path.
Args:
Expand All @@ -45,9 +45,12 @@ def init(
if path is None:
path = Path.cwd()

if storage_type is None:
storage_type = StorageType.YAML_STORAGE

with console.status(f"Initializing ZenML repository at {path}.\n"):
try:
Repository.initialize(root=path)
Repository.initialize(root=path, storage_type=storage_type)
declare(f"ZenML repository initialized at {path}.")
except InitializationException as e:
error(f"{e}")
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/cli/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def list_stacks() -> None:
"ACTIVE": ":point_right:" if is_active else "",
"STACK NAME": stack_name,
**{
key.upper(): value
for key, value in stack_configuration.dict().items()
component_type.value.upper(): value
for component_type, value in stack_configuration.items()
},
}
stack_dicts.append(stack_config)
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from zenml.console import console
from zenml.constants import IS_DEBUG_ENV
from zenml.enums import StackComponentType
from zenml.logger import get_logger
from zenml.repository import StackConfiguration
from zenml.stack import StackComponent

logger = get_logger(__name__)
Expand Down Expand Up @@ -167,7 +167,7 @@ def print_stack_component_list(


def print_stack_configuration(
component: StackConfiguration, active: bool, stack_name: str
config: Dict[StackComponentType, str], active: bool, stack_name: str
) -> None:
"""Prints the configuration options of a stack."""
stack_caption = f"'{stack_name}' stack"
Expand All @@ -181,7 +181,7 @@ def print_stack_configuration(
)
rich_table.add_column("COMPONENT_TYPE")
rich_table.add_column("COMPONENT_NAME")
items = component.dict().items()
items = {typ.value: name for typ, name in config.items()}
for item in items:
rich_table.add_row(*list(item))

Expand Down
3 changes: 3 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_PREVENT_PIPELINE_EXECUTION
)

# Repository Directory Path
LOCAL_CONFIG_DIRECTORY_NAME = ".zen"

USER_MAIN_MODULE: Optional[str] = None

# Rich config
Expand Down
66 changes: 48 additions & 18 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
from enum import Enum
from typing import Type

from zenml.utils.enum_utils import StrEnum

Expand All @@ -38,9 +39,47 @@ class LoggingLevels(Enum):
CRITICAL = logging.CRITICAL


class StackComponentType(StrEnum):
"""All possible types a `StackComponent` can have."""

ORCHESTRATOR = "orchestrator"
METADATA_STORE = "metadata_store"
ARTIFACT_STORE = "artifact_store"
CONTAINER_REGISTRY = "container_registry"
STEP_OPERATOR = "step_operator"

@property
def plural(self) -> str:
"""Returns the plural of the enum value."""
if self == StackComponentType.CONTAINER_REGISTRY:
return "container_registries"

return f"{self.value}s"


class StackComponentFlavor(StrEnum):
"""Abstract base class for all stack component flavors."""

@staticmethod
def for_type(
component_type: StackComponentType,
) -> Type["StackComponentFlavor"]:
"""Get the corresponding flavor child-type for a component type."""
if component_type == StackComponentType.ARTIFACT_STORE:
return ArtifactStoreFlavor
elif component_type == StackComponentType.METADATA_STORE:
return MetadataStoreFlavor
elif component_type == StackComponentType.CONTAINER_REGISTRY:
return ContainerRegistryFlavor
elif component_type == StackComponentType.ORCHESTRATOR:
return OrchestratorFlavor
elif component_type == StackComponentType.STEP_OPERATOR:
return StepOperatorFlavor
else:
raise ValueError(
f"Unsupported Stack Component Type {component_type.value}"
)


class ArtifactStoreFlavor(StackComponentFlavor):
"""All supported artifact store flavors."""
Expand Down Expand Up @@ -80,26 +119,17 @@ class StepOperatorFlavor(StackComponentFlavor):
SAGEMAKER = "sagemaker"


class StackComponentType(StrEnum):
"""All possible types a `StackComponent` can have."""

ORCHESTRATOR = "orchestrator"
METADATA_STORE = "metadata_store"
ARTIFACT_STORE = "artifact_store"
CONTAINER_REGISTRY = "container_registry"
STEP_OPERATOR = "step_operator"

@property
def plural(self) -> str:
"""Returns the plural of the enum value."""
if self == StackComponentType.CONTAINER_REGISTRY:
return "container_registries"

return f"{self.value}s"


class MetadataContextTypes(Enum):
"""All possible types that contexts can have within pipeline nodes"""

STACK = "stack"
PIPELINE_REQUIREMENTS = "pipeline_requirements"


class StorageType(StrEnum):
"""Storage Backend Types"""

YAML_STORAGE = "yaml_storage"
SQLITE_STORAGE = "sqlite_storage"
MYSQL_STORAGE = "mysql_storage"
REST_STORAGE = "rest_storage"
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_plot_method(
raise ValueError(
f"Invalid whylogs plot type: {plot} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(WhylogsPlots.list())}'
f'{f"{nl}- ".join(WhylogsPlots.names())}'
)
return plot_method

Expand Down
Loading

0 comments on commit 8714a8e

Please sign in to comment.