Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement base component tests #211

Merged
merged 3 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions src/zenml/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,14 @@

@cli.command("init", help="Initialize zenml on a given path.")
@click.option("--repo_path", type=click.Path(exists=True))
@click.option("--analytics_opt_in", "-a", type=click.BOOL)
@track(event=INITIALIZE_REPO)
def init(
repo_path: Optional[str],
analytics_opt_in: bool = True,
) -> None:
"""Initialize ZenML on given path.

Args:
repo_path: Path to the repository.
analytics_opt_in: (Default value = True)

Raises:
InvalidGitRepositoryError: If repo is not a git repo.
Expand All @@ -55,10 +52,7 @@ def init(
declare(f"Initializing at {repo_path}")

try:
Repository.init_repo(
repo_path=repo_path,
analytics_opt_in=analytics_opt_in,
)
Repository.init_repo(repo_path=repo_path)
declare(f"ZenML repo initialized at {repo_path}")
except git.InvalidGitRepositoryError: # type: ignore[attr-defined]
error(
Expand Down
9 changes: 0 additions & 9 deletions src/zenml/core/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from git import InvalidGitRepositoryError # type: ignore[attr-defined]

from zenml.config.global_config import GlobalConfig
from zenml.core.constants import ZENML_DIR_NAME
from zenml.core.git_wrapper import GitWrapper
from zenml.core.local_service import LocalService
Expand Down Expand Up @@ -70,15 +69,13 @@ def __init__(self, path: Optional[str] = None):
def init_repo(
repo_path: str = os.getcwd(),
stack: Optional[BaseStack] = None,
analytics_opt_in: bool = True,
) -> None:
"""
Initializes a git repo with zenml.

Args:
repo_path (str): path to root of a git repo
stack: Initial stack.
analytics_opt_in: opt-in flag for analytics code.

Raises:
InvalidGitRepositoryError: If repository is not a git repository.
Expand All @@ -88,12 +85,6 @@ def init_repo(
if fileio.is_zenml_dir(repo_path):
raise AssertionError(f"{repo_path} is already initialized!")

# Edit global config
if analytics_opt_in is not None:
gc = GlobalConfig()
gc.analytics_opt_in = analytics_opt_in
gc.update()

try:
GitWrapper(repo_path)
except InvalidGitRepositoryError:
Expand Down
21 changes: 0 additions & 21 deletions tests/config/test_global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,3 @@ def test_global_config_file_creation():

# Raw config should now exist
assert fileio.file_exists(os.path.join(APP_DIR, GLOBAL_CONFIG_NAME))


# def test_global_config_persistence():
# """A simple test to check whether the persistence logic works."""
# gc = GlobalConfig()
#
# # Track old one
# old_analytics_opt_in = gc.analytics_opt_in
#
# # Toggle it
# gc.analytics_opt_in = not old_analytics_opt_in
#
# # Initialize new config
# gc = GlobalConfig()
#
# # It still should be equal to the old value, as we have not saved
# assert old_analytics_opt_in == gc.analytics_opt_in
#
# # Get raw config
# raw_config = yaml_utils.read_json(os.path.join(APP_DIR, GLOBAL_CONFIG_NAME))
# assert raw_config["analytics_opt_in"] == old_analytics_opt_in
58 changes: 43 additions & 15 deletions tests/core/test_base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,57 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from unittest.mock import Mock, patch
from uuid import uuid4

from typing import Text
import pytest

from zenml.core.base_component import BaseComponent
from zenml.utils.yaml_utils import read_json, write_json


class MockComponent(BaseComponent):
"""Mocking the base component for testing."""
@pytest.fixture(scope="module", autouse=True)
def patch_serialization_dir(tmp_path_factory):
"""Patches the abstract `BaseComponent` class so it can be instantiated
and returns a temporary serialization dir."""
directory = str(tmp_path_factory.mktemp("base_component"))
with patch.multiple(
"zenml.core.base_component.BaseComponent",
__abstractmethods__=set(),
get_serialization_dir=Mock(return_value=directory),
):
yield

tmp_path: str

def get_serialization_dir(self) -> Text:
"""Mock serialization dir"""
return self.tmp_path
def test_base_component_detects_superfluous_arguments():
"""Tests that the base component correctly detects arguments that are
not defined in the class."""
component = BaseComponent(some_random_key=None, uuid=uuid4())
assert "some_random_key" in component._superfluous_options
assert "uuid" not in component._superfluous_options


def test_base_component_serialization_logic(tmp_path):
"""Tests the UUID serialization logic of BaseComponent"""
def test_base_component_creates_backup_file_if_schema_changes():
"""Tests that a base component creates a backup file if the json file
schema is different than the current class definiton."""
uuid = uuid4()
component = BaseComponent(uuid=uuid)
config_path = component.get_serialization_full_path()
# write a config file with a superfluous key
write_json(config_path, {"uuid": str(uuid), "superfluous": 0})

# Application of the monkeypatch to replace Path.home
# with the behavior of mockreturn defined above.
# mc = MockComponent(tmp_path=str(tmp_path))
# instantiate new BaseComponent which should create a backup file
BaseComponent(uuid=uuid)

# Calling getssh() will use mockreturn in place of Path.home
# for this test with the monkeypatch.
# print(mc.get_serialization_dir())
# config dict should be with new schema, so no superfluous options
config_dict = read_json(config_path)
assert config_dict["uuid"] == str(uuid)
assert "superfluous" not in config_dict

# backup file should contain both the uuid and superfluous options
backup_path = config_path + ".backup"
assert os.path.exists(backup_path)
backup_dict = read_json(backup_path)
assert backup_dict["superfluous"] == 0
assert backup_dict["uuid"] == str(uuid)
12 changes: 3 additions & 9 deletions tests/core/test_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,23 @@ def test_repo_double_init(tmp_path: str) -> None:
os.mkdir(os.path.join(tmp_path, ZENML_DIR_NAME))

with pytest.raises(Exception):
_ = Repository(str(tmp_path)).init_repo(
repo_path=tmp_path, analytics_opt_in=False
)
Repository(str(tmp_path)).init_repo(repo_path=tmp_path)


def test_repo_init_without_git_repo_initialized_raises_error(
tmp_path: str,
) -> None:
"""Check initializing repository without git repository raises error"""
with pytest.raises(Exception):
_ = Repository(str(tmp_path)).init_repo(
repo_path=tmp_path, analytics_opt_in=False
)
Repository(str(tmp_path)).init_repo(repo_path=tmp_path)


def test_init_repo_creates_a_zen_folder(tmp_path: str) -> None:
"""Check initializing repository creates a ZenML folder"""
_ = Repo.init(tmp_path)
repo = Repository(str(tmp_path))
local_stack = LocalService().get_stack("local_stack")
repo.init_repo(
repo_path=tmp_path, analytics_opt_in=False, stack=local_stack
)
repo.init_repo(repo_path=tmp_path, stack=local_stack)
assert os.path.exists(os.path.join(tmp_path, ZENML_DIR_NAME))


Expand Down