Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate stack persistence from repo implementation #462

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions examples/kubeflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,17 @@
import click
from pipeline import (
TrainerConfig,
mnist_pipeline,
importer,
trainer,
evaluator,
importer,
mnist_pipeline,
normalizer,
trainer,
)
from rich import print
from zenml.integrations.tensorflow.services import (
TensorboardService,
TensorboardServiceConfig,
)

from zenml.integrations.tensorflow.visualizers import (
visualize_tensorboard,
stop_tensorboard_server,
visualize_tensorboard,
)
from zenml.repository import Repository


@click.command()
Expand Down
9 changes: 2 additions & 7 deletions examples/kubeflow/run_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@

from datetime import datetime, timedelta

from pipeline import (
mnist_pipeline,
importer,
trainer,
evaluator,
normalizer,
)
from pipeline import evaluator, importer, mnist_pipeline, normalizer, trainer

from zenml.pipelines import Schedule

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/sagemaker_step_operator/train_on_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def importer() -> Output(
# setting the custom_step_operator param will tell ZenML
# to run this step on a custom backend defined by the name
# of the operator you provide.
@step(custom_step_operator='sagemaker')
@step(custom_step_operator="sagemaker")
def trainer(
X_train: np.ndarray,
y_train: np.ndarray,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ distro = "^1.6.0"
rich = {extras = ["jupyter"], version = "^12.0.0"}
httplib2 = "<0.20,>=0.19.1"
pyparsing = "<3,>=2.4.0"
sqlmodel = "0.0.6"


[tool.poetry.dev-dependencies]
Expand Down
11 changes: 7 additions & 4 deletions src/zenml/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from zenml.cli.cli import cli
from zenml.cli.utils import confirmation, declare, error
from zenml.console import console
from zenml.enums import StorageType
from zenml.exceptions import InitializationException
from zenml.repository import Repository

Expand All @@ -31,9 +32,8 @@
exists=True, file_okay=False, dir_okay=True, path_type=Path
),
)
def init(
path: Optional[Path],
) -> None:
@click.option("--storage-type", type=click.Choice(StorageType.names()))
def init(path: Optional[Path], storage_type: Optional[StorageType]) -> None:
"""Initialize ZenML on given path.

Args:
Expand All @@ -45,9 +45,12 @@ def init(
if path is None:
path = Path.cwd()

if storage_type is None:
storage_type = StorageType.YAML_STORAGE

with console.status(f"Initializing ZenML repository at {path}.\n"):
try:
Repository.initialize(root=path)
Repository.initialize(root=path, storage_type=storage_type)
declare(f"ZenML repository initialized at {path}.")
except InitializationException as e:
error(f"{e}")
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/cli/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def list_stacks() -> None:
"ACTIVE": ":point_right:" if is_active else "",
"STACK NAME": stack_name,
**{
key.upper(): value
for key, value in stack_configuration.dict().items()
component_type.value.upper(): value
for component_type, value in stack_configuration.items()
},
}
stack_dicts.append(stack_config)
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from zenml.console import console
from zenml.constants import IS_DEBUG_ENV
from zenml.enums import StackComponentType
from zenml.logger import get_logger
from zenml.repository import StackConfiguration
from zenml.stack import StackComponent

logger = get_logger(__name__)
Expand Down Expand Up @@ -167,7 +167,7 @@ def print_stack_component_list(


def print_stack_configuration(
component: StackConfiguration, active: bool, stack_name: str
config: Dict[StackComponentType, str], active: bool, stack_name: str
) -> None:
"""Prints the configuration options of a stack."""
stack_caption = f"'{stack_name}' stack"
Expand All @@ -181,7 +181,7 @@ def print_stack_configuration(
)
rich_table.add_column("COMPONENT_TYPE")
rich_table.add_column("COMPONENT_NAME")
items = component.dict().items()
items = {typ.value: name for typ, name in config.items()}
for item in items:
rich_table.add_row(*list(item))

Expand Down
3 changes: 3 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_PREVENT_PIPELINE_EXECUTION
)

# Repository Directory Path
LOCAL_CONFIG_DIRECTORY_NAME = ".zen"

USER_MAIN_MODULE: Optional[str] = None
66 changes: 48 additions & 18 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
from enum import Enum
from typing import Type

from zenml.utils.enum_utils import StrEnum

Expand All @@ -38,9 +39,47 @@ class LoggingLevels(Enum):
CRITICAL = logging.CRITICAL


class StackComponentType(StrEnum):
"""All possible types a `StackComponent` can have."""

ORCHESTRATOR = "orchestrator"
METADATA_STORE = "metadata_store"
ARTIFACT_STORE = "artifact_store"
CONTAINER_REGISTRY = "container_registry"
STEP_OPERATOR = "step_operator"

@property
def plural(self) -> str:
"""Returns the plural of the enum value."""
if self == StackComponentType.CONTAINER_REGISTRY:
return "container_registries"

return f"{self.value}s"


class StackComponentFlavor(StrEnum):
"""Abstract base class for all stack component flavors."""

@staticmethod
def for_type(
component_type: StackComponentType,
) -> Type["StackComponentFlavor"]:
"""Get the corresponding flavor child-type for a component type."""
if component_type == StackComponentType.ARTIFACT_STORE:
return ArtifactStoreFlavor
elif component_type == StackComponentType.METADATA_STORE:
return MetadataStoreFlavor
elif component_type == StackComponentType.CONTAINER_REGISTRY:
return ContainerRegistryFlavor
elif component_type == StackComponentType.ORCHESTRATOR:
return OrchestratorFlavor
elif component_type == StackComponentType.STEP_OPERATOR:
return StepOperatorFlavor
else:
jwwwb marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Unsupported Stack Component Type {component_type.value}"
)


class ArtifactStoreFlavor(StackComponentFlavor):
"""All supported artifact store flavors."""
Expand Down Expand Up @@ -80,26 +119,17 @@ class StepOperatorFlavor(StackComponentFlavor):
SAGEMAKER = "sagemaker"


class StackComponentType(StrEnum):
"""All possible types a `StackComponent` can have."""

ORCHESTRATOR = "orchestrator"
METADATA_STORE = "metadata_store"
ARTIFACT_STORE = "artifact_store"
CONTAINER_REGISTRY = "container_registry"
STEP_OPERATOR = "step_operator"

@property
def plural(self) -> str:
"""Returns the plural of the enum value."""
if self == StackComponentType.CONTAINER_REGISTRY:
return "container_registries"

return f"{self.value}s"


class MetadataContextTypes(Enum):
"""All possible types that contexts can have within pipeline nodes"""

STACK = "stack"
PIPELINE_REQUIREMENTS = "pipeline_requirements"


class StorageType(StrEnum):
"""Storage Backend Types"""

YAML_STORAGE = "yaml_storage"
SQLITE_STORAGE = "sqlite_storage"
MYSQL_STORAGE = "mysql_storage"
REST_STORAGE = "rest_storage"
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_plot_method(
raise ValueError(
f"Invalid whylogs plot type: {plot} \n\n"
f"Valid and supported options are: {nl}- "
f'{f"{nl}- ".join(WhylogsPlots.list())}'
f'{f"{nl}- ".join(WhylogsPlots.names())}'
)
return plot_method

Expand Down
Loading