diff --git a/data_safe_haven/__init__.py b/data_safe_haven/__init__.py index f0a613dc08..b13a7ce9ad 100644 --- a/data_safe_haven/__init__.py +++ b/data_safe_haven/__init__.py @@ -1,5 +1,8 @@ """Data Safe Haven""" +from .logging import init_logging from .version import __version__, __version_info__ +init_logging() + __all__ = ["__version__", "__version_info__"] diff --git a/data_safe_haven/administration/users/entra_users.py b/data_safe_haven/administration/users/entra_users.py index 755f441573..6ef39a5744 100644 --- a/data_safe_haven/administration/users/entra_users.py +++ b/data_safe_haven/administration/users/entra_users.py @@ -10,7 +10,7 @@ ) from data_safe_haven.external import GraphApi from data_safe_haven.functions import password -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger from .research_user import ResearchUser @@ -26,7 +26,7 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) self.graph_api = graph_api - self.logger = LoggingSingleton() + self.logger = get_logger() def add(self, new_users: Sequence[ResearchUser]) -> None: """ diff --git a/data_safe_haven/administration/users/user_handler.py b/data_safe_haven/administration/users/user_handler.py index d626ca6957..d24f5c300a 100644 --- a/data_safe_haven/administration/users/user_handler.py +++ b/data_safe_haven/administration/users/user_handler.py @@ -6,7 +6,8 @@ from data_safe_haven.context import Context from data_safe_haven.exceptions import DataSafeHavenUserHandlingError from data_safe_haven.external import GraphApi -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger +from data_safe_haven.utility import console from .entra_users import EntraUsers from .guacamole_users import GuacamoleUsers @@ -25,7 +26,7 @@ def __init__( self.context = context self.config = config self.pulumi_config = pulumi_config - self.logger = LoggingSingleton() + self.logger = get_logger() self.sre_guacamole_users_: dict[str, GuacamoleUsers] = {} def add(self, users_csv_path: pathlib.Path) -> None: @@ -125,9 +126,7 @@ def list(self) -> None: ) user_data.append(user_memberships) - # Write user information as a table - for line in self.logger.tabulate(user_headers, user_data): - self.logger.info(line) + console.tabulate(user_headers, user_data) except Exception as exc: msg = f"Could not list users.\n{exc}" raise DataSafeHavenUserHandlingError(msg) from exc @@ -153,13 +152,13 @@ def remove(self, user_names: Sequence[str]) -> None: """ try: # Construct user lists - self.logger.info(f"Attempting to remove {len(user_names)} user(s).") + self.logger.debug(f"Attempting to remove {len(user_names)} user(s).") entra_users_to_remove = [ user for user in self.entra_users.list() if user.username in user_names ] # Commit changes - self.logger.info( + self.logger.debug( f"Found {len(entra_users_to_remove)} valid user(s) to remove." ) self.entra_users.remove(entra_users_to_remove) diff --git a/data_safe_haven/commands/cli.py b/data_safe_haven/commands/cli.py index 71e2d75886..49b099bc7c 100644 --- a/data_safe_haven/commands/cli.py +++ b/data_safe_haven/commands/cli.py @@ -1,13 +1,12 @@ """Command line entrypoint for Data Safe Haven application""" -import pathlib from typing import Annotated, Optional import typer +from rich import print as rprint from data_safe_haven import __version__ -from data_safe_haven.exceptions import DataSafeHavenError -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import set_console_level, show_console_level from .config import config_command_group from .context import context_command_group @@ -29,22 +28,22 @@ # This is executed before @application.callback() def callback( - output: Annotated[ - Optional[pathlib.Path], # noqa: UP007 + verbose: Annotated[ # noqa: FBT002 + bool, typer.Option( - "--output", "-o", resolve_path=True, help="Path to an output log file" + "--verbose", + "-v", + help="Increase the verbosity of console output.", ), - ] = None, - verbosity: Annotated[ - Optional[int], # noqa: UP007 + ] = False, + show_level: Annotated[ # noqa: FBT002 + bool, typer.Option( - "--verbosity", - "-v", - help="Increase the verbosity of messages: each '-v' will increase by one step.", - count=True, - is_eager=True, + "--show-level", + "-l", + help="Show Log level.", ), - ] = None, + ] = False, version: Annotated[ Optional[bool], # noqa: UP007 typer.Option( @@ -53,13 +52,15 @@ def callback( ] = None, ) -> None: """Arguments to the main executable""" - logger = LoggingSingleton() - if output: - logger.set_log_file(output) - if verbosity: - logger.set_verbosity(verbosity) + + if verbose: + set_console_level("DEBUG") + + if show_level: + show_console_level() + if version: - print(f"Data Safe Haven {__version__}") # noqa: T201 + rprint(f"Data Safe Haven {__version__}") raise typer.Exit() @@ -95,10 +96,5 @@ def callback( def main() -> None: - """Run the application and log any exceptions""" - try: - application() - except DataSafeHavenError as exc: - logger = LoggingSingleton() - for line in str(exc).split("\n"): - logger.error(line) + """Run the application""" + application() diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index def533e8f8..4f9295a826 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -4,11 +4,12 @@ from typing import Annotated, Optional import typer -from rich import print +from rich import print as rprint from data_safe_haven.config import Config from data_safe_haven.context import ContextSettings -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger +from data_safe_haven.utility import prompts config_command_group = typer.Typer() @@ -30,7 +31,7 @@ def template( with open(file, "w") as outfile: outfile.write(config_yaml) else: - print(config_yaml) + rprint(config_yaml) @config_command_group.command() @@ -39,18 +40,20 @@ def upload( ) -> None: """Upload a configuration to the Data Safe Haven context""" context = ContextSettings.from_file().assert_context() - logger = LoggingSingleton() # Create configuration object from file with open(file) as config_file: config_yaml = config_file.read() config = Config.from_yaml(config_yaml) + logger = get_logger() + # Present diff to user if Config.remote_exists(context): if diff := config.remote_yaml_diff(context): - print("".join(diff)) - if not logger.confirm( + for line in "".join(diff).splitlines(): + logger.info(line) + if not prompts.confirm( ( "Configuration has changed, " "do you want to overwrite the remote configuration?" @@ -59,7 +62,7 @@ def upload( ): raise typer.Exit() else: - print("No changes, won't upload configuration.") + rprint("No changes, won't upload configuration.") raise typer.Exit() config.upload(context) @@ -79,4 +82,4 @@ def show( with open(file, "w") as outfile: outfile.write(config.to_yaml()) else: - print(config.to_yaml()) + rprint(config.to_yaml()) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 94a35fef58..92e3bda0a7 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -3,6 +3,7 @@ from typing import Annotated, Optional import typer +from rich import print as rprint from data_safe_haven import validators from data_safe_haven.context import ( @@ -14,15 +15,15 @@ DataSafeHavenAzureAPIAuthenticationError, DataSafeHavenConfigError, ) -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger context_command_group = typer.Typer() -logger = LoggingSingleton() @context_command_group.command() def show() -> None: """Show information about the selected context.""" + logger = get_logger() try: settings = ContextSettings.from_file() except DataSafeHavenConfigError as exc: @@ -34,17 +35,21 @@ def show() -> None: current_context_key = settings.selected current_context = settings.context - logger.info(f"Current context: [green]{current_context_key}") + rprint(f"Current context: [green]{current_context_key}") if current_context is not None: - logger.info(f"\tName: {current_context.name}") - logger.info(f"\tAdmin Group ID: {current_context.admin_group_id}") - logger.info(f"\tSubscription name: {current_context.subscription_name}") - logger.info(f"\tLocation: {current_context.location}") + rprint( + f"\tName: {current_context.name}", + f"\tAdmin Group ID: {current_context.admin_group_id}", + f"\tSubscription name: {current_context.subscription_name}", + f"\tLocation: {current_context.location}", + sep="\n", + ) @context_command_group.command() def available() -> None: """Show the available contexts.""" + logger = get_logger() try: settings = ContextSettings.from_file() except DataSafeHavenConfigError as exc: @@ -60,7 +65,7 @@ def available() -> None: available.remove(current_context_key) available = [f"[green]{current_context_key}*[/]", *available] - logger.info("\n".join(available)) + rprint("\n".join(available)) @context_command_group.command() @@ -68,6 +73,7 @@ def switch( key: Annotated[str, typer.Argument(help="Key of the context to switch to.")] ) -> None: """Switch the selected context.""" + logger = get_logger() try: settings = ContextSettings.from_file() except DataSafeHavenConfigError as exc: @@ -164,6 +170,7 @@ def update( ] = None, ) -> None: """Update the selected context settings.""" + logger = get_logger() try: settings = ContextSettings.from_file() except DataSafeHavenConfigError as exc: @@ -186,6 +193,7 @@ def remove( key: Annotated[str, typer.Argument(help="Name of the context to remove.")], ) -> None: """Removes a context.""" + logger = get_logger() try: settings = ContextSettings.from_file() except DataSafeHavenConfigError as exc: @@ -198,6 +206,7 @@ def remove( @context_command_group.command() def create() -> None: """Create Data Safe Haven context infrastructure.""" + logger = get_logger() try: context = ContextSettings.from_file().assert_context() except DataSafeHavenConfigError as exc: @@ -224,6 +233,7 @@ def create() -> None: @context_command_group.command() def teardown() -> None: """Tear down Data Safe Haven context infrastructure.""" + logger = get_logger() try: context = ContextSettings.from_file().assert_context() except DataSafeHavenConfigError as exc: diff --git a/data_safe_haven/commands/pulumi.py b/data_safe_haven/commands/pulumi.py index 634798bd7f..567c235730 100644 --- a/data_safe_haven/commands/pulumi.py +++ b/data_safe_haven/commands/pulumi.py @@ -4,13 +4,14 @@ from typing import Annotated import typer -from rich import print +from rich import print as rprint from data_safe_haven.config import Config, DSHPulumiConfig from data_safe_haven.context import ContextSettings from data_safe_haven.external import GraphApi from data_safe_haven.functions import sanitise_sre_name from data_safe_haven.infrastructure import SHMProjectManager, SREProjectManager +from data_safe_haven.logging import get_logger pulumi_command_group = typer.Typer() @@ -37,8 +38,9 @@ def run( ] = "", ) -> None: """Run arbitrary Pulumi commands in a DSH project""" + logger = get_logger() if project_type == ProjectType.SRE and not sre_name: - print("--sre-name is required.") + logger.fatal("--sre-name is required.") raise typer.Exit(1) context = ContextSettings.from_file().assert_context() @@ -74,4 +76,4 @@ def run( ) stdout = project.run_pulumi_command(command) - print(stdout) + rprint(stdout) diff --git a/data_safe_haven/commands/sre.py b/data_safe_haven/commands/sre.py index f73b1688a5..3644ee4215 100644 --- a/data_safe_haven/commands/sre.py +++ b/data_safe_haven/commands/sre.py @@ -10,8 +10,8 @@ from data_safe_haven.external import GraphApi from data_safe_haven.functions import sanitise_sre_name from data_safe_haven.infrastructure import SHMProjectManager, SREProjectManager +from data_safe_haven.logging import get_logger from data_safe_haven.provisioning import SREProvisioningManager -from data_safe_haven.utility import LoggingSingleton sre_command_group = typer.Typer() @@ -29,7 +29,7 @@ def deploy( ] = None, ) -> None: """Deploy a Secure Research Environment""" - logger = LoggingSingleton() + logger = get_logger() context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) pulumi_config = DSHPulumiConfig.from_remote(context) diff --git a/data_safe_haven/commands/users.py b/data_safe_haven/commands/users.py index ebc27dda05..ef7a9fefd0 100644 --- a/data_safe_haven/commands/users.py +++ b/data_safe_haven/commands/users.py @@ -11,7 +11,7 @@ from data_safe_haven.exceptions import DataSafeHavenError from data_safe_haven.external import GraphApi from data_safe_haven.functions import sanitise_sre_name -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger users_command_group = typer.Typer() @@ -32,7 +32,7 @@ def add( shm_name = context.shm_name - logger = LoggingSingleton() + logger = get_logger() if shm_name not in pulumi_config.project_names: logger.fatal(f"No Pulumi project for '{shm_name}'.\nHave you deployed the SHM?") raise typer.Exit(1) @@ -65,7 +65,7 @@ def list_users() -> None: shm_name = context.shm_name - logger = LoggingSingleton() + logger = get_logger() if shm_name not in pulumi_config.project_names: logger.fatal(f"No Pulumi project for '{shm_name}'.\nHave you deployed the SHM?") raise typer.Exit(1) @@ -111,17 +111,21 @@ def register( # Use a JSON-safe SRE name sre_name = sanitise_sre_name(sre) - logger = LoggingSingleton() + logger = get_logger() if shm_name not in pulumi_config.project_names: - logger.fatal(f"No Pulumi project for '{shm_name}'.\nHave you deployed the SHM?") + logger.critical( + f"No Pulumi project for '{shm_name}'.\nHave you deployed the SHM?" + ) raise typer.Exit(1) if sre_name not in pulumi_config.project_names: - logger.fatal(f"No Pulumi project for '{sre_name}'.\nHave you deployed the SRE?") + logger.critical( + f"No Pulumi project for '{sre_name}'.\nHave you deployed the SRE?" + ) raise typer.Exit(1) try: - logger.info( + logger.debug( f"Preparing to register {len(usernames)} user(s) with SRE '{sre_name}'" ) @@ -167,7 +171,7 @@ def remove( shm_name = context.shm_name - logger = LoggingSingleton() + logger = get_logger() if shm_name not in pulumi_config.project_names: logger.fatal(f"No Pulumi project for '{shm_name}'.\nHave you deployed the SHM?") raise typer.Exit(1) @@ -213,7 +217,7 @@ def unregister( shm_name = context.shm_name sre_name = sanitise_sre_name(sre) - logger = LoggingSingleton() + logger = get_logger() if shm_name not in pulumi_config.project_names: logger.fatal(f"No Pulumi project for '{shm_name}'.\nHave you deployed the SHM?") raise typer.Exit(1) @@ -223,7 +227,7 @@ def unregister( raise typer.Exit(1) try: - logger.info( + logger.debug( f"Preparing to unregister {len(usernames)} users with SRE '{sre_name}'" ) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index bd2bc9fa94..df16c154d5 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -10,6 +10,7 @@ ) from data_safe_haven.exceptions import DataSafeHavenConfigError +from data_safe_haven.logging import get_logger from data_safe_haven.serialisers import AzureSerialisableModel from data_safe_haven.types import ( AzureVmSku, @@ -22,9 +23,6 @@ TimeZone, UniqueList, ) -from data_safe_haven.utility import ( - LoggingSingleton, -) class ConfigSectionAzure(BaseModel, validate_assignment=True): @@ -57,35 +55,35 @@ def update( fqdn: Fully-qualified domain name to use for this SHM timezone: Timezone in pytz format (eg. Europe/London) """ - logger = LoggingSingleton() + logger = get_logger() # Set admin email address if admin_email_address: self.admin_email_address = admin_email_address - logger.info( + logger.debug( f"[bold]Admin email address[/] will be [green]{self.admin_email_address}[/]." ) # Set admin IP addresses if admin_ip_addresses: self.admin_ip_addresses = admin_ip_addresses - logger.info( + logger.debug( f"[bold]IP addresses used by administrators[/] will be [green]{self.admin_ip_addresses}[/]." ) # Set Entra tenant ID if entra_tenant_id: self.entra_tenant_id = entra_tenant_id - logger.info( + logger.debug( f"[bold]Entra tenant ID[/] will be [green]{self.entra_tenant_id}[/]." ) # Set fully-qualified domain name if fqdn: self.fqdn = fqdn - logger.info( + logger.debug( f"[bold]Fully-qualified domain name[/] will be [green]{self.fqdn}[/]." ) # Set timezone if timezone: self.timezone = timezone - logger.info(f"[bold]Timezone[/] will be [green]{self.timezone}[/].") + logger.debug(f"[bold]Timezone[/] will be [green]{self.timezone}[/].") class ConfigSubsectionRemoteDesktopOpts(BaseModel, validate_assignment=True): @@ -104,13 +102,13 @@ def update( # Set whether copying text out of the SRE is allowed if allow_copy: self.allow_copy = allow_copy - LoggingSingleton().info( + get_logger().debug( f"[bold]Copying text out of the SRE[/] will be [green]{'allowed' if self.allow_copy else 'forbidden'}[/]." ) # Set whether pasting text into the SRE is allowed if allow_paste: self.allow_paste = allow_paste - LoggingSingleton().info( + get_logger().debug( f"[bold]Pasting text into the SRE[/] will be [green]{'allowed' if self.allow_paste else 'forbidden'}[/]." ) @@ -149,11 +147,11 @@ def update( software_packages: Whether to allow packages from external repositories user_ip_addresses: List of IP addresses belonging to users """ - logger = LoggingSingleton() + logger = get_logger() # Set data provider IP addresses if data_provider_ip_addresses: self.data_provider_ip_addresses = data_provider_ip_addresses - logger.info( + logger.debug( f"[bold]IP addresses used by data providers[/] will be [green]{self.data_provider_ip_addresses}[/]." ) # Set which databases to deploy @@ -161,23 +159,25 @@ def update( self.databases = sorted(set(databases)) if len(self.databases) != len(databases): logger.warning("Discarding duplicate values for 'database'.") - logger.info( + logger.debug( f"[bold]Databases available to users[/] will be [green]{[database.value for database in self.databases]}[/]." ) # Set research desktop SKUs if workspace_skus: self.workspace_skus = workspace_skus - logger.info(f"[bold]Workspace SKUs[/] will be [green]{self.workspace_skus}[/].") + logger.debug( + f"[bold]Workspace SKUs[/] will be [green]{self.workspace_skus}[/]." + ) # Select which software packages can be installed by users if software_packages: self.software_packages = software_packages - logger.info( + logger.debug( f"[bold]Software packages[/] from [green]{self.software_packages.value}[/] sources will be installable." ) # Set user IP addresses if user_ip_addresses: self.research_user_ip_addresses = user_ip_addresses - logger.info( + logger.debug( f"[bold]IP addresses used by users[/] will be [green]{self.research_user_ip_addresses}[/]." ) diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index f8e21b3120..2d2d20ed4f 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from data_safe_haven import __version__ +from data_safe_haven.directories import config_dir from data_safe_haven.external import AzureApi from data_safe_haven.functions import alphanumeric from data_safe_haven.serialisers import ContextBase @@ -16,7 +17,6 @@ AzureSubscriptionName, Guid, ) -from data_safe_haven.utility import config_dir class Context(ContextBase, BaseModel, validate_assignment=True): diff --git a/data_safe_haven/context/context_settings.py b/data_safe_haven/context/context_settings.py index 6fc7d997c0..16217f8d87 100644 --- a/data_safe_haven/context/context_settings.py +++ b/data_safe_haven/context/context_settings.py @@ -5,17 +5,19 @@ annotations, ) +from logging import Logger from pathlib import Path from typing import ClassVar from pydantic import Field, model_validator +from data_safe_haven.directories import config_dir from data_safe_haven.exceptions import ( DataSafeHavenConfigError, DataSafeHavenParameterError, ) +from data_safe_haven.logging import get_logger from data_safe_haven.serialisers import YAMLSerialisableModel -from data_safe_haven.utility import LoggingSingleton, config_dir from .context import Context @@ -41,7 +43,7 @@ class ContextSettings(YAMLSerialisableModel): config_type: ClassVar[str] = "ContextSettings" selected_: str | None = Field(..., alias="selected") contexts: dict[str, Context] - logger: ClassVar[LoggingSingleton] = LoggingSingleton() + logger: ClassVar[Logger] = get_logger() @model_validator(mode="after") def ensure_selected_is_valid(self) -> ContextSettings: @@ -149,7 +151,7 @@ def remove(self, key: str) -> None: def from_file(cls, config_file_path: Path | None = None) -> ContextSettings: if config_file_path is None: config_file_path = cls.default_config_file_path() - cls.logger.info( + cls.logger.debug( f"Reading project settings from '[green]{config_file_path}[/]'." ) return cls.from_filepath(config_file_path) @@ -159,4 +161,4 @@ def write(self, config_file_path: Path | None = None) -> None: if config_file_path is None: config_file_path = self.default_config_file_path() self.to_filepath(config_file_path) - self.logger.info(f"Saved context settings to '[green]{config_file_path}[/]'.") + self.logger.debug(f"Saved context settings to '[green]{config_file_path}[/]'.") diff --git a/data_safe_haven/context_infrastructure/infrastructure.py b/data_safe_haven/context_infrastructure/infrastructure.py index b1498170b2..327a48f759 100644 --- a/data_safe_haven/context_infrastructure/infrastructure.py +++ b/data_safe_haven/context_infrastructure/infrastructure.py @@ -1,7 +1,7 @@ from data_safe_haven.context import Context from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external import AzureApi -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger class ContextInfrastructure: @@ -89,7 +89,7 @@ def teardown(self) -> None: Raises: DataSafeHavenAzureError if any resources cannot be destroyed """ - logger = LoggingSingleton() + logger = get_logger() try: logger.info( f"Removing context {self.context.name} resource group {self.context.resource_group_name}" diff --git a/data_safe_haven/directories.py b/data_safe_haven/directories.py new file mode 100644 index 0000000000..6e520bcb37 --- /dev/null +++ b/data_safe_haven/directories.py @@ -0,0 +1,25 @@ +from os import getenv +from pathlib import Path + +import appdirs + +_appname = "data_safe_haven" + + +def config_dir() -> Path: + if config_directory_env := getenv("DSH_CONFIG_DIRECTORY"): + config_directory = Path(config_directory_env).resolve() + else: + config_directory = Path(appdirs.user_config_dir(appname=_appname)).resolve() + + return config_directory + + +def log_dir() -> Path: + if log_directory_env := getenv("DSH_LOG_DIRECTORY"): + log_directory = Path(log_directory_env).resolve() + else: + log_directory = Path(appdirs.user_log_dir(appname=_appname)).resolve() + log_directory.mkdir(parents=True, exist_ok=True) + + return log_directory diff --git a/data_safe_haven/exceptions/__init__.py b/data_safe_haven/exceptions/__init__.py index 1064ad63df..4921dc4932 100644 --- a/data_safe_haven/exceptions/__init__.py +++ b/data_safe_haven/exceptions/__init__.py @@ -1,5 +1,13 @@ +from data_safe_haven.logging import get_logger + + class DataSafeHavenError(Exception): - pass + def __init__(self, message: str | bytes): + super().__init__(message) + + # Log exception message as an error + logger = get_logger() + logger.error(message) class DataSafeHavenCloudError(DataSafeHavenError): diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 7078e67832..730335306a 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -60,17 +60,15 @@ DataSafeHavenInternalError, ) from data_safe_haven.external.interface.azure_authenticator import AzureAuthenticator -from data_safe_haven.utility import LoggingSingleton, NonLoggingSingleton +from data_safe_haven.logging import get_logger class AzureApi(AzureAuthenticator): """Interface to the Azure REST API""" - def __init__( - self, subscription_name: str, *, disable_logging: bool = False - ) -> None: + def __init__(self, subscription_name: str) -> None: super().__init__(subscription_name) - self.logger = NonLoggingSingleton() if disable_logging else LoggingSingleton() + self.logger = get_logger() def blob_client( self, @@ -752,7 +750,7 @@ def remove_dns_txt_record( zone_name=zone_name, ) except ResourceNotFoundError: - self.logger.info( + self.logger.warning( f"DNS record [green]{record_name}[/] does not exist in zone [green]{zone_name}[/].", ) return diff --git a/data_safe_haven/external/api/azure_cli.py b/data_safe_haven/external/api/azure_cli.py index e32a597911..ebb6baa32e 100644 --- a/data_safe_haven/external/api/azure_cli.py +++ b/data_safe_haven/external/api/azure_cli.py @@ -8,7 +8,9 @@ import typer from data_safe_haven.exceptions import DataSafeHavenAzureError -from data_safe_haven.utility import LoggingSingleton, Singleton +from data_safe_haven.logging import get_logger +from data_safe_haven.singleton import Singleton +from data_safe_haven.utility import prompts @dataclass @@ -24,7 +26,7 @@ class AzureCliSingleton(metaclass=Singleton): """Interface to the Azure CLI""" def __init__(self) -> None: - self.logger = LoggingSingleton() + self.logger = get_logger() path = which("az") if path is None: @@ -70,7 +72,7 @@ def confirm(self) -> None: account = self.account self.logger.info(f"Azure user: {account.name} ({account.id_})") self.logger.info(f"Azure tenant ID: {account.tenant_id})") - if not self.logger.confirm( + if not prompts.confirm( "Is this the Azure account you expect?", default_to_yes=False ): self.logger.error( diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index f5356eb602..a2b51b104d 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -9,6 +9,7 @@ from typing import Any, ClassVar import requests +import typer from dns import resolver from msal import ( ConfidentialClientApplication, @@ -22,7 +23,8 @@ DataSafeHavenMicrosoftGraphError, ) from data_safe_haven.functions import alphanumeric -from data_safe_haven.utility import LoggingSingleton, NonLoggingSingleton +from data_safe_haven.logging import get_logger +from data_safe_haven.utility import prompts class LocalTokenCache(SerializableTokenCache): @@ -77,13 +79,12 @@ def __init__( application_secret: str | None = None, base_endpoint: str = "", default_scopes: Sequence[str] = [], - disable_logging: bool = False, ): self.base_endpoint = ( base_endpoint if base_endpoint else "https://graph.microsoft.com/v1.0" ) self.default_scopes = list(default_scopes) - self.logger = NonLoggingSingleton() if disable_logging else LoggingSingleton() + self.logger = get_logger() self.tenant_id = tenant_id if auth_token: self.token = auth_token @@ -1082,7 +1083,7 @@ def verify_custom_domain( while True: # Check whether all expected nameservers are active with suppress(resolver.NXDOMAIN): - self.logger.info( + self.logger.debug( f"Checking [green]{domain_name}[/] domain verification status ..." ) active_nameservers = [ @@ -1108,14 +1109,14 @@ def verify_custom_domain( self.logger.info( f"You will need to create an NS record pointing to: {ns_list}" ) - if isinstance(self.logger, LoggingSingleton): - self.logger.confirm( - f"Are you ready to check whether [green]{domain_name}[/] has been delegated to Azure?", - default_to_yes=True, + if not prompts.confirm( + f"Are you ready to check whether [green]{domain_name}[/] has been delegated to Azure?", + default_to_yes=True, + ): + self.logger.error( + "Please use `az login` to connect to the correct Azure CLI account" ) - else: - msg = "Unable to confirm Azure nameserver delegation." - raise NotImplementedError(msg) + raise typer.Exit(1) # Send verification request if needed if not any((d["id"] == domain_name and d["isVerified"]) for d in domains): response = self.http_post( diff --git a/data_safe_haven/external/interface/azure_container_instance.py b/data_safe_haven/external/interface/azure_container_instance.py index 1f696f4b04..012d29f429 100644 --- a/data_safe_haven/external/interface/azure_container_instance.py +++ b/data_safe_haven/external/interface/azure_container_instance.py @@ -11,7 +11,7 @@ from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external import AzureApi -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger class AzureContainerInstance: @@ -24,7 +24,7 @@ def __init__( subscription_name: str, ): self.azure_api = AzureApi(subscription_name) - self.logger = LoggingSingleton() + self.logger = get_logger() self.resource_group_name = resource_group_name self.container_group_name = container_group_name diff --git a/data_safe_haven/external/interface/azure_postgresql_database.py b/data_safe_haven/external/interface/azure_postgresql_database.py index a63b82ca6e..4c089c363b 100644 --- a/data_safe_haven/external/interface/azure_postgresql_database.py +++ b/data_safe_haven/external/interface/azure_postgresql_database.py @@ -18,8 +18,9 @@ DataSafeHavenInputError, ) from data_safe_haven.external import AzureApi +from data_safe_haven.logging import get_logger from data_safe_haven.types import PathType -from data_safe_haven.utility import FileReader, LoggingSingleton +from data_safe_haven.utility import FileReader class AzurePostgreSQLDatabase: @@ -50,7 +51,7 @@ def __init__( self.db_name = database_name self.db_server_ = None self.db_server_admin_password = database_server_admin_password - self.logger = LoggingSingleton() + self.logger = get_logger() self.port = 5432 self.resource_group_name = resource_group_name self.server_name = database_server_name @@ -147,7 +148,7 @@ def execute_scripts( # Apply the Guacamole initialisation script for filepath in filepaths: _filepath = pathlib.Path(filepath) - self.logger.info(f"Running SQL script: [green]{_filepath.name}[/].") + self.logger.debug(f"Running SQL script: [green]{_filepath.name}[/].") commands = self.load_sql(_filepath, mustache_values) cursor.execute(query=commands.encode()) if cursor.statusmessage and "SELECT" in cursor.statusmessage: @@ -155,7 +156,7 @@ def execute_scripts( # Commit changes connection.commit() - self.logger.info(f"Finished running {len(filepaths)} SQL scripts.") + self.logger.debug(f"Finished running {len(filepaths)} SQL scripts.") except (Exception, psycopg.Error) as exc: msg = f"Error while connecting to PostgreSQL.\n{exc}" raise DataSafeHavenAzureError(msg) from exc diff --git a/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py b/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py index 813af650d3..8c2291c501 100644 --- a/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py +++ b/data_safe_haven/infrastructure/components/dynamic/blob_container_acl.py @@ -55,7 +55,7 @@ def create(self, props: dict[str, Any]) -> CreateResult: """Set ACLs for a given blob container.""" outs = dict(**props) try: - azure_api = AzureApi(props["subscription_name"], disable_logging=True) + azure_api = AzureApi(props["subscription_name"]) azure_api.set_blob_container_acl( container_name=props["container_name"], desired_acl=props["desired_acl"], @@ -75,7 +75,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: # Use `id` as a no-op to avoid ARG002 while maintaining function signature id(id_) try: - azure_api = AzureApi(props["subscription_name"], disable_logging=True) + azure_api = AzureApi(props["subscription_name"]) azure_api.set_blob_container_acl( container_name=props["container_name"], desired_acl="user::rwx,group::r-x,other::---", diff --git a/data_safe_haven/infrastructure/components/dynamic/entra_application.py b/data_safe_haven/infrastructure/components/dynamic/entra_application.py index c6c0e3758a..0fc4d11eb6 100644 --- a/data_safe_haven/infrastructure/components/dynamic/entra_application.py +++ b/data_safe_haven/infrastructure/components/dynamic/entra_application.py @@ -40,9 +40,7 @@ def refresh(props: dict[str, Any]) -> dict[str, Any]: try: outs = dict(**props) with suppress(DataSafeHavenMicrosoftGraphError): - graph_api = GraphApi( - auth_token=outs["auth_token"], disable_logging=True - ) + graph_api = GraphApi(auth_token=outs["auth_token"]) if json_response := graph_api.get_application_by_name( outs["application_name"] ): @@ -68,7 +66,7 @@ def create(self, props: dict[str, Any]) -> CreateResult: """Create new Entra application.""" outs = dict(**props) try: - graph_api = GraphApi(auth_token=props["auth_token"], disable_logging=True) + graph_api = GraphApi(auth_token=props["auth_token"]) request_json = { "displayName": props["application_name"], "signInAudience": "AzureADMyOrg", @@ -124,7 +122,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: # Use `id` as a no-op to avoid ARG002 while maintaining function signature id(id_) try: - graph_api = GraphApi(auth_token=props["auth_token"], disable_logging=True) + graph_api = GraphApi(auth_token=props["auth_token"]) graph_api.delete_application(props["application_name"]) except Exception as exc: msg = f"Failed to delete application [green]{props['application_name']}[/] from Entra ID.\n{exc}" diff --git a/data_safe_haven/infrastructure/components/dynamic/file_upload.py b/data_safe_haven/infrastructure/components/dynamic/file_upload.py index f634477da6..a9005741c5 100644 --- a/data_safe_haven/infrastructure/components/dynamic/file_upload.py +++ b/data_safe_haven/infrastructure/components/dynamic/file_upload.py @@ -42,7 +42,7 @@ class FileUploadProvider(DshResourceProvider): def create(self, props: dict[str, Any]) -> CreateResult: """Run a remote script to create a file on a VM""" outs = dict(**props) - azure_api = AzureApi(props["subscription_name"], disable_logging=True) + azure_api = AzureApi(props["subscription_name"]) script_contents = f""" target_dir=$(dirname "$target"); mkdir -p $target_dir 2> /dev/null; @@ -83,7 +83,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: """Delete the remote file from the VM""" # Use `id` as a no-op to avoid ARG002 while maintaining function signature id(id_) - azure_api = AzureApi(props["subscription_name"], disable_logging=True) + azure_api = AzureApi(props["subscription_name"]) script_contents = """ rm -f "$target"; echo "Removed file at $target"; diff --git a/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py b/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py index deb63ce144..7cb967af95 100644 --- a/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py +++ b/data_safe_haven/infrastructure/components/dynamic/ssl_certificate.py @@ -48,7 +48,7 @@ def refresh(props: dict[str, Any]) -> dict[str, Any]: try: outs = dict(**props) with suppress(DataSafeHavenAzureError): - azure_api = AzureApi(outs["subscription_name"], disable_logging=True) + azure_api = AzureApi(outs["subscription_name"]) certificate = azure_api.get_keyvault_certificate( outs["certificate_secret_name"], outs["key_vault_name"] ) @@ -78,7 +78,7 @@ def create(self, props: dict[str, Any]) -> CreateResult: private_key_bytes = client.generate_private_key(key_type="rsa2048") client.generate_csr() # Request DNS verification tokens and add them to the DNS record - azure_api = AzureApi(props["subscription_name"], disable_logging=True) + azure_api = AzureApi(props["subscription_name"]) verification_tokens = client.request_verification_tokens().items() for record_name, record_values in verification_tokens: record_set = azure_api.ensure_dns_txt_record( @@ -153,7 +153,7 @@ def delete(self, id_: str, props: dict[str, Any]) -> None: id(id_) try: # Remove the DNS record - azure_api = AzureApi(props["subscription_name"], disable_logging=True) + azure_api = AzureApi(props["subscription_name"]) azure_api.remove_dns_txt_record( record_name="_acme_challenge", resource_group_name=props["networking_resource_group_name"], diff --git a/data_safe_haven/infrastructure/project_manager.py b/data_safe_haven/infrastructure/project_manager.py index e140129bd5..cf55762d31 100644 --- a/data_safe_haven/infrastructure/project_manager.py +++ b/data_safe_haven/infrastructure/project_manager.py @@ -20,7 +20,7 @@ ) from data_safe_haven.external import AzureApi, AzureCliSingleton from data_safe_haven.functions import replace_separators -from data_safe_haven.utility import LoggingSingleton +from data_safe_haven.logging import get_logger from .programs import DeclarativeSHM, DeclarativeSRE @@ -86,7 +86,7 @@ def __init__( self.create_project = create_project self.account = PulumiAccount(context, config) - self.logger = LoggingSingleton() + self.logger = get_logger() self._stack: automation.Stack | None = None self.stack_outputs_: automation.OutputMap | None = None self.options: dict[str, tuple[str, bool, bool]] = {} @@ -141,7 +141,7 @@ def pulumi_project(self) -> DSHPulumiProject: def stack(self) -> automation.Stack: """Load the Pulumi stack, creating if needed.""" if not self._stack: - self.logger.info(f"Creating/loading stack [green]{self.stack_name}[/].") + self.logger.debug(f"Creating/loading stack [green]{self.stack_name}[/].") try: self._stack = automation.create_or_select_stack( opts=automation.LocalWorkspaceOptions( @@ -173,7 +173,7 @@ def add_secret(self, name: str, value: str, *, replace: bool) -> None: def apply_config_options(self) -> None: """Set Pulumi config options""" try: - self.logger.info("Updating Pulumi configuration") + self.logger.debug("Updating Pulumi configuration") for name, (value, is_secret, replace) in self.options.items(): if replace: self.set_config(name, value, secret=is_secret) @@ -248,9 +248,10 @@ def destroy(self) -> None: raise DataSafeHavenPulumiError(msg) from exc # Remove stack JSON try: - self.logger.info(f"Removing Pulumi stack [green]{self.stack_name}[/].") + self.logger.debug(f"Removing Pulumi stack [green]{self.stack_name}[/].") if self._stack: self._stack.workspace.remove_stack(self.stack_name) + self.logger.info(f"Removed Pulumi stack [green]{self.stack_name}[/].") except automation.CommandError as exc: if "no stack named" not in str(exc): msg = f"Pulumi stack could not be removed.\n{exc}" @@ -258,7 +259,7 @@ def destroy(self) -> None: # Remove stack JSON backup stack_backup_name = f"{self.stack_name}.json.bak" try: - self.logger.info( + self.logger.debug( f"Removing Pulumi stack backup [green]{stack_backup_name}[/]." ) azure_api = AzureApi(self.context.subscription_name) @@ -268,6 +269,9 @@ def destroy(self) -> None: storage_account_name=self.context.storage_account_name, storage_container_name=self.context.pulumi_storage_container_name, ) + self.logger.debug( + f"Removed Pulumi stack backup [green]{stack_backup_name}[/]." + ) except DataSafeHavenAzureError as exc: if "blob does not exist" in str(exc): self.logger.warning( @@ -299,7 +303,7 @@ def evaluate(self, result: str) -> None: def install_plugins(self) -> None: """For inline programs, we must manage plugins ourselves.""" try: - self.logger.info("Installing required Pulumi plugins") + self.logger.debug("Installing required Pulumi plugins") self.stack.workspace.install_plugin( "azure-native", metadata.version("pulumi-azure-native") ) diff --git a/data_safe_haven/logging/__init__.py b/data_safe_haven/logging/__init__.py new file mode 100644 index 0000000000..a8ea5fdaa9 --- /dev/null +++ b/data_safe_haven/logging/__init__.py @@ -0,0 +1,13 @@ +from .logger import ( + get_logger, + init_logging, + set_console_level, + show_console_level, +) + +__all__ = [ + "get_logger", + "init_logging", + "set_console_level", + "show_console_level", +] diff --git a/data_safe_haven/logging/logger.py b/data_safe_haven/logging/logger.py new file mode 100644 index 0000000000..46871a27e9 --- /dev/null +++ b/data_safe_haven/logging/logger.py @@ -0,0 +1,93 @@ +"""Custom logging classes and functions to interact with Python logging""" + +import logging +from datetime import UTC, datetime +from typing import Any + +from rich.logging import RichHandler +from rich.text import Text + +from data_safe_haven.directories import log_dir + + +class PlainFileHandler(logging.FileHandler): + """ + Logging handler that cleans messages before sending them to a log file. + """ + + def __init__(self, *args: Any, **kwargs: Any): + """Constructor""" + super().__init__(*args, **kwargs) + + @staticmethod + def strip_formatting(input_string: str) -> str: + """Strip console markup formatting from a string""" + text = Text.from_markup(input_string) + text.spans = [] + return str(text) + + def emit(self, record: logging.LogRecord) -> None: + """Emit a record without formatting""" + record.msg = self.strip_formatting(record.msg) + super().emit(record) + + +def get_logger() -> logging.Logger: + return logging.getLogger("data_safe_haven") + + +def init_logging() -> None: + # Configure root logger + # By default logging level is WARNING + root_logger = logging.getLogger(None) + root_logger.setLevel(logging.NOTSET) + + # Configure DSH logger + logger = get_logger() + logger.setLevel(logging.NOTSET) + + console_handler = RichHandler( + level=logging.INFO, + markup=True, + rich_tracebacks=True, + show_time=False, + show_path=False, + show_level=False, + ) + console_handler.setFormatter(logging.Formatter(r"%(message)s")) + + file_handler = PlainFileHandler( + f"{log_dir()}/{logfile_name()}", + delay=True, + encoding="utf8", + mode="a", + ) + file_handler.setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + file_handler.setLevel(logging.NOTSET) + + # Add handlers + logger.addHandler(console_handler) + logger.console_handler = console_handler # type: ignore [attr-defined] + logger.addHandler(file_handler) + logger.file_handler = file_handler # type: ignore [attr-defined] + + # Disable unnecessarily verbose external logging + logging.getLogger("azure.core.pipeline.policies").setLevel(logging.ERROR) + logging.getLogger("azure.identity._credentials").setLevel(logging.ERROR) + logging.getLogger("azure.identity._internal").setLevel(logging.ERROR) + logging.getLogger("azure.mgmt.core.policies").setLevel(logging.ERROR) + logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) + + +def logfile_name() -> str: + return f"{datetime.now(UTC).date()}.log" + + +def set_console_level(level: int | str) -> None: + get_logger().console_handler.setLevel(level) # type: ignore [attr-defined] + + +def show_console_level() -> None: + get_logger().console_handler._log_render.show_level = True # type: ignore [attr-defined] diff --git a/data_safe_haven/provisioning/sre_provisioning_manager.py b/data_safe_haven/provisioning/sre_provisioning_manager.py index 1a3f734205..9b685cb71f 100644 --- a/data_safe_haven/provisioning/sre_provisioning_manager.py +++ b/data_safe_haven/provisioning/sre_provisioning_manager.py @@ -10,8 +10,8 @@ GraphApi, ) from data_safe_haven.infrastructure import SREProjectManager +from data_safe_haven.logging import get_logger from data_safe_haven.types import AzureLocation, AzureSubscriptionName -from data_safe_haven.utility import LoggingSingleton class SREProvisioningManager: @@ -29,7 +29,7 @@ def __init__( self._available_vm_skus: dict[str, dict[str, Any]] | None = None self.location = location self.graph_api = GraphApi(auth_token=graph_api_token) - self.logger = LoggingSingleton() + self.logger = get_logger() self.sre_name = sre_name self.subscription_name = subscription_name @@ -114,7 +114,7 @@ def update_remote_desktop_connections(self) -> None: "user_group_name": self.security_group_params["user_group_name"], } for details in connection_data["connections"]: - self.logger.info( + self.logger.debug( f"Adding connection [bold]{details['connection_name']}[/] at [green]{details['ip_address']}[/]." ) postgres_script_path = ( diff --git a/data_safe_haven/utility/singleton.py b/data_safe_haven/singleton.py similarity index 100% rename from data_safe_haven/utility/singleton.py rename to data_safe_haven/singleton.py diff --git a/data_safe_haven/utility/__init__.py b/data_safe_haven/utility/__init__.py index 5cede18f95..b204bcd2f1 100644 --- a/data_safe_haven/utility/__init__.py +++ b/data_safe_haven/utility/__init__.py @@ -1,12 +1,8 @@ -from .directories import config_dir +from . import console, prompts from .file_reader import FileReader -from .logger import LoggingSingleton, NonLoggingSingleton -from .singleton import Singleton __all__ = [ - "config_dir", + "console", "FileReader", - "LoggingSingleton", - "NonLoggingSingleton", - "Singleton", + "prompts", ] diff --git a/data_safe_haven/utility/console.py b/data_safe_haven/utility/console.py new file mode 100644 index 0000000000..b9588f4e28 --- /dev/null +++ b/data_safe_haven/utility/console.py @@ -0,0 +1,25 @@ +from rich import print as rprint +from rich.table import Table + + +def tabulate( + header: list[str] | None = None, rows: list[list[str]] | None = None +) -> None: + """Generate a table from header and rows + + Args: + header: The table header + rows: The table rows + + Returns: + A list of strings representing the table + """ + table = Table() + if header: + for item in header: + table.add_column(item) + if rows: + for row in rows: + table.add_row(*row) + + rprint(table) diff --git a/data_safe_haven/utility/directories.py b/data_safe_haven/utility/directories.py deleted file mode 100644 index 593f64bb50..0000000000 --- a/data_safe_haven/utility/directories.py +++ /dev/null @@ -1,15 +0,0 @@ -from os import getenv -from pathlib import Path - -import appdirs - - -def config_dir() -> Path: - if config_directory_env := getenv("DSH_CONFIG_DIRECTORY"): - config_directory = Path(config_directory_env).resolve() - else: - config_directory = Path( - appdirs.user_config_dir(appname="data_safe_haven") - ).resolve() - - return config_directory diff --git a/data_safe_haven/utility/logger.py b/data_safe_haven/utility/logger.py deleted file mode 100644 index 911db1c7ba..0000000000 --- a/data_safe_haven/utility/logger.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Standalone logging class implemented as a singleton""" - -import io -import logging -from typing import Any, ClassVar - -from rich.console import Console -from rich.highlighter import RegexHighlighter -from rich.logging import RichHandler -from rich.prompt import Confirm, Prompt -from rich.table import Table -from rich.text import Text - -from data_safe_haven.types import PathType - -from .singleton import Singleton - - -class LoggingHandlerPlainFile(logging.FileHandler): - """ - Logging handler that cleans messages before sending them to a log file. - """ - - def __init__( - self, fmt: str, datefmt: str, filename: str, *args: Any, **kwargs: Any - ): - """Constructor""" - kwargs["filename"] = filename - super().__init__(*args, **kwargs) - self.setFormatter(logging.Formatter(self.strip_formatting(fmt), datefmt)) - - @staticmethod - def strip_formatting(input_string: str) -> str: - """Strip console markup formatting from a string""" - text = Text.from_markup(input_string) - text.spans = [] - return str(text) - - def emit(self, record: logging.LogRecord) -> None: - """Emit a record without formatting""" - record.msg = self.strip_formatting(record.msg) - super().emit(record) - - -class LoggingHandlerRichConsole(RichHandler): - """ - Logging handler that uses Rich. - """ - - def __init__(self, fmt: str, datefmt: str, *args: Any, **kwargs: Any): - super().__init__( - *args, - highlighter=LogLevelHighlighter(), - markup=True, - omit_repeated_times=False, - rich_tracebacks=True, - show_level=False, - show_time=False, - tracebacks_show_locals=True, - **kwargs, - ) - self.setFormatter(logging.Formatter(fmt, datefmt)) - - -class LogLevelHighlighter(RegexHighlighter): - """ - Highlighter that looks for [level-name] and applies default formatting. - """ - - base_style = "logging.level." - highlights: ClassVar[list[str]] = [ - r"(?P\[CRITICAL\])", - r"(?P\[ DEBUG\])", - r"(?P\[ ERROR\])", - r"(?P\[ INFO\])", - r"(?P\[ WARNING\])", - ] - - -class RichStringAdaptor: - """ - A wrapper to convert Rich objects into strings. - """ - - def __init__(self, *, coloured: bool): - """Constructor""" - self.string_io = io.StringIO() - self.console = Console(file=self.string_io, force_terminal=coloured) - - def to_string(self, *renderables: Any) -> str: - """Convert Rich renderables into a string""" - self.console.print(*renderables) - return self.string_io.getvalue() - - -class LoggingSingleton(logging.Logger, metaclass=Singleton): - """ - Logging singleton that can be used by anything needing logging - """ - - date_fmt = r"%Y-%m-%d %H:%M:%S" - rich_format = r"[log.time]%(asctime)s[/] [%(levelname)8s] %(message)s" - - def __init__(self) -> None: - super().__init__(name="data_safe_haven", level=logging.INFO) - # Initialise console handler - self.addHandler(LoggingHandlerRichConsole(self.rich_format, self.date_fmt)) - # Disable unnecessarily verbose external logging - logging.getLogger("azure.core.pipeline.policies").setLevel(logging.ERROR) - logging.getLogger("azure.identity._credentials").setLevel(logging.ERROR) - logging.getLogger("azure.identity._internal").setLevel(logging.ERROR) - logging.getLogger("azure.mgmt.core.policies").setLevel(logging.ERROR) - logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) - - def ask(self, message: str, default: str | None = None) -> str: - """Ask user a question, formatted as a log message""" - formatted = self.format_msg(message, logging.INFO) - if default: - return str(Prompt.ask(formatted, default=default)) - return str(Prompt.ask(formatted)) - - def choose( - self, - message: str, - choices: list[str] | None = None, - default: str | None = None, - ) -> str: - """Ask a user to choose among options, formatted as a log message""" - formatted = self.format_msg(message, logging.INFO) - if default: - return str(Prompt.ask(formatted, choices=choices, default=default)) - return str(Prompt.ask(formatted, choices=choices)) - - def confirm(self, message: str, *, default_to_yes: bool) -> bool: - """Ask a user to confirm an action, formatted as a log message""" - formatted = self.format_msg(message, logging.INFO) - return bool(Confirm.ask(formatted, default=default_to_yes)) - - def format_msg(self, message: str, level: int = logging.INFO) -> str: - """Format a message using rich handler""" - for handler in self.handlers: - if isinstance(handler, RichHandler): - fn, lno, func, sinfo = self.findCaller(stack_info=False, stacklevel=1) - return str( - handler.format( - self.makeRecord( - name=self.name, - level=level, - fn=fn, - lno=lno, - msg=message, - args={}, - exc_info=None, - func=func, - sinfo=sinfo, - ) - ) - ) - return message - - def parse(self, message: str) -> None: - """ - Parse a message that starts with a log-level token. - - This function is designed to handle messages from non-Python code inside this package. - """ - tokens = message.split(":") - level, remainder = tokens[0].upper(), ":".join(tokens[1:]).strip() - if level == "CRITICAL": - return self.critical(remainder) - elif level == "ERROR": - return self.error(remainder) - elif level == "WARNING": - return self.warning(remainder) - elif level == "INFO": - return self.info(remainder) - elif level == "DEBUG": - return self.debug(remainder) - else: - return self.info(message.strip()) - - def set_log_file(self, file_path: PathType) -> None: - """Set a log file handler""" - file_handler = LoggingHandlerPlainFile( - self.rich_format, self.date_fmt, str(file_path) - ) - for h in self.handlers: - if isinstance(h, LoggingHandlerPlainFile): - self.removeHandler(h) - self.addHandler(file_handler) - - def set_verbosity(self, verbosity: int) -> None: - """Set verbosity""" - self.setLevel( - max(logging.INFO - 10 * (verbosity if verbosity else 0), logging.NOTSET) - ) - - def style(self, message: str) -> str: - """Apply logging style to a string""" - markup = self.format_msg(message) - return RichStringAdaptor(coloured=True).to_string(markup) - - def tabulate( - self, header: list[str] | None = None, rows: list[list[str]] | None = None - ) -> list[str]: - """Generate a table from header and rows - - Args: - header: The table header - rows: The table rows - - Returns: - A list of strings representing the table - """ - table = Table() - if header: - for item in header: - table.add_column(item) - if rows: - for row in rows: - table.add_row(*row) - adaptor = RichStringAdaptor(coloured=True) - return [line.strip() for line in adaptor.to_string(table).split("\n")] - - -class NonLoggingSingleton(logging.Logger, metaclass=Singleton): - """ - Non-logging singleton that can be used by anything needing logs to be consumed - """ - - def __init__(self) -> None: - super().__init__(name="non-logger", level=logging.CRITICAL + 10) - while self.handlers: - self.removeHandler(self.handlers[0]) diff --git a/data_safe_haven/utility/prompts.py b/data_safe_haven/utility/prompts.py new file mode 100644 index 0000000000..201d4fcc28 --- /dev/null +++ b/data_safe_haven/utility/prompts.py @@ -0,0 +1,14 @@ +from rich.prompt import Confirm + +from data_safe_haven.logging import get_logger + + +def confirm(message: str, *, default_to_yes: bool) -> bool: + """Ask a user to confirm an action, formatted as a log message""" + logger = get_logger() + + logger.debug(f"Prompting user to confirm '{message}'") + response: bool = Confirm.ask(message, default=default_to_yes) + response_text = "yes" if response else "no" + logger.debug(f"User responded '{response_text}'") + return response diff --git a/tests/commands/test_cli.py b/tests/commands/test_cli.py index c27a1e3e76..889382776f 100644 --- a/tests/commands/test_cli.py +++ b/tests/commands/test_cli.py @@ -1,4 +1,5 @@ from data_safe_haven.commands import application +from data_safe_haven.version import __version__ class TestHelp: @@ -6,8 +7,8 @@ def result_checker(self, result): assert result.exit_code == 0 assert "Usage: dsh [OPTIONS] COMMAND [ARGS]..." in result.stdout assert "Arguments to the main executable" in result.stdout - assert "│ --output" in result.stdout - assert "│ --verbosity" in result.stdout + assert "│ --verbose" in result.stdout + assert "│ --show-level" in result.stdout assert "│ --version" in result.stdout assert "│ --install-completion" in result.stdout assert "│ --show-completion" in result.stdout @@ -25,3 +26,10 @@ def test_help(self, runner): def test_help_short_code(self, runner): result = runner.invoke(application, ["-h"]) self.result_checker(result) + + +class TestVersion: + def test_version(self, runner): + result = runner.invoke(application, ["--version"]) + assert result.exit_code == 0 + assert f"Data Safe Haven {__version__}" in result.stdout diff --git a/tests/commands/test_context.py b/tests/commands/test_context.py index 4111dd592d..4374239052 100644 --- a/tests/commands/test_context.py +++ b/tests/commands/test_context.py @@ -247,7 +247,8 @@ def test_show_none(self, runner_none): def test_auth_failure(self, runner, mocker): def mock_login(self): # noqa: ARG001 - raise DataSafeHavenAzureAPIAuthenticationError + msg = "Failed to authenticate with Azure API." + raise DataSafeHavenAzureAPIAuthenticationError(msg) mocker.patch.object(AzureAuthenticator, "login", mock_login) @@ -279,7 +280,8 @@ def test_show_none(self, runner_none): def test_auth_failure(self, runner, mocker): def mock_login(self): # noqa: ARG001 - raise DataSafeHavenAzureAPIAuthenticationError + msg = "Failed to authenticate with Azure API." + raise DataSafeHavenAzureAPIAuthenticationError(msg) mocker.patch.object(AzureAuthenticator, "login", mock_login) diff --git a/tests/conftest.py b/tests/conftest.py index 1ac030c2d8..8a75ec0c76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from pytest import fixture import data_safe_haven.context.context_settings as context_mod +import data_safe_haven.logging.logger from data_safe_haven.config import ( Config, DSHPulumiConfig, @@ -25,6 +26,7 @@ ProjectManager, PulumiAccount, ) +from data_safe_haven.logging import init_logging @fixture(autouse=True, scope="session") @@ -35,6 +37,16 @@ def local_pulumi_login(): run([pulumi_path, "logout"], check=False) +@fixture(autouse=True) +def log_directory(mocker, monkeypatch, tmp_path): + monkeypatch.setenv("DSH_LOG_DIRECTORY", tmp_path) + mocker.patch.object( + data_safe_haven.logging.logger, "logfile_name", return_value="test.log" + ) + init_logging() + return tmp_path + + @fixture def context_dict(): return { diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py new file mode 100644 index 0000000000..0527424007 --- /dev/null +++ b/tests/logging/test_logger.py @@ -0,0 +1,78 @@ +import logging +from datetime import datetime +from pathlib import Path + +from rich.logging import RichHandler + +from data_safe_haven.logging.logger import ( + PlainFileHandler, + get_logger, + logfile_name, + set_console_level, + show_console_level, +) + + +class TestPlainFileHandler: + def test_strip_formatting(self): + assert PlainFileHandler.strip_formatting("[green]hello[/]") == "hello" + + +class TestLogFileName: + def test_logfile_name(self): + name = logfile_name() + assert name.endswith(".log") + date = name.split(".")[0] + assert datetime.strptime(date, "%Y-%m-%d") # noqa: DTZ007 + + +class TestGetLogger: + def test_get_logger(self): + logger = get_logger() + assert isinstance(logger, logging.Logger) + assert logger.name == "data_safe_haven" + assert hasattr(logger, "console_handler") + assert hasattr(logger, "file_handler") + + +class TestLogger: + def test_constructor(self, log_directory): + logger = get_logger() + + assert isinstance(logger.file_handler, PlainFileHandler) + assert isinstance(logger.console_handler, RichHandler) + + assert logger.file_handler.baseFilename == f"{log_directory}/test.log" + log_file = Path(logger.file_handler.baseFilename) + logger.info("hello") + assert log_file.is_file() + + +class TestSetConsoleLevel: + def test_set_console_level(self): + logger = get_logger() + assert logger.console_handler.level == logging.INFO + set_console_level(logging.DEBUG) + assert logger.console_handler.level == logging.DEBUG + + def test_set_console_level_stdout(self, capsys): + logger = get_logger() + set_console_level(logging.DEBUG) + logger.debug("hello") + out, _ = capsys.readouterr() + assert "hello" in out + + +class TestShowConsoleLevel: + def test_show_console_level(self): + logger = get_logger() + assert not logger.console_handler._log_render.show_level + show_console_level() + assert logger.console_handler._log_render.show_level + + def test_show_console_level_stdout(self, capsys): + logger = get_logger() + show_console_level() + logger.info("hello") + out, _ = capsys.readouterr() + assert "INFO" in out