Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kubeflow dockerignore #249

Merged
merged 9 commits into from
Dec 21, 2021
20 changes: 13 additions & 7 deletions src/zenml/artifact_stores/base_artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Definition of an Artifact Store"""

import os
from typing import Optional
from typing import Any, Optional

from zenml.config.global_config import GlobalConfig
from zenml.core.base_component import BaseComponent
Expand All @@ -31,6 +31,18 @@ class BaseArtifactStore(BaseComponent):
path: str
_ARTIFACT_STORE_DIR_NAME: str = "artifact_stores"

def __init__(self, repo_path: str, **kwargs: Any) -> None:
"""Initializes a BaseArtifactStore instance.

Args:
repo_path: Path to the repository of this artifact store.
"""
serialization_dir = os.path.join(
get_zenml_config_dir(repo_path),
self._ARTIFACT_STORE_DIR_NAME,
)
super().__init__(serialization_dir=serialization_dir, **kwargs)

@staticmethod
def get_component_name_from_uri(artifact_uri: str) -> str:
"""Gets component name from artifact URI.
Expand All @@ -43,12 +55,6 @@ def get_component_name_from_uri(artifact_uri: str) -> str:
"""
return fileio.get_grandparent(artifact_uri)

def get_serialization_dir(self) -> str:
"""Gets the local path where artifacts are stored."""
return os.path.join(
get_zenml_config_dir(), self._ARTIFACT_STORE_DIR_NAME
)

def resolve_uri_locally(
self, artifact_uri: str, path: Optional[str] = None
) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def register_artifact_store(name: str, type: str, args: List[str]) -> None:
from zenml.core.component_factory import artifact_store_factory

comp = artifact_store_factory.get_single_component(type)
artifact_store = comp(**parsed_args)
artifact_store = comp(repo_path=repo.path, **parsed_args)
service = repo.get_service()
service.register_artifact_store(name, artifact_store)
cli_utils.declare(f"Artifact Store `{name}` successfully registered!")
Expand Down
12 changes: 3 additions & 9 deletions src/zenml/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import Optional

import click
import git

from zenml.cli.cli import cli
from zenml.cli.utils import confirmation, declare, error, warning
from zenml.core.repo import Repository
from zenml.exceptions import InitializationException
from zenml.utils.analytics_utils import INITIALIZE_REPO, track


Expand Down Expand Up @@ -52,15 +52,9 @@ def init(
declare(f"Initializing at {repo_path}")

try:
Repository.init_repo(repo_path=repo_path)
Repository.init_repo(path=repo_path)
declare(f"ZenML repo initialized at {repo_path}")
except git.InvalidGitRepositoryError: # type: ignore[attr-defined]
error(
f"{repo_path} is not a valid git repository! Please "
f"initialize ZenML within a git repository using "
f"`git init `"
)
except AssertionError as e:
except InitializationException as e:
error(f"{e}")


Expand Down
6 changes: 3 additions & 3 deletions src/zenml/cli/container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def register_container_registry(name: str, uri: str) -> None:
BaseContainerRegistry,
)

registry = BaseContainerRegistry(uri=uri)
service = Repository().get_service()
service.register_container_registry(name, registry)
repo = Repository()
registry = BaseContainerRegistry(uri=uri, repo_path=repo.path)
repo.get_service().register_container_registry(name, registry)
cli_utils.declare(f"Container registry `{name}` successfully registered!")


Expand Down
5 changes: 3 additions & 2 deletions src/zenml/cli/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from git.repo.base import Repo
from packaging.version import Version, parse

import zenml.io.utils
from zenml import __version__ as zenml_version_installed
from zenml.cli.cli import cli
from zenml.cli.utils import (
Expand All @@ -33,7 +34,7 @@
title,
warning,
)
from zenml.constants import APP_NAME, GIT_REPO_URL
from zenml.constants import GIT_REPO_URL
from zenml.io import fileio
from zenml.logger import get_logger

Expand Down Expand Up @@ -280,7 +281,7 @@ class GitExamplesHandler(object):

def __init__(self) -> None:
"""Create a new GitExamplesHandler instance."""
self.repo_dir = click.get_app_dir(APP_NAME)
self.repo_dir = zenml.io.utils.get_global_config_directory()
self.examples_dir = Path(
os.path.join(self.repo_dir, EXAMPLES_GITHUB_REPO)
)
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def register_metadata_store(name: str, type: str, args: List[str]) -> None:
cli_utils.error(str(e))
return

metadata_store = comp(**parsed_args)
metadata_store = comp(repo_path=repo.path, **parsed_args)
service = repo.get_service()
service.register_metadata_store(name, metadata_store)
cli_utils.declare(f"Metadata Store `{name}` successfully registered!")
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def register_orchestrator(name: str, type: str, args: List[str]) -> None:
from zenml.core.component_factory import orchestrator_store_factory

comp = orchestrator_store_factory.get_single_component(type)
orchestrator_ = comp(**parsed_args)
orchestrator_ = comp(repo_path=repo.path, **parsed_args)
service = repo.get_service()
service.register_orchestrator(name, orchestrator_)
cli_utils.declare(f"Orchestrator `{name}` successfully registered!")
Expand Down
14 changes: 5 additions & 9 deletions src/zenml/config/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from typing import Any
from uuid import UUID, uuid4

import click
from pydantic import Field

from zenml import constants
import zenml.io.utils
from zenml.config.constants import GLOBAL_CONFIG_NAME
from zenml.core.base_component import BaseComponent
from zenml.io import fileio
Expand All @@ -41,20 +40,17 @@ def __init__(self, **data: Any):
"""We persist the attributes in the config file. For the global
config, we want to persist the data as soon as it is initialized for
the first time."""
super().__init__(**data)
super().__init__(
serialization_dir=zenml.io.utils.get_global_config_directory(),
**data
)

# At this point, if the serialization file does not exist we should
# create it and dump our data.
f = self.get_serialization_full_path()
if not fileio.file_exists(str(f)):
self._dump()

def get_serialization_dir(self) -> str:
"""Gets the global config dir for installed package."""
# using a version-pinned folder avoids conflicts when
# upgrading zenml versions.
return click.get_app_dir(constants.APP_NAME)

def get_serialization_file_name(self) -> str:
"""Gets the global config dir for installed package."""
return GLOBAL_CONFIG_NAME
15 changes: 11 additions & 4 deletions src/zenml/container_registry/base_container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Base class for all container registries."""

import os
from typing import Any

from zenml.core.base_component import BaseComponent
from zenml.io.utils import get_zenml_config_dir
Expand All @@ -25,11 +26,17 @@ class BaseContainerRegistry(BaseComponent):
uri: str
_CONTAINER_REGISTRY_DIR_NAME: str = "container_registries"

def get_serialization_dir(self) -> str:
"""Gets the local path where artifacts are stored."""
return os.path.join(
get_zenml_config_dir(), self._CONTAINER_REGISTRY_DIR_NAME
def __init__(self, repo_path: str, **kwargs: Any) -> None:
"""Initializes a BaseContainerRegistry instance.

Args:
repo_path: Path to the repository of this container registry.
"""
serialization_dir = os.path.join(
get_zenml_config_dir(repo_path),
self._CONTAINER_REGISTRY_DIR_NAME,
)
super().__init__(serialization_dir=serialization_dir, **kwargs)

class Config:
"""Configuration of settings."""
Expand Down
24 changes: 17 additions & 7 deletions src/zenml/core/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from abc import abstractmethod
from typing import Any, Dict, Optional
from uuid import UUID, uuid4

Expand Down Expand Up @@ -45,30 +44,32 @@ class BaseComponent(BaseSettings):
uuid: Optional[UUID] = Field(default_factory=uuid4)
_file_suffix = ".json"
_superfluous_options: Dict[str, Any] = {}
_serialization_dir: str

def __init__(self, **values: Any):
def __init__(self, serialization_dir: str, **values: Any):
# Here, we insert monkey patch the `customise_sources` function
# because we want to dynamically generate the serialization
# file path and name.

if hasattr(self, "uuid"):
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
self.get_serialization_dir(),
serialization_dir,
self.get_serialization_file_name(),
)
elif "uuid" in values:
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
self.get_serialization_dir(),
serialization_dir,
f"{str(values['uuid'])}{self._file_suffix}",
)
else:
self.__config__.customise_sources = generate_customise_sources( # type: ignore[assignment] # noqa
self.get_serialization_dir(),
serialization_dir,
self.get_serialization_file_name(),
)

# Initialize values from the above sources.
super().__init__(**values)
self._serialization_dir = serialization_dir
self._save_backup_file_if_required()

def _save_backup_file_if_required(self) -> None:
Expand Down Expand Up @@ -101,15 +102,24 @@ def _dump(self) -> None:
)
zenml.io.utils.write_file_contents_as_string(file_path, file_content)

def dict(self, **kwargs: Any) -> Dict[str, Any]:
"""Removes private attributes from pydantic dict so they don't get
stored in our config files."""
return {
key: value
for key, value in super().dict(**kwargs).items()
if not key.startswith("_")
}

def _create_serialization_file_if_not_exists(self) -> None:
"""Creates the serialization file if it does not exist."""
f = self.get_serialization_full_path()
if not fileio.file_exists(str(f)):
fileio.create_file_if_not_exists(str(f))

@abstractmethod
def get_serialization_dir(self) -> str:
"""Return the dir where object is serialized."""
return self._serialization_dir

def get_serialization_file_name(self) -> str:
"""Return the name of the file where object is serialized. This
Expand All @@ -125,7 +135,7 @@ def get_serialization_file_name(self) -> str:
def get_serialization_full_path(self) -> str:
"""Returns the full path of the serialization file."""
return os.path.join(
self.get_serialization_dir(), self.get_serialization_file_name()
self._serialization_dir, self.get_serialization_file_name()
)

def update(self) -> None:
Expand Down
34 changes: 24 additions & 10 deletions src/zenml/core/local_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Any, Dict

import zenml.io.utils
from zenml.core import mapping_utils
Expand Down Expand Up @@ -42,9 +42,17 @@ class LocalService(BaseComponent):

_LOCAL_SERVICE_FILE_NAME = "zenservice.json"

def get_serialization_dir(self) -> str:
"""The local service stores everything in the zenml config dir."""
return zenml.io.utils.get_zenml_config_dir()
def __init__(self, repo_path: str, **kwargs: Any) -> None:
"""Initializes a LocalService instance.

Args:
repo_path: Path to the repository of this service.
"""
serialization_dir = zenml.io.utils.get_zenml_config_dir(repo_path)
super().__init__(serialization_dir=serialization_dir, **kwargs)
self._repo_path = repo_path
for stack in self.stacks.values():
stack._repo_path = repo_path

def get_serialization_file_name(self) -> str:
"""Return the name of the file where object is serialized."""
Expand All @@ -56,7 +64,9 @@ def metadata_stores(self) -> Dict[str, "BaseMetadataStore"]:
from zenml.metadata_stores import BaseMetadataStore

return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseMetadataStore._METADATA_STORE_DIR_NAME, self.metadata_store_map
BaseMetadataStore._METADATA_STORE_DIR_NAME,
self.metadata_store_map,
self._repo_path,
)

@property
Expand All @@ -65,7 +75,9 @@ def artifact_stores(self) -> Dict[str, "BaseArtifactStore"]:
from zenml.artifact_stores import BaseArtifactStore

return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseArtifactStore._ARTIFACT_STORE_DIR_NAME, self.artifact_store_map
BaseArtifactStore._ARTIFACT_STORE_DIR_NAME,
self.artifact_store_map,
self._repo_path,
)

@property
Expand All @@ -76,6 +88,7 @@ def orchestrators(self) -> Dict[str, "BaseOrchestrator"]:
return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseOrchestrator._ORCHESTRATOR_STORE_DIR_NAME,
self.orchestrator_map,
self._repo_path,
)

@property
Expand All @@ -88,6 +101,7 @@ def container_registries(self) -> Dict[str, "BaseContainerRegistry"]:
return mapping_utils.get_components_from_store( # type: ignore[return-value] # noqa
BaseContainerRegistry._CONTAINER_REGISTRY_DIR_NAME,
self.container_registry_map,
self._repo_path,
)

def get_active_stack_key(self) -> str:
Expand Down Expand Up @@ -188,7 +202,7 @@ def get_artifact_store(self, key: str) -> "BaseArtifactStore":
f"Available keys: {list(self.artifact_store_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.artifact_store_map
key, self.artifact_store_map, self._repo_path
)

@track(event=REGISTERED_ARTIFACT_STORE)
Expand Down Expand Up @@ -246,7 +260,7 @@ def get_metadata_store(self, key: str) -> "BaseMetadataStore":
f"Available keys: {list(self.metadata_store_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.metadata_store_map
key, self.metadata_store_map, self._repo_path
)

@track(event=REGISTERED_METADATA_STORE)
Expand Down Expand Up @@ -304,7 +318,7 @@ def get_orchestrator(self, key: str) -> "BaseOrchestrator":
f"Available keys: {list(self.orchestrator_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.orchestrator_map
key, self.orchestrator_map, self._repo_path
)

@track(event=REGISTERED_ORCHESTRATOR)
Expand Down Expand Up @@ -362,7 +376,7 @@ def get_container_registry(self, key: str) -> "BaseContainerRegistry":
f"Available keys: {list(self.container_registry_map.keys())}"
)
return mapping_utils.get_component_from_key( # type: ignore[return-value] # noqa
key, self.container_registry_map
key, self.container_registry_map, self._repo_path
)

@track(event=REGISTERED_CONTAINER_REGISTRY)
Expand Down
Loading