Skip to content

Commit

Permalink
Added more code documentation (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx authored Mar 9, 2024
1 parent 404c4c7 commit 779ed28
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/databricks/labs/blueprint/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Baseline CLI for Databricks Labs projects."""

import functools
import json
import logging
Expand Down Expand Up @@ -33,6 +35,8 @@ def __init__(self, __file: str):
self._product_info = ProductInfo(__file)

def command(self, fn=None, is_account: bool = False, is_unauthenticated: bool = False):
"""Decorator to register a function as a command."""

def register(func):
command_name = func.__name__.replace("_", "-")
if not func.__doc__:
Expand All @@ -52,6 +56,7 @@ def register(func):
return fn

def _route(self, raw):
"""Route the command. This is the entry point for the CLI."""
payload = json.loads(raw)
command = payload["command"]
if command not in self._mapping:
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/labs/blueprint/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Entrypoint utilities for logging and project root detection"""

import logging
import os
import sys
Expand Down
86 changes: 83 additions & 3 deletions src/databricks/labs/blueprint/installation.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/databricks/labs/blueprint/installer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Manages ~/.{product}/state.json file on WorkspaceFS to track installations."""

import logging
import threading
from dataclasses import dataclass, field
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/labs/blueprint/limiter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""This module provides a RateLimiter class and a rate_limited decorator to limit the rate of requests"""

import logging
import threading
import time
Expand Down
10 changes: 10 additions & 0 deletions src/databricks/labs/blueprint/logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""A nice formatter for logging. It uses colors and bold text if the console supports it."""

import logging
import sys


class NiceFormatter(logging.Formatter):
"""A nice formatter for logging. It uses colors and bold text if the console supports it."""

BOLD = "\033[1m"
RESET = "\033[0m"
GREEN = "\033[32m"
Expand All @@ -14,6 +18,9 @@ class NiceFormatter(logging.Formatter):
GRAY = "\033[90m"

def __init__(self, *, probe_tty: bool = False) -> None:
"""Create a new instance of the formatter. If probe_tty is True, then the formatter will
attempt to detect if the console supports colors. If probe_tty is False, colors will be
enabled by default."""
super().__init__(fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s", datefmt="%H:%M")
self._levels = {
logging.NOTSET: self._bold("TRACE"),
Expand All @@ -27,9 +34,11 @@ def __init__(self, *, probe_tty: bool = False) -> None:
self.colors = sys.stdout.isatty() if probe_tty else True

def _bold(self, text):
"""Return text in bold."""
return f"{self.BOLD}{text}{self.RESET}"

def format(self, record: logging.LogRecord): # noqa: A003
"""Format the log record. If colors are enabled, use them."""
if not self.colors:
return super().format(record)
ts = self.formatTime(record, datefmt="%H:%M:%S")
Expand Down Expand Up @@ -60,6 +69,7 @@ def format(self, record: logging.LogRecord): # noqa: A003


def install_logger(level="DEBUG"):
"""Install a console logger with a nice formatter."""
for h in logging.root.handlers:
logging.root.removeHandler(h)
console_handler = logging.StreamHandler(sys.stderr)
Expand Down
8 changes: 8 additions & 0 deletions src/databricks/labs/blueprint/parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Run tasks in parallel and return results and errors"""

import concurrent
import datetime as dt
import functools
Expand Down Expand Up @@ -40,6 +42,7 @@ def __init__(self, name, tasks: Sequence[Task[Result]], num_threads: int):
def gather(
cls, name: str, tasks: Sequence[Task[Result]], num_threads: int | None = None
) -> tuple[Collection[Result], list[Exception]]:
"""Run tasks in parallel and return results and errors"""
if num_threads is None:
num_cpus = os.cpu_count()
if num_cpus is None:
Expand All @@ -49,12 +52,14 @@ def gather(

@classmethod
def strict(cls, name: str, tasks: Sequence[Task[Result]]) -> Collection[Result]:
"""Run tasks in parallel and raise ManyError if any task fails"""
collected, errs = cls.gather(name, tasks)
if errs:
raise ManyError(errs)
return collected

def _run(self) -> tuple[Collection[Result], list[Exception]]:
"""Run tasks in parallel and return results and errors"""
given_cnt = len(self._tasks)
if given_cnt == 0:
return [], []
Expand All @@ -78,6 +83,7 @@ def _run(self) -> tuple[Collection[Result], list[Exception]]:
return collected, errors

def _on_finish(self, given_cnt: int, collected_cnt: int, failed_cnt: int):
"""Log the results of the parallel execution"""
since = dt.datetime.now() - self._started
success_pct = collected_cnt / given_cnt * 100
stats = f"{success_pct:.0f}% results available ({collected_cnt}/{given_cnt}). Took {since}"
Expand All @@ -91,6 +97,7 @@ def _on_finish(self, given_cnt: int, collected_cnt: int, failed_cnt: int):
logger.info(f"Finished '{self._name}' tasks: {stats}")

def _execute(self):
"""Run tasks in parallel and return futures"""
thread_name_prefix = re.sub(r"\W+", "_", self._name)
with ThreadPoolExecutor(self._num_threads, thread_name_prefix) as pool:
futures = []
Expand All @@ -103,6 +110,7 @@ def _execute(self):
return concurrent.futures.as_completed(futures)

def _progress_report(self, _):
"""Log the progress of the parallel execution"""
total_cnt = len(self._tasks)
log_every = self._default_log_every
if total_cnt > self._large_log_every:
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/labs/blueprint/tui.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Text User Interface (TUI) utilities"""

import logging
import re
from collections.abc import Callable
Expand Down
6 changes: 6 additions & 0 deletions src/databricks/labs/blueprint/upgrades.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Automated rollout of application upgrades deployed in a Databricks workspace."""

import importlib.util
import logging
from dataclasses import dataclass, field
Expand Down Expand Up @@ -85,6 +87,7 @@ def apply(self, ws: WorkspaceClient):
self._installation.save(applied)

def _apply_python_script(self, script: Path, ws: WorkspaceClient):
"""Load and apply the upgrade script."""
name = "_".join(script.name.removesuffix(".py").split("_")[1:])
spec = importlib.util.spec_from_file_location(name, script.as_posix())
if not spec:
Expand All @@ -102,9 +105,11 @@ def _apply_python_script(self, script: Path, ws: WorkspaceClient):
change.upgrade(self._installation, ws)

def _installed(self) -> SemVer:
"""Load the installed version of the product."""
return self._installation.load(Version).as_semver()

def _diff(self, upgrades_folder: Path):
"""Yield the upgrade scripts that need to be applied."""
current = self._product_info.as_semver()
installed_version = self._installed()
for file in upgrades_folder.glob("v*.py"):
Expand All @@ -122,6 +127,7 @@ def _diff(self, upgrades_folder: Path):

@staticmethod
def _parse_version(name: str) -> SemVer:
"""Parse the version from the upgrade script name."""
split = name.split("_")
if len(split) < 2:
raise ValueError(f"invalid spec: {name}")
Expand Down
23 changes: 23 additions & 0 deletions src/databricks/labs/blueprint/wheels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Product info and wheel builder."""

import inspect
import logging
import random
Expand All @@ -6,6 +8,7 @@
import subprocess
import sys
import tempfile
import warnings
from collections.abc import Iterable
from contextlib import AbstractContextManager
from dataclasses import dataclass
Expand Down Expand Up @@ -49,6 +52,7 @@ def __init__(self, __file: str, *, github_org: str = "databrickslabs", product_n

@classmethod
def from_class(cls, klass: type) -> "ProductInfo":
"""Create a product info with a class used as a starting point to determine location of the version file."""
return cls(inspect.getfile(klass))

@classmethod
Expand All @@ -57,6 +61,7 @@ def for_testing(cls, klass: type) -> "ProductInfo":
return cls(inspect.getfile(klass), product_name=cls._make_random(4))

def checkout_root(self):
"""Returns the root of the project, where .git folder is located."""
return find_project_root(self._version_file.as_posix())

def version_file(self) -> Path:
Expand All @@ -80,6 +85,7 @@ def version(self):
return self.__version

def as_semver(self) -> SemVer:
"""Returns the version as SemVer object."""
return SemVer.parse(self.version())

def product_name(self) -> str:
Expand All @@ -90,16 +96,20 @@ def product_name(self) -> str:
return version_file_folder.name.replace("_", "-")

def released_version(self) -> str:
"""Returns the version from the version file."""
return self._read_version(self._version_file)

def is_git_checkout(self) -> bool:
"""Returns True if the project is a git checkout."""
git_config = self.checkout_root() / ".git" / "config"
return git_config.exists()

def is_unreleased_version(self) -> bool:
"""Returns True if we are in the git checkout and the version is unreleased."""
return "+" in self.version()

def unreleased_version(self) -> str:
"""Returns the unreleased version based on the `git describe --tags` output."""
try:
out = subprocess.run(
["git", "describe", "--tags"], stdout=subprocess.PIPE, check=True, cwd=self.checkout_root()
Expand All @@ -115,9 +125,11 @@ def unreleased_version(self) -> str:
return self.released_version()

def current_installation(self, ws: WorkspaceClient) -> Installation:
"""Returns the current installation of the product."""
return Installation.current(ws, self.product_name())

def wheels(self, ws: WorkspaceClient) -> "WheelsV2":
"""Returns the wheel builder."""
return WheelsV2(self.current_installation(ws), self)

@staticmethod
Expand All @@ -136,6 +148,7 @@ def _make_random(k) -> str:

@staticmethod
def _semver_and_pep440(git_detached_version: str) -> str:
"""Create a version that is both SemVer and PEP440 compliant."""
dv = SemVer.parse(git_detached_version)
datestamp = datetime.now().strftime("%Y%m%d%H%M%S")
# new commits on main branch since the last tag
Expand Down Expand Up @@ -164,6 +177,7 @@ def _infer_version_file(cls, start: Path, version_file_names: list[str]) -> Path

@staticmethod
def _traverse_up(start: Path, version_file_names: list[str]) -> Iterable[Path]:
"""Traverse up the directory tree and yield the version files."""
prev_folder = start
folder = start.parent
while not folder.samefile(prev_folder):
Expand All @@ -177,6 +191,7 @@ def _traverse_up(start: Path, version_file_names: list[str]) -> Iterable[Path]:

@staticmethod
def _read_version(version_file: Path) -> str:
"""Read the version from the version file."""
version_data: dict[str, str] = {}
with version_file.open("r") as f:
exec(f.read(), version_data) # pylint: disable=exec-used
Expand Down Expand Up @@ -206,25 +221,30 @@ def __init__(self, installation: Installation, product_info: ProductInfo, *, ver
self._verbose = verbose

def upload_to_dbfs(self) -> str:
"""Uploads the wheel to DBFS location of installation and returns the remote path."""
with self._local_wheel.open("rb") as f:
return self._installation.upload_dbfs(f"wheels/{self._local_wheel.name}", f)

def upload_to_wsfs(self) -> str:
"""Uploads the wheel to WSFS location of installation and returns the remote path."""
with self._local_wheel.open("rb") as f:
remote_wheel = self._installation.upload(f"wheels/{self._local_wheel.name}", f.read())
self._installation.save(Version(self._product_info.version(), remote_wheel, self._now_iso()))
return remote_wheel

@staticmethod
def _now_iso():
"""Returns the current time in ISO format."""
return datetime.now(timezone.utc).isoformat()

def __enter__(self) -> "WheelsV2":
"""Builds the wheel and returns the instance. Use it as a context manager."""
self._tmp_dir = tempfile.TemporaryDirectory()
self._local_wheel = self._build_wheel(self._tmp_dir.name, verbose=self._verbose)
return self

def __exit__(self, __exc_type, __exc_value, __traceback):
"""Cleans up the temporary directory. Use it as a context manager."""
self._tmp_dir.cleanup()

def _build_wheel(self, tmp_dir: str, *, verbose: bool = False):
Expand Down Expand Up @@ -257,13 +277,15 @@ def _build_wheel(self, tmp_dir: str, *, verbose: bool = False):
return next(Path(tmp_dir).glob("*.whl"))

def _override_version_to_unreleased(self, tmp_dir_path: Path):
"""Overrides the version file to unreleased version."""
checkout_root = self._product_info.checkout_root()
relative_version_file = self._product_info.version_file().relative_to(checkout_root)
version_file = tmp_dir_path / relative_version_file
with version_file.open("w") as f:
f.write(f'__version__ = "{self._product_info.version()}"')

def _copy_root_to(self, tmp_dir: str | Path):
"""Copies the root to a temporary directory."""
checkout_root = self._product_info.checkout_root()
tmp_dir_path = Path(tmp_dir) / "working-copy"

Expand All @@ -287,5 +309,6 @@ class Wheels(WheelsV2):
def __init__(
self, ws: WorkspaceClient, install_state: InstallState, product_info: ProductInfo, *, verbose: bool = False
):
warnings.warn("Wheels is deprecated, use WheelsV2 instead", DeprecationWarning)
installation = Installation(ws, product_info.product_name(), install_folder=install_state.install_folder())
super().__init__(installation, product_info, verbose=verbose)

0 comments on commit 779ed28

Please sign in to comment.