Skip to content

Commit

Permalink
Move active_stack back to repository, fix legacy loading
Browse files Browse the repository at this point in the history
  • Loading branch information
jwwwb committed Mar 16, 2022
1 parent 3144bb5 commit 673c89d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 108 deletions.
50 changes: 45 additions & 5 deletions src/zenml/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import base64
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional

Expand Down Expand Up @@ -58,9 +59,32 @@ class RepositoryConfiguration(BaseModel):
"""

version: str
active_stack_name: Optional[str]
storage_type: StorageType


class LegacyRepositoryConfig(BaseModel):
version: str
active_stack_name: Optional[str]
stacks: Dict[str, Dict[StackComponentType, Optional[str]]]
stack_components: Dict[StackComponentType, Dict[str, str]]

def get_stack_data(self) -> StackStoreModel:
"""Extract stack data from Legacy Repository file."""
return StackStoreModel(
stacks={
name: {
component_type: value
for component_type, value in stack.items()
if value is not None # filter out null components
}
for name, stack in self.stacks.items()
},
stack_components=defaultdict(dict, self.stack_components),
**self.dict(exclude={"stacks", "stack_components"}),
)


class Repository:
"""ZenML repository class.
Expand Down Expand Up @@ -116,9 +140,11 @@ def __init__(
"Found old style repository, converting to "
"minimal repository config with separate stack store file."
)
stack_data = StackStoreModel.parse_obj(config_dict)
legacy_config = LegacyRepositoryConfig.parse_obj(config_dict)
stack_data = legacy_config.get_stack_data()
self.__config = RepositoryConfiguration(
version=stack_data.version,
version=legacy_config.version,
active_stack_name=legacy_config.active_stack_name,
storage_type=storage_type,
)
self._write_config()
Expand Down Expand Up @@ -244,7 +270,12 @@ def active_stack(self) -> Stack:
KeyError: If no stack was found for the configured name or one
of the stack components is not registered.
"""
return self.get_stack(name=self.active_stack_name)
if self.__config.active_stack_name is None:
raise RuntimeError(
"No active stack name configured. Run "
"`zenml stack set STACK_NAME` to update the active stack."
)
return self.get_stack(name=self.__config.active_stack_name)

@property
def active_stack_name(self) -> str:
Expand All @@ -253,7 +284,12 @@ def active_stack_name(self) -> str:
Raises:
RuntimeError: If no active stack name is configured.
"""
return self.stack_store.active_stack_name
if self.__config.active_stack_name is None:
raise RuntimeError(
"No active stack name configured. Run "
"`zenml stack set STACK_NAME` to update the active stack."
)
return self.__config.active_stack_name

@track(event=AnalyticsEvent.SET_STACK)
def activate_stack(self, name: str) -> None:
Expand All @@ -265,7 +301,9 @@ def activate_stack(self, name: str) -> None:
Raises:
KeyError: If no stack exists for the given name.
"""
self.stack_store.activate_stack(name)
self.stack_store.get_stack_configuration(name) # raises KeyError
self.__config.active_stack_name = name
self._write_config()

def get_stack(self, name: str) -> Stack:
"""Fetches a stack.
Expand Down Expand Up @@ -309,6 +347,8 @@ def deregister_stack(self, name: str) -> None:
ValueError: If the stack is the currently active stack for this
repository.
"""
if name == self.active_stack_name:
raise ValueError(f"Unable to deregister active stack '{name}'.")
self.stack_store.deregister_stack(name)

def get_stack_components(
Expand Down
50 changes: 8 additions & 42 deletions src/zenml/stack_stores/base_stack_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,6 @@ class BaseStackStore(ABC):
def version(self) -> str:
"""Get the ZenML version."""

@property
@abstractmethod
def active_stack_name(self) -> str:
"""The name of the active stack for this stack store.
Raises:
RuntimeError: If no active stack name is configured.
"""

@abstractmethod
def activate_stack(self, name: str) -> None:
"""Activate the stack for the given name.
Args:
name: Name of the stack to activate.
Raises:
KeyError: If no stack exists for the given name.
"""

@abstractmethod
def get_stack_configuration(
self, name: str
Expand Down Expand Up @@ -95,6 +75,14 @@ def register_stack_component(
and name already exists.
"""

@abstractmethod
def deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
"""

# Private interface (must be implemented, not to be called by user):

@abstractmethod
Expand All @@ -108,14 +96,6 @@ def _create_stack(
stack_configuration: Dict[StackComponentType, str] to persist.
"""

@abstractmethod
def _delete_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
"""

@abstractmethod
def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
Expand Down Expand Up @@ -231,20 +211,6 @@ def __check_component(
self._create_stack(stack.name, stack_configuration)
return metadata

def deregister_stack(self, name: str) -> None:
"""Deregister a stack.
Args:
name: The name of the stack to deregister.
Raises:
ValueError: If the stack is the currently active stack for this
stack store.
"""
if name == self.active_stack_name:
raise ValueError(f"Unable to deregister active stack '{name}'.")
self._delete_stack(name)

def get_stack_component(
self, component_type: StackComponentType, name: str
) -> StackComponentWrapper:
Expand Down
64 changes: 18 additions & 46 deletions src/zenml/stack_stores/local_stack_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.__store = StackStoreModel.parse_obj(config_dict)
else:
self.__store = StackStoreModel.empty_store()
self._write_store()

# Public interface implementations:

Expand All @@ -63,35 +64,6 @@ def version(self) -> str:
"""Get the ZenML version."""
return self.__store.version

@property
def active_stack_name(self) -> str:
"""The name of the active stack for this stack store.
Raises:
RuntimeError: If no active stack name is configured.
"""
if not self.__store.active_stack_name:
raise RuntimeError(
"No active stack name configured. Run "
"`zenml stack set STACK_NAME` to update the active stack."
)
return self.__store.active_stack_name

def activate_stack(self, name: str) -> None:
"""Activate the stack for the given name.
Args:
name: Name of the stack to activate.
Raises:
KeyError: If no stack exists for the given name.
"""
if name not in self.__store.stacks:
raise KeyError(f"Unable to find stack for name '{name}'.")

self.__store.active_stack_name = name
self._write_store()

def get_stack_configuration(
self, name: str
) -> Dict[StackComponentType, str]:
Expand Down Expand Up @@ -164,23 +136,8 @@ def register_stack_component(
"Registered stack component with name '%s'.", component.name
)

# Private interface implementations:

def _create_stack(
self, name: str, stack_configuration: Dict[StackComponentType, str]
) -> None:
"""Add a stack to storage.
Args:
name: The name to save the stack as.
stack_configuration: Dict[StackComponentType, str] to persist.
"""
self.__store.stacks[name] = stack_configuration
self._write_store()
logger.info("Registered stack with name '%s'.", name)

def _delete_stack(self, name: str) -> None:
"""Delete a stack from storage.
def deregister_stack(self, name: str) -> None:
"""Remove a stack from storage.
Args:
name: The name of the stack to be deleted.
Expand All @@ -196,6 +153,21 @@ def _delete_stack(self, name: str) -> None:
name,
)

# Private interface implementations:

def _create_stack(
self, name: str, stack_configuration: Dict[StackComponentType, str]
) -> None:
"""Add a stack to storage.
Args:
name: The name to save the stack as.
stack_configuration: Dict[StackComponentType, str] to persist.
"""
self.__store.stacks[name] = stack_configuration
self._write_store()
logger.info("Registered stack with name '%s'.", name)

def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
) -> Tuple[str, bytes]:
Expand Down
3 changes: 1 addition & 2 deletions src/zenml/stack_stores/models/stack_store_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.

from collections import defaultdict
from typing import DefaultDict, Dict, Optional
from typing import DefaultDict, Dict

from pydantic import BaseModel, validator

Expand All @@ -33,7 +33,6 @@ class StackStoreModel(BaseModel):
"""

version: str
active_stack_name: Optional[str]
stacks: Dict[str, Dict[StackComponentType, str]]
stack_components: DefaultDict[StackComponentType, Dict[str, str]]

Expand Down
26 changes: 13 additions & 13 deletions src/zenml/stack_stores/sql_stack_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,19 @@ def register_stack_component(
session.add(new_component)
session.commit()

def deregister_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
"""
with Session(self.engine) as session:
stack = session.exec(
select(ZenStack).where(ZenStack.name == name)
).one()
session.delete(stack)
session.commit()

# Private interface implementations:

def _create_stack(
Expand All @@ -246,19 +259,6 @@ def _create_stack(
)
session.commit()

def _delete_stack(self, name: str) -> None:
"""Delete a stack from storage.
Args:
name: The name of the stack to be deleted.
"""
with Session(self.engine) as session:
stack = session.exec(
select(ZenStack).where(ZenStack.name == name)
).one()
session.delete(stack)
session.commit()

def _get_component_flavor_and_config(
self, component_type: StackComponentType, name: str
) -> Tuple[str, bytes]:
Expand Down

0 comments on commit 673c89d

Please sign in to comment.