From 779ed284faa804e92f11cd8e9f311b67fd604806 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Sat, 9 Mar 2024 16:23:03 +0100 Subject: [PATCH] Added more code documentation (#64) --- src/databricks/labs/blueprint/cli.py | 5 ++ src/databricks/labs/blueprint/entrypoint.py | 2 + src/databricks/labs/blueprint/installation.py | 86 ++++++++++++++++++- src/databricks/labs/blueprint/installer.py | 2 + src/databricks/labs/blueprint/limiter.py | 2 + src/databricks/labs/blueprint/logger.py | 10 +++ src/databricks/labs/blueprint/parallel.py | 8 ++ src/databricks/labs/blueprint/tui.py | 2 + src/databricks/labs/blueprint/upgrades.py | 6 ++ src/databricks/labs/blueprint/wheels.py | 23 +++++ 10 files changed, 143 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/blueprint/cli.py b/src/databricks/labs/blueprint/cli.py index b9be83a..1d6c477 100644 --- a/src/databricks/labs/blueprint/cli.py +++ b/src/databricks/labs/blueprint/cli.py @@ -1,3 +1,5 @@ +"""Baseline CLI for Databricks Labs projects.""" + import functools import json import logging @@ -33,6 +35,8 @@ def __init__(self, __file: str): self._product_info = ProductInfo(__file) def command(self, fn=None, is_account: bool = False, is_unauthenticated: bool = False): + """Decorator to register a function as a command.""" + def register(func): command_name = func.__name__.replace("_", "-") if not func.__doc__: @@ -52,6 +56,7 @@ def register(func): return fn def _route(self, raw): + """Route the command. This is the entry point for the CLI.""" payload = json.loads(raw) command = payload["command"] if command not in self._mapping: diff --git a/src/databricks/labs/blueprint/entrypoint.py b/src/databricks/labs/blueprint/entrypoint.py index d7c6f31..402d2b6 100644 --- a/src/databricks/labs/blueprint/entrypoint.py +++ b/src/databricks/labs/blueprint/entrypoint.py @@ -1,3 +1,5 @@ +"""Entrypoint utilities for logging and project root detection""" + import logging import os import sys diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 92f43f3..8c28ab8 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -1,3 +1,5 @@ +"""The `Installation` class is used to manage the `~/.{product}` folder on WorkspaceFS to track typed files.""" + import csv import dataclasses import enum @@ -43,11 +45,11 @@ class IllegalState(ValueError): class NotInstalled(NotFound): - pass + """Raised when a product is not installed.""" class SerdeError(TypeError): - pass + """Raised when a serialization or deserialization error occurs.""" class Installation: @@ -66,6 +68,8 @@ class Installation: _PRIMITIVES = (int, bool, float, str) def __init__(self, ws: WorkspaceClient, product: str, *, install_folder: str | None = None): + """The `Installation` class constructor creates an `Installation` object for the given product in + the current workspace.""" self._ws = ws self._product = product self._install_folder = install_folder @@ -129,11 +133,13 @@ def check_folder(install_folder: str) -> Installation | None: @classmethod def load_local(cls, type_ref: type[T], file: Path) -> T: + """Loads a typed file from the local file system.""" with file.open("rb") as f: as_dict = cls._convert_content(file.name, f) return cls._unmarshal_type(as_dict, file.name, type_ref) def product(self) -> str: + """The `product` method returns the name of the product associated with the installation.""" return self._product def install_folder(self) -> str: @@ -185,6 +191,7 @@ def is_global(self) -> bool: return self.install_folder() == self._global_installation(self._product) def username(self) -> str: + """Returns the username associated with the installation folder""" return os.path.basename(os.path.dirname(self.install_folder())) def load(self, type_ref: type[T], *, filename: str | None = None) -> T: @@ -202,6 +209,7 @@ def load(self, type_ref: type[T], *, filename: str | None = None) -> T: return self._unmarshal_type(as_dict, filename, type_ref) def load_or_default(self, type_ref: type[T]) -> T: + """If the file is not found, the method will return a default instance of the `type_ref` class.""" try: return self.load(type_ref) except NotFound: @@ -276,6 +284,7 @@ def upload(self, filename: str, raw: bytes): @classmethod def _strip_notebook_source_suffix(cls, dst: str, raw: bytes) -> str: + """If the file is a Databricks notebook, the method will remove the suffix from the filename.""" if "." not in dst: return dst ext = dst.split(".")[-1] @@ -305,9 +314,13 @@ def upload_dbfs(self, filename: str, raw: BinaryIO) -> str: return dst def files(self) -> list[workspace.ObjectInfo]: + """The `files` method returns a list of all files in the installation folder on WorkspaceFS. + This method is used to list the files that are managed by the `Installation` class.""" return list(self._ws.workspace.list(self.install_folder(), recursive=True)) def remove(self): + """The `remove` method deletes the installation folder on WorkspaceFS. + This method is used to remove all files and folders that are managed by the `Installation` class.""" self._ws.workspace.delete(self.install_folder(), recursive=True) def workspace_link(self, path: str) -> str: @@ -321,9 +334,11 @@ def workspace_link(self, path: str) -> str: return f"{self._host()}/#workspace{self.install_folder()}/{path.removeprefix('/')}" def workspace_markdown_link(self, label: str, path: str) -> str: + """Returns a markdown link to a file in a workspace.""" return f"[{label}]({self.workspace_link(path)})" def _host(self): + """Returns the host of the current workspace.""" return self._ws.config.host def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): @@ -346,10 +361,14 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): @staticmethod def _global_installation(product): + """The `_global_installation` method is a private method that is used to determine the installation folder + for the given product in the `/Applications` directory. This method is called by the `install_folder` method.""" return f"/Applications/{product}" @classmethod def _unmarshal_type(cls, as_dict, filename, type_ref): + """The `_unmarshal_type` method is a private method that is used to deserialize a dictionary to an object of + type `type_ref`. This method is called by the `load` method.""" expected_version = None if hasattr(type_ref, "__version__"): expected_version = getattr(type_ref, "__version__") @@ -358,6 +377,8 @@ def _unmarshal_type(cls, as_dict, filename, type_ref): return cls._unmarshal(as_dict, [], type_ref) def _load_content(self, filename: str) -> Json: + """The `_load_content` method is a private method that is used to load the contents of a file from + WorkspaceFS as a dictionary. This method is called by the `load` method.""" with self._lock: # TODO: check how to make this fail fast during unit testing, otherwise # this currently hangs with the real installation class and mocked workspace client @@ -366,6 +387,8 @@ def _load_content(self, filename: str) -> Json: @classmethod def _convert_content(cls, filename: str, raw: BinaryIO) -> Json: + """The `_convert_content` method is a private method that is used to convert the raw bytes of a file to a + dictionary. This method is called by the `_load_content` method.""" converters: dict[str, Callable[[BinaryIO], Any]] = { "json": json.load, "yml": cls._load_yaml, @@ -388,15 +411,20 @@ def __eq__(self, o): return self.install_folder() == o.install_folder() def __hash__(self): + """The `__hash__` method is used to hash the `Installation` object. + This method is called by the `hash` function.""" return hash(self.install_folder()) @staticmethod def _user_home_installation(ws: WorkspaceClient, product: str): + """The `_user_home_installation` method is a private method that is used to determine the installation folder + for the current user. This method is called by the `install_folder` method.""" me = ws.current_user.me() return f"/Users/{me.user_name}/.{product}" @staticmethod def _migrate_file_format(type_ref, expected_version, as_dict, filename): + """The `_migrate_file_format` method is a private method that is used to migrate the file format of a file""" actual_version = as_dict.pop("version", 1) while actual_version < expected_version: migrate = getattr(type_ref, f"v{actual_version}_migrate", None) @@ -413,6 +441,8 @@ def _migrate_file_format(type_ref, expected_version, as_dict, filename): @staticmethod def _get_filename(filename: str | None, type_ref: type) -> str: + """The `_get_filename` method is a private method that is used to determine the filename of a file based on + the type of the object that is being saved. This method is called by the `save` method.""" if not filename and hasattr(type_ref, "__file__"): return getattr(type_ref, "__file__") if not filename: @@ -422,6 +452,8 @@ def _get_filename(filename: str | None, type_ref: type) -> str: @classmethod def _get_type_ref(cls, inst) -> type: + """The `_get_type_ref` method is a private method that is used to determine the type of an object. This method + is called by the `save` method.""" type_ref = type(inst) if type_ref == list: return cls._get_list_type_ref(inst) @@ -429,6 +461,7 @@ def _get_type_ref(cls, inst) -> type: @staticmethod def _get_list_type_ref(inst: T) -> type[list[T]]: + """The `_get_list_type_ref` method is a private method that is used to determine the type of a list object.""" from_list: list = inst # type: ignore[assignment] if len(from_list) == 0: raise ValueError("List cannot be empty") @@ -437,6 +470,8 @@ def _get_list_type_ref(inst: T) -> type[list[T]]: @classmethod def _marshal(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal` method is a private method that is used to serialize an object of type `type_ref` to + a dictionary. This method is called by the `save` method.""" # pylint: disable-next=import-outside-toplevel from typing import ( # type: ignore[attr-defined] _GenericAlias, @@ -470,6 +505,8 @@ def _marshal(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool @classmethod def _marshal_union(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_union` method is a private method that is used to serialize an object of type `type_ref` to + a dictionary. This method is called by the `save` method.""" combo = [] for variant in get_args(type_ref): value, ok = cls._marshal(variant, [*path, f"(as {variant})"], inst) @@ -480,6 +517,8 @@ def _marshal_union(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any @classmethod def _marshal_generic(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_generic` method is a private method that is used to serialize an object of type `type_ref` + to a dictionary. This method is called by the `save` method.""" type_args = get_args(type_ref) if not type_args: raise SerdeError(f"Missing type arguments: {type_args}") @@ -489,12 +528,16 @@ def _marshal_generic(cls, type_ref: type, path: list[str], inst: Any) -> tuple[A @staticmethod def _marshal_generic_alias(type_ref, inst): + """The `_marshal_generic_alias` method is a private method that is used to serialize an object of type + `type_ref` to a dictionary. This method is called by the `save` method.""" if not inst: return None, False return inst, isinstance(inst, type_ref.__origin__) # type: ignore[attr-defined] @classmethod def _marshal_list(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_list` method is a private method that is used to serialize an object of type `type_ref` to + a dictionary. This method is called by the `save` method.""" as_list = [] if not inst: return None, False @@ -507,6 +550,8 @@ def _marshal_list(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, @classmethod def _marshal_dict(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_dict` method is a private method that is used to serialize an object of type `type_ref` to + a dictionary. This method is called by the `save` method.""" if not isinstance(inst, dict): return None, False as_dict = {} @@ -518,6 +563,8 @@ def _marshal_dict(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, @classmethod def _marshal_dataclass(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_dataclass` method is a private method that is used to serialize an object of type `type_ref` + to a dictionary. This method is called by the `save` method.""" if inst is None: return None, False as_dict = {} @@ -533,24 +580,34 @@ def _marshal_dataclass(cls, type_ref: type, path: list[str], inst: Any) -> tuple @staticmethod def _marshal_databricks_config(inst): + """The `_marshal_databricks_config` method is a private method that is used to serialize an object of type + `databricks.sdk.core.Config` to a dictionary. This method is called by the `save` method.""" if not inst: return None, False return inst.as_dict(), True @staticmethod def _marshal_enum(inst): + """The `_marshal_enum` method is a private method that is used to serialize an object of type `enum.Enum` to + a dictionary. This method is called by the `save` method.""" if not inst: return None, False return inst.value, True @runtime_checkable class _FromDict(Protocol): + """The `_FromDict` protocol is used to define a type that can be constructed from a dictionary. This protocol + is used to define a type that can be constructed from a dictionary. This protocol is used to define a type that + can be constructed from a dictionary.""" + @classmethod def from_dict(cls, raw: dict): pass @classmethod def _unmarshal(cls, inst: Any, path: list[str], type_ref: type[T]) -> T | None: + """The `_unmarshal` method is a private method that is used to deserialize a dictionary to an object of type + `type_ref`. This method is called by the `load` method.""" # pylint: disable-next=import-outside-toplevel from typing import ( # type: ignore[attr-defined] _GenericAlias, @@ -585,6 +642,8 @@ def _unmarshal(cls, inst: Any, path: list[str], type_ref: type[T]) -> T | None: @classmethod def _unmarshal_dataclass(cls, inst, path, type_ref): + """The `_unmarshal_dataclass` method is a private method that is used to deserialize a dictionary to an object + of type `type_ref`. This method is called by the `load` method.""" if inst is None: return None if not isinstance(inst, dict): @@ -609,6 +668,8 @@ def _unmarshal_dataclass(cls, inst, path, type_ref): @classmethod def _unmarshal_union(cls, inst, path, type_ref): + """The `_unmarshal_union` method is a private method that is used to deserialize a dictionary to an object + of type `type_ref`. This method is called by the `load` method.""" for variant in get_args(type_ref): value = cls._unmarshal(inst, path, variant) if value: @@ -617,6 +678,8 @@ def _unmarshal_union(cls, inst, path, type_ref): @classmethod def _unmarshal_generic(cls, inst, path, type_ref): + """The `_unmarshal_generic` method is a private method that is used to deserialize a dictionary to an object + of type `type_ref`. This method is called by the `load` method.""" type_args = get_args(type_ref) if not type_args: raise SerdeError(f"Missing type arguments: {type_args}") @@ -626,6 +689,8 @@ def _unmarshal_generic(cls, inst, path, type_ref): @classmethod def _unmarshal_list(cls, inst, path, hint): + """The `_unmarshal_list` method is a private method that is used to deserialize a dictionary to an object + of type `type_ref`. This method is called by the `load` method.""" if inst is None: return None as_list = [] @@ -635,6 +700,8 @@ def _unmarshal_list(cls, inst, path, hint): @classmethod def _unmarshal_dict(cls, inst, path, type_ref): + """The `_unmarshal_dict` method is a private method that is used to deserialize a dictionary to an object + of type `type_ref`. This method is called by the `load` method.""" if not inst: return None if not isinstance(inst, dict): @@ -646,6 +713,8 @@ def _unmarshal_dict(cls, inst, path, type_ref): @classmethod def _unmarshal_primitive(cls, inst, type_ref): + """The `_unmarshal_primitive` method is a private method that is used to deserialize a dictionary to an object + of type `type_ref`. This method is called by the `load` method.""" if not inst: return inst # convert from str to int if necessary @@ -654,16 +723,22 @@ def _unmarshal_primitive(cls, inst, type_ref): @staticmethod def _explain_why(type_ref: type, path: list[str], raw: Any) -> str: + """The `_explain_why` method is a private method that is used to explain why a value is not of the expected + type. This method is called by the `_unmarshal` and `_marshal` methods.""" if raw is None: raw = "value is missing" return f'{".".join(path)}: not a {type_ref.__name__}: {raw}' @staticmethod def _dump_json(as_dict: Json, _: type) -> bytes: + """The `_dump_json` method is a private method that is used to serialize a dictionary to a JSON string. This + method is called by the `save` method.""" return json.dumps(as_dict, indent=2).encode("utf8") @staticmethod def _dump_yaml(raw: Json, _: type) -> bytes: + """The `_dump_yaml` method is a private method that is used to serialize a dictionary to a YAML string. This + method is called by the `save` method.""" try: from yaml import dump # pylint: disable=import-outside-toplevel @@ -673,6 +748,8 @@ def _dump_yaml(raw: Json, _: type) -> bytes: @staticmethod def _load_yaml(raw: BinaryIO) -> Json: + """The `_load_yaml` method is a private method that is used to deserialize a YAML string to a dictionary. This + method is called by the `load` method.""" try: from yaml import ( # pylint: disable=import-outside-toplevel YAMLError, @@ -688,6 +765,8 @@ def _load_yaml(raw: BinaryIO) -> Json: @staticmethod def _dump_csv(raw: list[Json], type_ref: type) -> bytes: + """The `_dump_csv` method is a private method that is used to serialize a list of dictionaries to a CSV string. + This method is called by the `save` method.""" type_args = get_args(type_ref) if not type_args: raise SerdeError(f"Writing CSV is only supported for lists. Got {type_ref}") @@ -721,7 +800,8 @@ def _load_csv(raw: BinaryIO) -> list[Json]: return out def _enable_files_in_repos(self): - # check if "enableWorkspaceFilesystem" is set to false + """The `_enable_files_in_repos` method is a private method that is used to enable the "Files In Repos" + feature on the current workspace. This method is called by the `upload` method.""" workspace_file_system = self._ws.workspace_conf.get_status("enableWorkspaceFilesystem") logger.debug("Checking Files In Repos configuration") diff --git a/src/databricks/labs/blueprint/installer.py b/src/databricks/labs/blueprint/installer.py index 28e1401..3d31070 100644 --- a/src/databricks/labs/blueprint/installer.py +++ b/src/databricks/labs/blueprint/installer.py @@ -1,3 +1,5 @@ +"""Manages ~/.{product}/state.json file on WorkspaceFS to track installations.""" + import logging import threading from dataclasses import dataclass, field diff --git a/src/databricks/labs/blueprint/limiter.py b/src/databricks/labs/blueprint/limiter.py index fd38eaa..b207664 100644 --- a/src/databricks/labs/blueprint/limiter.py +++ b/src/databricks/labs/blueprint/limiter.py @@ -1,3 +1,5 @@ +"""This module provides a RateLimiter class and a rate_limited decorator to limit the rate of requests""" + import logging import threading import time diff --git a/src/databricks/labs/blueprint/logger.py b/src/databricks/labs/blueprint/logger.py index 752ad64..b8df54e 100644 --- a/src/databricks/labs/blueprint/logger.py +++ b/src/databricks/labs/blueprint/logger.py @@ -1,8 +1,12 @@ +"""A nice formatter for logging. It uses colors and bold text if the console supports it.""" + import logging import sys class NiceFormatter(logging.Formatter): + """A nice formatter for logging. It uses colors and bold text if the console supports it.""" + BOLD = "\033[1m" RESET = "\033[0m" GREEN = "\033[32m" @@ -14,6 +18,9 @@ class NiceFormatter(logging.Formatter): GRAY = "\033[90m" def __init__(self, *, probe_tty: bool = False) -> None: + """Create a new instance of the formatter. If probe_tty is True, then the formatter will + attempt to detect if the console supports colors. If probe_tty is False, colors will be + enabled by default.""" super().__init__(fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s", datefmt="%H:%M") self._levels = { logging.NOTSET: self._bold("TRACE"), @@ -27,9 +34,11 @@ def __init__(self, *, probe_tty: bool = False) -> None: self.colors = sys.stdout.isatty() if probe_tty else True def _bold(self, text): + """Return text in bold.""" return f"{self.BOLD}{text}{self.RESET}" def format(self, record: logging.LogRecord): # noqa: A003 + """Format the log record. If colors are enabled, use them.""" if not self.colors: return super().format(record) ts = self.formatTime(record, datefmt="%H:%M:%S") @@ -60,6 +69,7 @@ def format(self, record: logging.LogRecord): # noqa: A003 def install_logger(level="DEBUG"): + """Install a console logger with a nice formatter.""" for h in logging.root.handlers: logging.root.removeHandler(h) console_handler = logging.StreamHandler(sys.stderr) diff --git a/src/databricks/labs/blueprint/parallel.py b/src/databricks/labs/blueprint/parallel.py index 5f3f22e..0c93635 100644 --- a/src/databricks/labs/blueprint/parallel.py +++ b/src/databricks/labs/blueprint/parallel.py @@ -1,3 +1,5 @@ +"""Run tasks in parallel and return results and errors""" + import concurrent import datetime as dt import functools @@ -40,6 +42,7 @@ def __init__(self, name, tasks: Sequence[Task[Result]], num_threads: int): def gather( cls, name: str, tasks: Sequence[Task[Result]], num_threads: int | None = None ) -> tuple[Collection[Result], list[Exception]]: + """Run tasks in parallel and return results and errors""" if num_threads is None: num_cpus = os.cpu_count() if num_cpus is None: @@ -49,12 +52,14 @@ def gather( @classmethod def strict(cls, name: str, tasks: Sequence[Task[Result]]) -> Collection[Result]: + """Run tasks in parallel and raise ManyError if any task fails""" collected, errs = cls.gather(name, tasks) if errs: raise ManyError(errs) return collected def _run(self) -> tuple[Collection[Result], list[Exception]]: + """Run tasks in parallel and return results and errors""" given_cnt = len(self._tasks) if given_cnt == 0: return [], [] @@ -78,6 +83,7 @@ def _run(self) -> tuple[Collection[Result], list[Exception]]: return collected, errors def _on_finish(self, given_cnt: int, collected_cnt: int, failed_cnt: int): + """Log the results of the parallel execution""" since = dt.datetime.now() - self._started success_pct = collected_cnt / given_cnt * 100 stats = f"{success_pct:.0f}% results available ({collected_cnt}/{given_cnt}). Took {since}" @@ -91,6 +97,7 @@ def _on_finish(self, given_cnt: int, collected_cnt: int, failed_cnt: int): logger.info(f"Finished '{self._name}' tasks: {stats}") def _execute(self): + """Run tasks in parallel and return futures""" thread_name_prefix = re.sub(r"\W+", "_", self._name) with ThreadPoolExecutor(self._num_threads, thread_name_prefix) as pool: futures = [] @@ -103,6 +110,7 @@ def _execute(self): return concurrent.futures.as_completed(futures) def _progress_report(self, _): + """Log the progress of the parallel execution""" total_cnt = len(self._tasks) log_every = self._default_log_every if total_cnt > self._large_log_every: diff --git a/src/databricks/labs/blueprint/tui.py b/src/databricks/labs/blueprint/tui.py index 3627b96..5bdc5b1 100644 --- a/src/databricks/labs/blueprint/tui.py +++ b/src/databricks/labs/blueprint/tui.py @@ -1,3 +1,5 @@ +"""Text User Interface (TUI) utilities""" + import logging import re from collections.abc import Callable diff --git a/src/databricks/labs/blueprint/upgrades.py b/src/databricks/labs/blueprint/upgrades.py index b044911..fcd8c4d 100644 --- a/src/databricks/labs/blueprint/upgrades.py +++ b/src/databricks/labs/blueprint/upgrades.py @@ -1,3 +1,5 @@ +"""Automated rollout of application upgrades deployed in a Databricks workspace.""" + import importlib.util import logging from dataclasses import dataclass, field @@ -85,6 +87,7 @@ def apply(self, ws: WorkspaceClient): self._installation.save(applied) def _apply_python_script(self, script: Path, ws: WorkspaceClient): + """Load and apply the upgrade script.""" name = "_".join(script.name.removesuffix(".py").split("_")[1:]) spec = importlib.util.spec_from_file_location(name, script.as_posix()) if not spec: @@ -102,9 +105,11 @@ def _apply_python_script(self, script: Path, ws: WorkspaceClient): change.upgrade(self._installation, ws) def _installed(self) -> SemVer: + """Load the installed version of the product.""" return self._installation.load(Version).as_semver() def _diff(self, upgrades_folder: Path): + """Yield the upgrade scripts that need to be applied.""" current = self._product_info.as_semver() installed_version = self._installed() for file in upgrades_folder.glob("v*.py"): @@ -122,6 +127,7 @@ def _diff(self, upgrades_folder: Path): @staticmethod def _parse_version(name: str) -> SemVer: + """Parse the version from the upgrade script name.""" split = name.split("_") if len(split) < 2: raise ValueError(f"invalid spec: {name}") diff --git a/src/databricks/labs/blueprint/wheels.py b/src/databricks/labs/blueprint/wheels.py index 53b3530..e25987c 100644 --- a/src/databricks/labs/blueprint/wheels.py +++ b/src/databricks/labs/blueprint/wheels.py @@ -1,3 +1,5 @@ +"""Product info and wheel builder.""" + import inspect import logging import random @@ -6,6 +8,7 @@ import subprocess import sys import tempfile +import warnings from collections.abc import Iterable from contextlib import AbstractContextManager from dataclasses import dataclass @@ -49,6 +52,7 @@ def __init__(self, __file: str, *, github_org: str = "databrickslabs", product_n @classmethod def from_class(cls, klass: type) -> "ProductInfo": + """Create a product info with a class used as a starting point to determine location of the version file.""" return cls(inspect.getfile(klass)) @classmethod @@ -57,6 +61,7 @@ def for_testing(cls, klass: type) -> "ProductInfo": return cls(inspect.getfile(klass), product_name=cls._make_random(4)) def checkout_root(self): + """Returns the root of the project, where .git folder is located.""" return find_project_root(self._version_file.as_posix()) def version_file(self) -> Path: @@ -80,6 +85,7 @@ def version(self): return self.__version def as_semver(self) -> SemVer: + """Returns the version as SemVer object.""" return SemVer.parse(self.version()) def product_name(self) -> str: @@ -90,16 +96,20 @@ def product_name(self) -> str: return version_file_folder.name.replace("_", "-") def released_version(self) -> str: + """Returns the version from the version file.""" return self._read_version(self._version_file) def is_git_checkout(self) -> bool: + """Returns True if the project is a git checkout.""" git_config = self.checkout_root() / ".git" / "config" return git_config.exists() def is_unreleased_version(self) -> bool: + """Returns True if we are in the git checkout and the version is unreleased.""" return "+" in self.version() def unreleased_version(self) -> str: + """Returns the unreleased version based on the `git describe --tags` output.""" try: out = subprocess.run( ["git", "describe", "--tags"], stdout=subprocess.PIPE, check=True, cwd=self.checkout_root() @@ -115,9 +125,11 @@ def unreleased_version(self) -> str: return self.released_version() def current_installation(self, ws: WorkspaceClient) -> Installation: + """Returns the current installation of the product.""" return Installation.current(ws, self.product_name()) def wheels(self, ws: WorkspaceClient) -> "WheelsV2": + """Returns the wheel builder.""" return WheelsV2(self.current_installation(ws), self) @staticmethod @@ -136,6 +148,7 @@ def _make_random(k) -> str: @staticmethod def _semver_and_pep440(git_detached_version: str) -> str: + """Create a version that is both SemVer and PEP440 compliant.""" dv = SemVer.parse(git_detached_version) datestamp = datetime.now().strftime("%Y%m%d%H%M%S") # new commits on main branch since the last tag @@ -164,6 +177,7 @@ def _infer_version_file(cls, start: Path, version_file_names: list[str]) -> Path @staticmethod def _traverse_up(start: Path, version_file_names: list[str]) -> Iterable[Path]: + """Traverse up the directory tree and yield the version files.""" prev_folder = start folder = start.parent while not folder.samefile(prev_folder): @@ -177,6 +191,7 @@ def _traverse_up(start: Path, version_file_names: list[str]) -> Iterable[Path]: @staticmethod def _read_version(version_file: Path) -> str: + """Read the version from the version file.""" version_data: dict[str, str] = {} with version_file.open("r") as f: exec(f.read(), version_data) # pylint: disable=exec-used @@ -206,10 +221,12 @@ def __init__(self, installation: Installation, product_info: ProductInfo, *, ver self._verbose = verbose def upload_to_dbfs(self) -> str: + """Uploads the wheel to DBFS location of installation and returns the remote path.""" with self._local_wheel.open("rb") as f: return self._installation.upload_dbfs(f"wheels/{self._local_wheel.name}", f) def upload_to_wsfs(self) -> str: + """Uploads the wheel to WSFS location of installation and returns the remote path.""" with self._local_wheel.open("rb") as f: remote_wheel = self._installation.upload(f"wheels/{self._local_wheel.name}", f.read()) self._installation.save(Version(self._product_info.version(), remote_wheel, self._now_iso())) @@ -217,14 +234,17 @@ def upload_to_wsfs(self) -> str: @staticmethod def _now_iso(): + """Returns the current time in ISO format.""" return datetime.now(timezone.utc).isoformat() def __enter__(self) -> "WheelsV2": + """Builds the wheel and returns the instance. Use it as a context manager.""" self._tmp_dir = tempfile.TemporaryDirectory() self._local_wheel = self._build_wheel(self._tmp_dir.name, verbose=self._verbose) return self def __exit__(self, __exc_type, __exc_value, __traceback): + """Cleans up the temporary directory. Use it as a context manager.""" self._tmp_dir.cleanup() def _build_wheel(self, tmp_dir: str, *, verbose: bool = False): @@ -257,6 +277,7 @@ def _build_wheel(self, tmp_dir: str, *, verbose: bool = False): return next(Path(tmp_dir).glob("*.whl")) def _override_version_to_unreleased(self, tmp_dir_path: Path): + """Overrides the version file to unreleased version.""" checkout_root = self._product_info.checkout_root() relative_version_file = self._product_info.version_file().relative_to(checkout_root) version_file = tmp_dir_path / relative_version_file @@ -264,6 +285,7 @@ def _override_version_to_unreleased(self, tmp_dir_path: Path): f.write(f'__version__ = "{self._product_info.version()}"') def _copy_root_to(self, tmp_dir: str | Path): + """Copies the root to a temporary directory.""" checkout_root = self._product_info.checkout_root() tmp_dir_path = Path(tmp_dir) / "working-copy" @@ -287,5 +309,6 @@ class Wheels(WheelsV2): def __init__( self, ws: WorkspaceClient, install_state: InstallState, product_info: ProductInfo, *, verbose: bool = False ): + warnings.warn("Wheels is deprecated, use WheelsV2 instead", DeprecationWarning) installation = Installation(ws, product_info.product_name(), install_folder=install_state.install_folder()) super().__init__(installation, product_info, verbose=verbose)