Skip to content

Commit

Permalink
Added more integration tests for Installation
Browse files Browse the repository at this point in the history
This PR fixes couple of bugs with `Installation` utilities
  • Loading branch information
nfx committed Feb 1, 2024
1 parent 96c34f9 commit 67a4145
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 35 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ for blueprint in Installation.existing(ws, "blueprint"):

The `save(obj)` method saves a dataclass instance of type `T` to a file on WorkspaceFS. If no `filename` is provided,
the name of the `type_ref` class will be used as the filename. Any missing parent directories are created automatically.
If the object has a `__version__` attribute, the method will add a `$version` field to the serialized object
If the object has a `__version__` attribute, the method will add a `version` field to the serialized object
with the value of the `__version__` attribute. See [configuration format evolution](#configuration-format-evolution)
for more details. `save(obj)` works with JSON and YAML configurations without the need to supply `filename` keyword
attribute. When you need to save [CSV files](#saving-csv-files), the `filename` attribute is required. If you need to
Expand Down Expand Up @@ -567,7 +567,7 @@ class SomeConfig: # <-- auto-detected filename is `some-config.json`

ws = WorkspaceClient()
installation = Installation.current(ws, "blueprint")
cfg = installation.load(EvolvedConfig)
cfg = installation.load(SomeConfig)

installation.save(SomeConfig("0.1.2"))
installation.assert_file_written("some-config.json", {"version": "0.1.2"})
Expand Down Expand Up @@ -604,13 +604,13 @@ class EvolvedConfig:
@staticmethod
def v1_migrate(raw: dict) -> dict:
raw["added_in_v1"] = 111
raw["$version"] = 2
raw["version"] = 2
return raw

@staticmethod
def v2_migrate(raw: dict) -> dict:
raw["added_in_v2"] = 222
raw["$version"] = 3
raw["version"] = 3
return raw

installation = Installation.current(WorkspaceClient(), "blueprint")
Expand Down Expand Up @@ -681,7 +681,7 @@ installation = MockInstallation()
installation.save(WorkspaceConfig(inventory_database="some_blueprint"))

installation.assert_file_written("config.yml", {
"$version": 2,
"version": 2,
"inventory_database": "some_blueprint",
"log_level": "INFO",
"num_threads": 10,
Expand All @@ -696,7 +696,7 @@ ws.workspace.upload.assert_called_with(
"/Users/foo/.blueprint/config.yml",
yaml.dump(
{
"$version": 2,
"version": 2,
"num_threads": 10,
"inventory_database": "some_blueprint",
"include_group_names": ["foo", "bar"],
Expand Down
39 changes: 26 additions & 13 deletions src/databricks/labs/blueprint/installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def check_folder(install_folder: str) -> Installation | None:
return None

tasks = [functools.partial(check_folder, f"/Applications/{product}")]
for user in ws.users.list(attributes="user_name"):
for user in ws.users.list(attributes="userName"):
user_folder = f"/Users/{user.user_name}/.{product}"
tasks.append(functools.partial(check_folder, user_folder))
return Threads.strict(f"finding {product} installations", tasks)
Expand Down Expand Up @@ -177,7 +177,7 @@ def load(self, type_ref: type[T], *, filename: str | None = None) -> T:
def save(self, inst: T, *, filename: str | None = None):
"""The `save` method saves a dataclass object of type `T` to a file on WorkspaceFS.
If no `filename` is provided, the name of the `type_ref` class will be used as the filename.
If the object has a `__version__` attribute, the method will add a `$version` field to the serialized object
If the object has a `__version__` attribute, the method will add a `version` field to the serialized object
with the value of the `__version__` attribute.
Here is an example of how you can use the `save` method:
Expand Down Expand Up @@ -212,7 +212,7 @@ class MyClass:
version = getattr(inst, "__version__")
as_dict, _ = self._marshal(type_ref, [], inst)
if version:
as_dict["$version"] = version
as_dict["version"] = version
self._overwrite_content(filename, as_dict, type_ref)
return f"{self.install_folder()}/{filename}"

Expand Down Expand Up @@ -251,7 +251,7 @@ def _strip_notebook_source_suffix(cls, dst: str, raw: bytes) -> str:
return dst.removesuffix(f".{ext}")
return dst

def upload_dbfs(self, filename: str, raw: bytes) -> str:
def upload_dbfs(self, filename: str, raw: BinaryIO) -> str:
"""The `upload_dbfs` method uploads raw bytes to a file on DBFS (Databricks File System) with the given
`filename`. This method is used to upload files to DBFS, which is a distributed file system that is integrated
with Databricks."""
Expand Down Expand Up @@ -342,25 +342,33 @@ def _convert_content(cls, filename: str, raw: BinaryIO) -> Json:
def __repr__(self):
return self.install_folder()

def __eq__(self, o):
if not isinstance(o, Installation):
return False
return self.install_folder() == o.install_folder()

Check warning on line 348 in src/databricks/labs/blueprint/installation.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/blueprint/installation.py#L347-L348

Added lines #L347 - L348 were not covered by tests

def __hash__(self):
return hash(self.install_folder())

Check warning on line 351 in src/databricks/labs/blueprint/installation.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/blueprint/installation.py#L351

Added line #L351 was not covered by tests

@staticmethod
def _user_home_installation(ws: WorkspaceClient, product: str):
me = ws.current_user.me()
return f"/Users/{me.user_name}/.{product}"

@staticmethod
def _migrate_file_format(type_ref, expected_version, as_dict, filename):
actual_version = as_dict.pop("$version", 1)
actual_version = as_dict.pop("version", 1)
while actual_version < expected_version:
migrate = getattr(type_ref, f"v{actual_version}_migrate", None)
if not migrate:
break
as_dict = migrate(as_dict)
prev_version = actual_version
actual_version = as_dict.pop("$version", 1)
actual_version = as_dict.pop("version", 1)
if actual_version == prev_version:
raise IllegalState(f"cannot migrate {filename} from v{prev_version}")
if actual_version != expected_version:
raise IllegalState(f"expected state $version={expected_version}, got={actual_version}")
raise IllegalState(f"expected state version={expected_version}, got={actual_version}")
return as_dict

@staticmethod
Expand Down Expand Up @@ -655,10 +663,11 @@ def _dump_csv(raw: list[Json], type_ref: type) -> bytes:

@staticmethod
def _load_csv(raw: BinaryIO) -> list[Json]:
out = []
for row in csv.DictReader(raw): # type: ignore[arg-type]
out.append(row)
return out
with io.TextIOWrapper(raw, encoding="utf8") as text_file:
out = []
for row in csv.DictReader(text_file): # type: ignore[arg-type]
out.append(row)
return out


class MockInstallation(Installation):
Expand Down Expand Up @@ -687,8 +696,8 @@ def upload(self, filename: str, raw: bytes):
self._uploads[filename] = raw
return f"{self.install_folder()}/{filename}"

def upload_dbfs(self, filename: str, raw: bytes) -> str:
self._dbfs[filename] = raw
def upload_dbfs(self, filename: str, raw: BinaryIO) -> str:
self._dbfs[filename] = raw.read()
return f"{self.install_folder()}/{filename}"

def files(self) -> list[workspace.ObjectInfo]:
Expand Down Expand Up @@ -722,6 +731,10 @@ def _load_content(self, filename: str) -> Json:

def assert_file_written(self, filename: str, expected: Any):
assert filename in self._overwrites, f"{filename} had no writes"
if isinstance(expected, dict):
for k, v in expected.items():
if v == ...:
self._overwrites[filename][k] = ...
actual = self._overwrites[filename]
assert expected == actual, f"{filename} content missmatch"

Expand Down
11 changes: 8 additions & 3 deletions src/databricks/labs/blueprint/wheels.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import inspect
import logging
import shutil
Expand All @@ -7,6 +6,7 @@
import tempfile
from contextlib import AbstractContextManager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path

from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -96,7 +96,7 @@ def unreleased_version(self) -> str:
@staticmethod
def _semver_and_pep440(git_detached_version: str) -> str:
dv = SemVer.parse(git_detached_version)
datestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
datestamp = datetime.now().strftime("%Y%m%d%H%M%S")

Check warning on line 99 in src/databricks/labs/blueprint/wheels.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/blueprint/wheels.py#L99

Added line #L99 was not covered by tests
# new commits on main branch since the last tag
new_commits = dv.pre_release.split("-")[0] if dv.pre_release else None
# show that it's a version different from the released one in stats
Expand Down Expand Up @@ -137,6 +137,7 @@ def _read_version(version_file: Path) -> str:
class Version:
version: str
wheel: str
date: str


class WheelsV2(AbstractContextManager):
Expand All @@ -156,9 +157,13 @@ def upload_to_dbfs(self) -> str:
def upload_to_wsfs(self) -> str:
with self._local_wheel.open("rb") as f:
remote_wheel = self._installation.upload(f"wheels/{self._local_wheel.name}", f.read())
self._installation.save(Version(version=self._product_info.version(), wheel=remote_wheel))
self._installation.save(Version(self._product_info.version(), remote_wheel, self._now_iso()))
return remote_wheel

@staticmethod
def _now_iso():
return datetime.now(timezone.utc).isoformat()

def __enter__(self) -> "WheelsV2":
self._tmp_dir = tempfile.TemporaryDirectory()
self._local_wheel = self._build_wheel(self._tmp_dir.name, verbose=self._verbose)
Expand Down
67 changes: 67 additions & 0 deletions tests/integration/test_installation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,71 @@
from dataclasses import dataclass

import pytest
from databricks.sdk.service.provisioning import Workspace

from databricks.labs.blueprint.installation import Installation


def test_install_folder(ws):
installation = Installation(ws, "blueprint")

assert installation.install_folder() == f"/Users/{ws.current_user.me().user_name}/.blueprint"


def test_install_folder_custom(ws):
installation = Installation(ws, "blueprint", install_folder="/custom/folder")

assert installation.install_folder() == "/custom/folder"


def test_detect_global(ws, make_random):
product = make_random(4)
Installation(ws, product, install_folder=f"/Applications/{product}").upload("some", b"...")

current = Installation.current(ws, product)

assert current.install_folder() == f"/Applications/{product}"


def test_existing(ws, make_random):
product = make_random(4)

global_install = Installation(ws, product, install_folder=f"/Applications/{product}")
global_install.upload("some", b"...")

user_install = Installation(ws, product)
user_install.upload("some2", b"...")

existing = Installation.existing(ws, product)
assert set(existing) == {global_install, user_install}


@dataclass
class MyClass:
field1: str
field2: str


def test_dataclass(new_installation):
obj = MyClass("value1", "value2")
new_installation.save(obj)

# Verify that the object was saved correctly
loaded_obj = new_installation.load(MyClass)
assert loaded_obj == obj


def test_csv(new_installation):
new_installation.save(
[
Workspace(workspace_id=1234, workspace_name="first"),
Workspace(workspace_id=1235, workspace_name="second"),
],
filename="workspaces.csv",
)

loaded = new_installation.load(list[Workspace], filename="workspaces.csv")
assert len(loaded) == 2


@pytest.mark.parametrize(
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_save_typed_file():
"/Users/foo/.blueprint/config.yml",
yaml.dump(
{
"$version": 2,
"version": 2,
"num_threads": 10,
"inventory_database": "some_blueprint",
"include_group_names": ["foo", "bar"],
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_load_typed_file():
ws.workspace.download.return_value = io.StringIO(
yaml.dump(
{
"$version": 2,
"version": 2,
"num_threads": 20,
"inventory_database": "some_blueprint",
"connect": {"host": "https://foo", "token": "bar"},
Expand All @@ -197,8 +197,8 @@ def test_load_typed_file():
def test_load_csv_file():
ws = create_autospec(WorkspaceClient)
ws.current_user.me().user_name = "foo"
ws.workspace.download.return_value = io.StringIO(
"\n".join(["workspace_id,workspace_name", "1234,first", "1235,second"])
ws.workspace.download.return_value = io.BytesIO(
"\n".join(["workspace_id,workspace_name", "1234,first", "1235,second"]).encode("utf8")
)
installation = Installation(ws, "blueprint")

Expand Down Expand Up @@ -252,7 +252,7 @@ def test_mock_save_typed_file():
installation.assert_file_written(
"config.yml",
{
"$version": 2,
"version": 2,
"inventory_database": "some_blueprint",
"log_level": "INFO",
"num_threads": 10,
Expand Down Expand Up @@ -286,13 +286,13 @@ class EvolvedConfig:
@staticmethod
def v1_migrate(raw: dict) -> dict:
raw["added_in_v1"] = 111
raw["$version"] = 2
raw["version"] = 2
return raw

@staticmethod
def v2_migrate(raw: dict) -> dict:
raw["added_in_v2"] = 222
raw["$version"] = 3
raw["version"] = 3
return raw


Expand All @@ -318,7 +318,7 @@ class BrokenConfig:
@staticmethod
def v1_migrate(raw: dict) -> dict:
raw["added_in_v1"] = 111
raw["$version"] = 2
raw["version"] = 2
return {}


Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_install_folder():
def test_jobs_state():
ws = create_autospec(WorkspaceClient)
ws.current_user.me().user_name = "foo"
ws.workspace.download.return_value = io.StringIO('{"$version":1, "resources": {"jobs": {"foo": 123}}}')
ws.workspace.download.return_value = io.StringIO('{"version":1, "resources": {"jobs": {"foo": 123}}}')

state = InstallState(ws, "blueprint")

Expand All @@ -39,7 +39,7 @@ def test_jobs_state():
def test_invalid_config_version():
ws = create_autospec(WorkspaceClient)
ws.current_user.me().user_name = "foo"
ws.workspace.download.return_value = io.StringIO('{"$version":9, "resources": {"jobs": [1,2,3]}}')
ws.workspace.download.return_value = io.StringIO('{"version":9, "resources": {"jobs": [1,2,3]}}')

state = InstallState(ws, "blueprint")

Expand Down Expand Up @@ -70,13 +70,13 @@ def test_state_corrupt():
def test_state_overwrite_existing():
ws = create_autospec(WorkspaceClient)
ws.current_user.me().user_name = "foo"
ws.workspace.download.return_value = io.StringIO('{"$version":1, "resources": {"sql": {"a": "b"}}}')
ws.workspace.download.return_value = io.StringIO('{"version":1, "resources": {"sql": {"a": "b"}}}')

state = InstallState(ws, "blueprint")
state.jobs["foo"] = "bar"
state.save()

new_state = {"resources": {"sql": {"a": "b"}, "jobs": {"foo": "bar"}}, "$version": 1}
new_state = {"resources": {"sql": {"a": "b"}, "jobs": {"foo": "bar"}}, "version": 1}
ws.workspace.upload.assert_called_with(
"/Users/foo/.blueprint/state.json",
json.dumps(new_state, indent=2).encode("utf8"),
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ def test_build_and_upload_wheel():

remote_on_wsfs = wheels.upload_to_wsfs()
installation.assert_file_uploaded(re.compile("wheels/databricks_labs_blueprint-*"))
installation.assert_file_written("version.json", {"version": product_info.version(), "wheel": remote_on_wsfs})
installation.assert_file_written(
"version.json",
{
"version": product_info.version(),
"wheel": remote_on_wsfs,
"date": ...,
},
)

wheels.upload_to_dbfs()
installation.assert_file_dbfs_uploaded(re.compile("wheels/databricks_labs_blueprint-*"))
Expand Down

0 comments on commit 67a4145

Please sign in to comment.