diff --git a/.github/workflows/test_conda_build.yaml b/.github/workflows/test_conda_build.yaml index e34363d9a3..f7500f343e 100644 --- a/.github/workflows/test_conda_build.yaml +++ b/.github/workflows/test_conda_build.yaml @@ -33,7 +33,7 @@ jobs: uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true - python-version: 3.8 + python-version: 3.11 channels: conda-forge activate-environment: nebari-dev diff --git a/pyproject.toml b/pyproject.toml index cb90bc52d0..4585ddcd9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,8 @@ dependencies = [ "kubernetes==27.2.0", "pluggy==1.3.0", "prompt-toolkit==3.0.36", - "pydantic==1.10.12", + "pydantic==2.4.2", + "typing-extensions==4.7.1; python_version < '3.9'", "pynacl==1.5.0", "python-keycloak==3.3.0", "questionary==2.0.0", diff --git a/src/_nebari/config.py b/src/_nebari/config.py index d1b2f42944..7c27274f36 100644 --- a/src/_nebari/config.py +++ b/src/_nebari/config.py @@ -2,19 +2,19 @@ import pathlib import re import sys -import typing +from typing import Any, Dict, List, Union import pydantic from _nebari.utils import yaml -def set_nested_attribute(data: typing.Any, attrs: typing.List[str], value: typing.Any): +def set_nested_attribute(data: Any, attrs: List[str], value: Any): """Takes an arbitrary set of attributes and accesses the deep nested object config to set value """ - def _get_attr(d: typing.Any, attr: str): + def _get_attr(d: Any, attr: str): if isinstance(d, list) and re.fullmatch(r"\d+", attr): return d[int(attr)] elif hasattr(d, "__getitem__"): @@ -22,7 +22,7 @@ def _get_attr(d: typing.Any, attr: str): else: return getattr(d, attr) - def _set_attr(d: typing.Any, attr: str, value: typing.Any): + def _set_attr(d: Any, attr: str, value: Any): if isinstance(d, list) and re.fullmatch(r"\d+", attr): d[int(attr)] = value elif hasattr(d, "__getitem__"): @@ -63,6 +63,15 @@ def set_config_from_environment_variables( return config +def dump_nested_model(model_dict: Dict[str, Union[pydantic.BaseModel, str]]): + result = {} + for key, value in model_dict.items(): + result[key] = ( + value.model_dump() if isinstance(value, pydantic.BaseModel) else value + ) + return result + + def read_configuration( config_filename: pathlib.Path, config_schema: pydantic.BaseModel, @@ -77,7 +86,8 @@ def read_configuration( ) with filename.open() as f: - config = config_schema(**yaml.load(f.read())) + config_dict = yaml.load(f) + config = config_schema(**config_dict) if read_environment: config = set_config_from_environment_variables(config) @@ -87,14 +97,15 @@ def read_configuration( def write_configuration( config_filename: pathlib.Path, - config: typing.Union[pydantic.BaseModel, typing.Dict], + config: Union[pydantic.BaseModel, Dict], mode: str = "w", ): """Write the nebari configuration file to disk""" with config_filename.open(mode) as f: if isinstance(config, pydantic.BaseModel): - yaml.dump(config.dict(), f) + yaml.dump(config.model_dump(), f) else: + config = dump_nested_model(config) yaml.dump(config, f) diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 2f07647521..050556a39a 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -3,6 +3,7 @@ import re import tempfile from pathlib import Path +from typing import Any, Dict import pydantic import requests @@ -45,7 +46,7 @@ def render_config( region: str = None, disable_prompt: bool = False, ssl_cert_email: str = None, -): +) -> Dict[str, Any]: config = { "provider": cloud_provider, "namespace": namespace, @@ -189,7 +190,7 @@ def render_config( from nebari.plugins import nebari_plugin_manager try: - config_model = nebari_plugin_manager.config_schema.parse_obj(config) + config_model = nebari_plugin_manager.config_schema.model_validate(config) except pydantic.ValidationError as e: print(str(e)) diff --git a/src/_nebari/provider/cicd/github.py b/src/_nebari/provider/cicd/github.py index 7b58464c43..2563af6ad9 100644 --- a/src/_nebari/provider/cicd/github.py +++ b/src/_nebari/provider/cicd/github.py @@ -4,7 +4,7 @@ import requests from nacl import encoding, public -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel from _nebari.constants import LATEST_SUPPORTED_PYTHON_VERSION from _nebari.provider.cicd.common import pip_install_nebari @@ -143,49 +143,34 @@ class GHA_on_extras(BaseModel): paths: List[str] -class GHA_on(BaseModel): - # to allow for dynamic key names - __root__: Dict[str, GHA_on_extras] - - # TODO: validate __root__ values - # `push`, `pull_request`, etc. - - -class GHA_job_steps_extras(BaseModel): - # to allow for dynamic key names - __root__: Union[str, float, int] +GHA_on = RootModel[Dict[str, GHA_on_extras]] +GHA_job_steps_extras = RootModel[Union[str, float, int]] class GHA_job_step(BaseModel): name: str - uses: Optional[str] - with_: Optional[Dict[str, GHA_job_steps_extras]] = Field(alias="with") - run: Optional[str] - env: Optional[Dict[str, GHA_job_steps_extras]] - - class Config: - allow_population_by_field_name = True + uses: Optional[str] = None + with_: Optional[Dict[str, GHA_job_steps_extras]] = Field(alias="with", default=None) + run: Optional[str] = None + env: Optional[Dict[str, GHA_job_steps_extras]] = None + model_config = ConfigDict(populate_by_name=True) class GHA_job_id(BaseModel): name: str runs_on_: str = Field(alias="runs-on") - permissions: Optional[Dict[str, str]] + permissions: Optional[Dict[str, str]] = None steps: List[GHA_job_step] - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) -class GHA_jobs(BaseModel): - # to allow for dynamic key names - __root__: Dict[str, GHA_job_id] +GHA_jobs = RootModel[Dict[str, GHA_job_id]] class GHA(BaseModel): name: str on: GHA_on - env: Optional[Dict[str, str]] + env: Optional[Dict[str, str]] = None jobs: GHA_jobs @@ -204,11 +189,7 @@ def checkout_image_step(): return GHA_job_step( name="Checkout Image", uses="actions/checkout@v3", - with_={ - "token": GHA_job_steps_extras( - __root__="${{ secrets.REPOSITORY_ACCESS_TOKEN }}" - ) - }, + with_={"token": GHA_job_steps_extras("${{ secrets.REPOSITORY_ACCESS_TOKEN }}")}, ) @@ -216,11 +197,7 @@ def setup_python_step(): return GHA_job_step( name="Set up Python", uses="actions/setup-python@v4", - with_={ - "python-version": GHA_job_steps_extras( - __root__=LATEST_SUPPORTED_PYTHON_VERSION - ) - }, + with_={"python-version": GHA_job_steps_extras(LATEST_SUPPORTED_PYTHON_VERSION)}, ) @@ -242,7 +219,7 @@ def gen_nebari_ops(config): env_vars = gha_env_vars(config) push = GHA_on_extras(branches=[config.ci_cd.branch], paths=["nebari-config.yaml"]) - on = GHA_on(__root__={"push": push}) + on = GHA_on({"push": push}) step1 = checkout_image_step() step2 = setup_python_step() @@ -272,7 +249,7 @@ def gen_nebari_ops(config): ), env={ "COMMIT_MSG": GHA_job_steps_extras( - __root__="nebari-config.yaml automated commit: ${{ github.sha }}" + "nebari-config.yaml automated commit: ${{ github.sha }}" ) }, ) @@ -291,7 +268,7 @@ def gen_nebari_ops(config): }, steps=gha_steps, ) - jobs = GHA_jobs(__root__={"build": job1}) + jobs = GHA_jobs({"build": job1}) return NebariOps( name="nebari auto update", @@ -312,18 +289,16 @@ def gen_nebari_linter(config): pull_request = GHA_on_extras( branches=[config.ci_cd.branch], paths=["nebari-config.yaml"] ) - on = GHA_on(__root__={"pull_request": pull_request}) + on = GHA_on({"pull_request": pull_request}) step1 = checkout_image_step() step2 = setup_python_step() step3 = install_nebari_step(config.nebari_version) step4_envs = { - "PR_NUMBER": GHA_job_steps_extras(__root__="${{ github.event.number }}"), - "REPO_NAME": GHA_job_steps_extras(__root__="${{ github.repository }}"), - "GITHUB_TOKEN": GHA_job_steps_extras( - __root__="${{ secrets.REPOSITORY_ACCESS_TOKEN }}" - ), + "PR_NUMBER": GHA_job_steps_extras("${{ github.event.number }}"), + "REPO_NAME": GHA_job_steps_extras("${{ github.repository }}"), + "GITHUB_TOKEN": GHA_job_steps_extras("${{ secrets.REPOSITORY_ACCESS_TOKEN }}"), } step4 = GHA_job_step( @@ -336,7 +311,7 @@ def gen_nebari_linter(config): name="nebari", runs_on_="ubuntu-latest", steps=[step1, step2, step3, step4] ) jobs = GHA_jobs( - __root__={ + { "nebari-validate": job1, } ) diff --git a/src/_nebari/provider/cicd/gitlab.py b/src/_nebari/provider/cicd/gitlab.py index e2d02b388b..d5e944f36d 100644 --- a/src/_nebari/provider/cicd/gitlab.py +++ b/src/_nebari/provider/cicd/gitlab.py @@ -1,40 +1,34 @@ from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel from _nebari.constants import LATEST_SUPPORTED_PYTHON_VERSION from _nebari.provider.cicd.common import pip_install_nebari - -class GLCI_extras(BaseModel): - # to allow for dynamic key names - __root__: Union[str, float, int] +GLCI_extras = RootModel[Union[str, float, int]] class GLCI_image(BaseModel): name: str - entrypoint: Optional[str] + entrypoint: Optional[str] = None class GLCI_rules(BaseModel): if_: Optional[str] = Field(alias="if") - changes: Optional[List[str]] - - class Config: - allow_population_by_field_name = True + changes: Optional[List[str]] = None + model_config = ConfigDict(populate_by_name=True) class GLCI_job(BaseModel): - image: Optional[Union[str, GLCI_image]] - variables: Optional[Dict[str, str]] - before_script: Optional[List[str]] - after_script: Optional[List[str]] + image: Optional[Union[str, GLCI_image]] = None + variables: Optional[Dict[str, str]] = None + before_script: Optional[List[str]] = None + after_script: Optional[List[str]] = None script: List[str] - rules: Optional[List[GLCI_rules]] + rules: Optional[List[GLCI_rules]] = None -class GLCI(BaseModel): - __root__: Dict[str, GLCI_job] +GLCI = RootModel[Dict[str, GLCI_job]] def gen_gitlab_ci(config): @@ -76,7 +70,7 @@ def gen_gitlab_ci(config): ) return GLCI( - __root__={ + { "render-nebari": render_nebari, } ) diff --git a/src/_nebari/provider/cloud/amazon_web_services.py b/src/_nebari/provider/cloud/amazon_web_services.py index 2bf905bfcb..1123c07fe0 100644 --- a/src/_nebari/provider/cloud/amazon_web_services.py +++ b/src/_nebari/provider/cloud/amazon_web_services.py @@ -7,25 +7,18 @@ import boto3 from botocore.exceptions import ClientError, EndpointConnectionError -from _nebari import constants +from _nebari.constants import AWS_ENV_DOCS from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version +from _nebari.utils import check_environment_variables from nebari import schema MAX_RETRIES = 5 DELAY = 5 -def check_credentials(): - """Check for AWS credentials are set in the environment.""" - for variable in { - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - }: - if variable not in os.environ: - raise ValueError( - f"""Missing the following required environment variable: {variable}\n - Please see the documentation for more information: {constants.AWS_ENV_DOCS}""" - ) +def check_credentials() -> None: + required_variables = {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"} + check_environment_variables(required_variables, AWS_ENV_DOCS) @functools.lru_cache() diff --git a/src/_nebari/provider/cloud/azure_cloud.py b/src/_nebari/provider/cloud/azure_cloud.py index 992e5c1362..44ebdaaee6 100644 --- a/src/_nebari/provider/cloud/azure_cloud.py +++ b/src/_nebari/provider/cloud/azure_cloud.py @@ -9,10 +9,11 @@ from azure.mgmt.containerservice import ContainerServiceClient from azure.mgmt.resource import ResourceManagementClient -from _nebari import constants +from _nebari.constants import AZURE_ENV_DOCS from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version from _nebari.utils import ( AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX, + check_environment_variables, construct_azure_resource_group_name, ) from nebari import schema @@ -24,29 +25,18 @@ RETRIES = 10 -def check_credentials(): - """Check if credentials are valid.""" - - required_variables = { - "ARM_CLIENT_ID": os.environ.get("ARM_CLIENT_ID", None), - "ARM_SUBSCRIPTION_ID": os.environ.get("ARM_SUBSCRIPTION_ID", None), - "ARM_TENANT_ID": os.environ.get("ARM_TENANT_ID", None), - } - arm_client_secret = os.environ.get("ARM_CLIENT_SECRET", None) - - if not all(required_variables.values()): - raise ValueError( - f"""Missing the following required environment variables: {required_variables}\n - Please see the documentation for more information: {constants.AZURE_ENV_DOCS}""" - ) +def check_credentials() -> DefaultAzureCredential: + required_variables = {"ARM_CLIENT_ID", "ARM_SUBSCRIPTION_ID", "ARM_TENANT_ID"} + check_environment_variables(required_variables, AZURE_ENV_DOCS) + optional_variable = "ARM_CLIENT_SECRET" + arm_client_secret = os.environ.get(optional_variable, None) if arm_client_secret: logger.info("Authenticating as a service principal.") - return DefaultAzureCredential() else: - logger.info("No ARM_CLIENT_SECRET environment variable found.") + logger.info(f"No {optional_variable} environment variable found.") logger.info("Allowing Azure SDK to authenticate using OIDC or other methods.") - return DefaultAzureCredential() + return DefaultAzureCredential() @functools.lru_cache() diff --git a/src/_nebari/provider/cloud/digital_ocean.py b/src/_nebari/provider/cloud/digital_ocean.py index d64ca4c6de..3e4a507be6 100644 --- a/src/_nebari/provider/cloud/digital_ocean.py +++ b/src/_nebari/provider/cloud/digital_ocean.py @@ -7,24 +7,20 @@ import kubernetes.config import requests -from _nebari import constants +from _nebari.constants import DO_ENV_DOCS from _nebari.provider.cloud.amazon_web_services import aws_delete_s3_bucket from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version -from _nebari.utils import set_do_environment +from _nebari.utils import check_environment_variables, set_do_environment from nebari import schema -def check_credentials(): - for variable in { +def check_credentials() -> None: + required_variables = { + "DIGITALOCEAN_TOKEN", "SPACES_ACCESS_KEY_ID", "SPACES_SECRET_ACCESS_KEY", - "DIGITALOCEAN_TOKEN", - }: - if variable not in os.environ: - raise ValueError( - f"""Missing the following required environment variable: {variable}\n - Please see the documentation for more information: {constants.DO_ENV_DOCS}""" - ) + } + check_environment_variables(required_variables, DO_ENV_DOCS) def digital_ocean_request(url, method="GET", json=None): @@ -63,7 +59,7 @@ def regions(): return _kubernetes_options()["options"]["regions"] -def kubernetes_versions(region) -> typing.List[str]: +def kubernetes_versions() -> typing.List[str]: """Return list of available kubernetes supported by cloud provider. Sorted from oldest to latest.""" supported_kubernetes_versions = sorted( [_["slug"].split("-")[0] for _ in _kubernetes_options()["options"]["versions"]] diff --git a/src/_nebari/provider/cloud/google_cloud.py b/src/_nebari/provider/cloud/google_cloud.py index ba95f713cf..c2beff5c7e 100644 --- a/src/_nebari/provider/cloud/google_cloud.py +++ b/src/_nebari/provider/cloud/google_cloud.py @@ -1,21 +1,17 @@ import functools import json -import os import subprocess from typing import Dict, List, Set -from _nebari import constants +from _nebari.constants import GCP_ENV_DOCS from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version +from _nebari.utils import check_environment_variables from nebari import schema -def check_credentials(): - for variable in {"GOOGLE_CREDENTIALS", "PROJECT_ID"}: - if variable not in os.environ: - raise ValueError( - f"""Missing the following required environment variable: {variable}\n - Please see the documentation for more information: {constants.GCP_ENV_DOCS}""" - ) +def check_credentials() -> None: + required_variables = {"GOOGLE_APPLICATION_CREDENTIALS", "GOOGLE_PROJECT"} + check_environment_variables(required_variables, GCP_ENV_DOCS) @functools.lru_cache() @@ -282,7 +278,7 @@ def check_missing_service() -> None: if missing: raise ValueError( f"""Missing required services: {missing}\n - Please see the documentation for more information: {constants.GCP_ENV_DOCS}""" + Please see the documentation for more information: {GCP_ENV_DOCS}""" ) diff --git a/src/_nebari/stages/bootstrap/__init__.py b/src/_nebari/stages/bootstrap/__init__.py index 688146999b..97e754d9c8 100644 --- a/src/_nebari/stages/bootstrap/__init__.py +++ b/src/_nebari/stages/bootstrap/__init__.py @@ -96,7 +96,7 @@ def render(self) -> Dict[str, str]: for fn, workflow in gen_cicd(self.config).items(): stream = io.StringIO() schema.yaml.dump( - workflow.dict( + workflow.model_dump( by_alias=True, exclude_unset=True, exclude_defaults=True ), stream, diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 5c1aa77f77..4568aa08b2 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -7,7 +7,7 @@ import tempfile from typing import Any, Dict, List, Optional, Tuple, Type, Union -import pydantic +from pydantic import Field, field_validator, model_validator from _nebari import constants from _nebari.provider import terraform @@ -27,6 +27,11 @@ from nebari import schema from nebari.hookspecs import NebariStage, hookimpl +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + def get_kubeconfig_filename(): return str(pathlib.Path(tempfile.gettempdir()) / "NEBARI_KUBECONFIG") @@ -34,7 +39,7 @@ def get_kubeconfig_filename(): class LocalInputVars(schema.Base): kubeconfig_filename: str = get_kubeconfig_filename() - kube_context: Optional[str] + kube_context: Optional[str] = None class ExistingInputVars(schema.Base): @@ -219,13 +224,13 @@ class DigitalOceanNodeGroup(schema.Base): """ instance: str - min_nodes: pydantic.conint(ge=1) = 1 - max_nodes: pydantic.conint(ge=1) = 1 + min_nodes: Annotated[int, Field(ge=1)] = 1 + max_nodes: Annotated[int, Field(ge=1)] = 1 class DigitalOceanProvider(schema.Base): region: str - kubernetes_version: str + kubernetes_version: Optional[str] = None # Digital Ocean image slugs are listed here https://slugs.do-api.dev/ node_groups: Dict[str, DigitalOceanNodeGroup] = { "general": DigitalOceanNodeGroup( @@ -240,51 +245,39 @@ class DigitalOceanProvider(schema.Base): } tags: Optional[List[str]] = [] - @pydantic.validator("region") - def _validate_region(cls, value): + @model_validator(mode="before") + @classmethod + def _check_input(self, data: Any) -> Any: digital_ocean.check_credentials() + # check if region is valid available_regions = set(_["slug"] for _ in digital_ocean.regions()) - if value not in available_regions: + if data["region"] not in available_regions: raise ValueError( - f"Digital Ocean region={value} is not one of {available_regions}" + f"Digital Ocean region={data['region']} is not one of {available_regions}" ) - return value - - @pydantic.validator("node_groups") - def _validate_node_group(cls, value): - digital_ocean.check_credentials() - - available_instances = {_["slug"] for _ in digital_ocean.instances()} - for name, node_group in value.items(): - if node_group.instance not in available_instances: - raise ValueError( - f"Digital Ocean instance {node_group.instance} not one of available instance types={available_instances}" - ) - - return value - @pydantic.root_validator - def _validate_kubernetes_version(cls, values): - digital_ocean.check_credentials() - - if "region" not in values: - raise ValueError("Region required in order to set kubernetes_version") - - available_kubernetes_versions = digital_ocean.kubernetes_versions( - values["region"] - ) - assert available_kubernetes_versions - if ( - values["kubernetes_version"] is not None - and values["kubernetes_version"] not in available_kubernetes_versions - ): + # check if kubernetes version is valid + available_kubernetes_versions = digital_ocean.kubernetes_versions() + if len(available_kubernetes_versions) == 0: raise ValueError( - f"\nInvalid `kubernetes-version` provided: {values['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + "Request to Digital Ocean for available Kubernetes versions failed." ) - else: - values["kubernetes_version"] = available_kubernetes_versions[-1] - return values + if data["kubernetes_version"] is None: + data["kubernetes_version"] = available_kubernetes_versions[-1] + elif data["kubernetes_version"] not in available_kubernetes_versions: + raise ValueError( + f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + ) + + available_instances = {_["slug"] for _ in digital_ocean.instances()} + if "node_groups" in data: + for _, node_group in data["node_groups"].items(): + if node_group["instance"] not in available_instances: + raise ValueError( + f"Digital Ocean instance {node_group.instance} not one of available instance types={available_instances}" + ) + return data class GCPIPAllocationPolicy(schema.Base): @@ -317,13 +310,13 @@ class GCPGuestAccelerator(schema.Base): """ name: str - count: pydantic.conint(ge=1) = 1 + count: Annotated[int, Field(ge=1)] = 1 class GCPNodeGroup(schema.Base): instance: str - min_nodes: pydantic.conint(ge=0) = 0 - max_nodes: pydantic.conint(ge=1) = 1 + min_nodes: Annotated[int, Field(ge=0)] = 0 + max_nodes: Annotated[int, Field(ge=1)] = 1 preemptible: bool = False labels: Dict[str, str] = {} guest_accelerators: List[GCPGuestAccelerator] = [] @@ -348,31 +341,24 @@ class GoogleCloudPlatformProvider(schema.Base): master_authorized_networks_config: Optional[Union[GCPCIDRBlock, None]] = None private_cluster_config: Optional[Union[GCPPrivateClusterConfig, None]] = None - @pydantic.root_validator - def validate_all(cls, values): - region = values.get("region") - project_id = values.get("project") - - if project_id is None: - raise ValueError("The `google_cloud_platform.project` field is required.") - - if region is None: - raise ValueError("The `google_cloud_platform.region` field is required.") - - # validate region - google_cloud.validate_region(region) - - # validate kubernetes version - kubernetes_version = values.get("kubernetes_version") - available_kubernetes_versions = google_cloud.kubernetes_versions(region) - if kubernetes_version is None: - values["kubernetes_version"] = available_kubernetes_versions[-1] - elif kubernetes_version not in available_kubernetes_versions: + @model_validator(mode="before") + @classmethod + def _check_input(cls, data: Any) -> Any: + google_cloud.check_credentials() + avaliable_regions = google_cloud.regions(data["project"]) + print(avaliable_regions) + if data["region"] not in avaliable_regions: raise ValueError( - f"\nInvalid `kubernetes-version` provided: {values['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + f"Google Cloud region={data['region']} is not one of {avaliable_regions}" ) - return values + available_kubernetes_versions = google_cloud.kubernetes_versions(data["region"]) + print(available_kubernetes_versions) + if data["kubernetes_version"] not in available_kubernetes_versions: + raise ValueError( + f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + ) + return data class AzureNodeGroup(schema.Base): @@ -383,24 +369,31 @@ class AzureNodeGroup(schema.Base): class AzureProvider(schema.Base): region: str - kubernetes_version: str + kubernetes_version: Optional[str] = None storage_account_postfix: str - resource_group_name: str = None + resource_group_name: Optional[str] = None node_groups: Dict[str, AzureNodeGroup] = { "general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1), "user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), "worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), } storage_account_postfix: str - vnet_subnet_id: Optional[Union[str, None]] = None + vnet_subnet_id: Optional[str] = None private_cluster_enabled: bool = False resource_group_name: Optional[str] = None tags: Optional[Dict[str, str]] = {} network_profile: Optional[Dict[str, str]] = None max_pods: Optional[int] = None - @pydantic.validator("kubernetes_version") - def _validate_kubernetes_version(cls, value): + @model_validator(mode="before") + @classmethod + def _check_credentials(cls, data: Any) -> Any: + azure_cloud.check_credentials() + return data + + @field_validator("kubernetes_version") + @classmethod + def _validate_kubernetes_version(cls, value: Optional[str]) -> str: available_kubernetes_versions = azure_cloud.kubernetes_versions() if value is None: value = available_kubernetes_versions[-1] @@ -410,7 +403,8 @@ def _validate_kubernetes_version(cls, value): ) return value - @pydantic.validator("resource_group_name") + @field_validator("resource_group_name") + @classmethod def _validate_resource_group_name(cls, value): if value is None: return value @@ -428,9 +422,12 @@ def _validate_resource_group_name(cls, value): return value - @pydantic.validator("tags") - def _validate_tags(cls, tags): - return azure_cloud.validate_tags(tags) + @field_validator("tags") + @classmethod + def _validate_tags( + cls, value: typing.Optional[typing.Dict[str, str]] + ) -> typing.Dict[str, str]: + return value if value is None else azure_cloud.validate_tags(value) class AWSNodeGroup(schema.Base): @@ -455,49 +452,66 @@ class AmazonWebServicesProvider(schema.Base): instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False ), } - existing_subnet_ids: List[str] = None - existing_security_group_id: str = None + existing_subnet_ids: Optional[List[str]] = None + existing_security_group_id: Optional[str] = None vpc_cidr_block: str = "10.10.0.0/16" permissions_boundary: Optional[str] = None tags: Optional[Dict[str, str]] = {} - @pydantic.root_validator - def validate_all(cls, values): - region = values.get("region") - if region is None: - raise ValueError("The `amazon_web_services.region` field is required.") - - # validate region - amazon_web_services.validate_region(region) - - # validate kubernetes version - kubernetes_version = values.get("kubernetes_version") - available_kubernetes_versions = amazon_web_services.kubernetes_versions(region) - if kubernetes_version is None: - values["kubernetes_version"] = available_kubernetes_versions[-1] - elif kubernetes_version not in available_kubernetes_versions: + @model_validator(mode="before") + @classmethod + def _check_input(cls, data: Any) -> Any: + amazon_web_services.check_credentials() + + # check if region is valid + available_regions = amazon_web_services.regions(data["region"]) + if data["region"] not in available_regions: raise ValueError( - f"\nInvalid `kubernetes-version` provided: {values['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + f"Amazon Web Services region={data['region']} is not one of {available_regions}" ) - # validate node groups - node_groups = values["node_groups"] - available_instances = amazon_web_services.instances(region) - for name, node_group in node_groups.items(): - if node_group.instance not in available_instances: - raise ValueError( - f"Instance {node_group.instance} not available out of available instances {available_instances.keys()}" - ) + # check if kubernetes version is valid + available_kubernetes_versions = amazon_web_services.kubernetes_versions( + data["region"] + ) + if len(available_kubernetes_versions) == 0: + raise ValueError("Request to AWS for available Kubernetes versions failed.") + if data["kubernetes_version"] is None: + data["kubernetes_version"] = available_kubernetes_versions[-1] + elif data["kubernetes_version"] not in available_kubernetes_versions: + raise ValueError( + f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + ) - if values["availability_zones"] is None: - zones = amazon_web_services.zones(region) - values["availability_zones"] = list(sorted(zones))[:2] + # check if availability zones are valid + available_zones = amazon_web_services.zones(data["region"]) + if "availability_zones" not in data: + data["availability_zones"] = available_zones + else: + for zone in data["availability_zones"]: + if zone not in available_zones: + raise ValueError( + f"Amazon Web Services availability zone={zone} is not one of {available_zones}" + ) - return values + # check if instances are valid + available_instances = amazon_web_services.instances(data["region"]) + if "node_groups" in data: + for _, node_group in data["node_groups"].items(): + instance = ( + node_group["instance"] + if hasattr(node_group, "__getitem__") + else node_group.instance + ) + if instance not in available_instances: + raise ValueError( + f"Amazon Web Services instance {node_group.instance} not one of available instance types={available_instances}" + ) + return data class LocalProvider(schema.Base): - kube_context: Optional[str] + kube_context: Optional[str] = None node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -506,7 +520,7 @@ class LocalProvider(schema.Base): class ExistingProvider(schema.Base): - kube_context: Optional[str] + kube_context: Optional[str] = None node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -538,23 +552,24 @@ class ExistingProvider(schema.Base): class InputSchema(schema.Base): - local: Optional[LocalProvider] - existing: Optional[ExistingProvider] - google_cloud_platform: Optional[GoogleCloudPlatformProvider] - amazon_web_services: Optional[AmazonWebServicesProvider] - azure: Optional[AzureProvider] - digital_ocean: Optional[DigitalOceanProvider] - - @pydantic.root_validator(pre=True) - def check_provider(cls, values): - if "provider" in values: - provider: str = values["provider"] + local: Optional[LocalProvider] = None + existing: Optional[ExistingProvider] = None + google_cloud_platform: Optional[GoogleCloudPlatformProvider] = None + amazon_web_services: Optional[AmazonWebServicesProvider] = None + azure: Optional[AzureProvider] = None + digital_ocean: Optional[DigitalOceanProvider] = None + + @model_validator(mode="before") + @classmethod + def check_provider(cls, data: Any) -> Any: + if "provider" in data: + provider: str = data["provider"] if hasattr(schema.ProviderEnum, provider): # TODO: all cloud providers has required fields, but local and existing don't. # And there is no way to initialize a model without user input here. # We preserve the original behavior here, but we should find a better way to do this. - if provider in ["local", "existing"] and provider not in values: - values[provider] = provider_enum_model_map[provider]() + if provider in ["local", "existing"] and provider not in data: + data[provider] = provider_enum_model_map[provider]() else: # if the provider field is invalid, it won't be set when this validator is called # so we need to check for it explicitly here, and set the `pre` to True @@ -566,16 +581,16 @@ def check_provider(cls, values): setted_providers = [ provider for provider in provider_name_abbreviation_map.keys() - if provider in values + if provider in data ] num_providers = len(setted_providers) if num_providers > 1: raise ValueError(f"Multiple providers set: {setted_providers}") elif num_providers == 1: - values["provider"] = provider_name_abbreviation_map[setted_providers[0]] + data["provider"] = provider_name_abbreviation_map[setted_providers[0]] elif num_providers == 0: - values["provider"] = schema.ProviderEnum.local.value - return values + data["provider"] = schema.ProviderEnum.local.value + return data class NodeSelectorKeyValue(schema.Base): @@ -586,20 +601,20 @@ class NodeSelectorKeyValue(schema.Base): class KubernetesCredentials(schema.Base): host: str cluster_ca_certifiate: str - token: Optional[str] - username: Optional[str] - password: Optional[str] - client_certificate: Optional[str] - client_key: Optional[str] - config_path: Optional[str] - config_context: Optional[str] + token: Optional[str] = None + username: Optional[str] = None + password: Optional[str] = None + client_certificate: Optional[str] = None + client_key: Optional[str] = None + config_path: Optional[str] = None + config_context: Optional[str] = None class OutputSchema(schema.Base): node_selectors: Dict[str, NodeSelectorKeyValue] kubernetes_credentials: KubernetesCredentials kubeconfig_filename: str - nfs_endpoint: Optional[str] + nfs_endpoint: Optional[str] = None class KubernetesInfrastructureStage(NebariTerraformStage): @@ -687,7 +702,9 @@ def tf_objects(self) -> List[Dict]: def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): if self.config.provider == schema.ProviderEnum.local: - return LocalInputVars(kube_context=self.config.local.kube_context).dict() + return LocalInputVars( + kube_context=self.config.local.kube_context + ).model_dump() elif self.config.provider == schema.ProviderEnum.existing: return ExistingInputVars( kube_context=self.config.existing.kube_context diff --git a/src/_nebari/stages/kubernetes_ingress/__init__.py b/src/_nebari/stages/kubernetes_ingress/__init__.py index 2c55e0cae9..ab717c1ba3 100644 --- a/src/_nebari/stages/kubernetes_ingress/__init__.py +++ b/src/_nebari/stages/kubernetes_ingress/__init__.py @@ -3,8 +3,7 @@ import socket import sys import time -import typing -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type from _nebari import constants from _nebari.provider.dns.cloudflare import update_record @@ -128,23 +127,23 @@ def to_yaml(cls, representer, node): class Certificate(schema.Base): type: CertificateEnum = CertificateEnum.selfsigned # existing - secret_name: typing.Optional[str] + secret_name: Optional[str] = None # lets-encrypt - acme_email: typing.Optional[str] + acme_email: Optional[str] = None acme_server: str = "https://acme-v02.api.letsencrypt.org/directory" class DnsProvider(schema.Base): - provider: typing.Optional[str] - auto_provision: typing.Optional[bool] = False + provider: Optional[str] = None + auto_provision: Optional[bool] = False class Ingress(schema.Base): - terraform_overrides: typing.Dict = {} + terraform_overrides: Dict = {} class InputSchema(schema.Base): - domain: typing.Optional[str] + domain: Optional[str] = None certificate: Certificate = Certificate() ingress: Ingress = Ingress() dns: DnsProvider = DnsProvider() @@ -156,7 +155,7 @@ class IngressEndpoint(schema.Base): class OutputSchema(schema.Base): - load_balancer_address: typing.List[IngressEndpoint] + load_balancer_address: List[IngressEndpoint] domain: str diff --git a/src/_nebari/stages/kubernetes_initialize/__init__.py b/src/_nebari/stages/kubernetes_initialize/__init__.py index e40a69ed0f..1810f81e1a 100644 --- a/src/_nebari/stages/kubernetes_initialize/__init__.py +++ b/src/_nebari/stages/kubernetes_initialize/__init__.py @@ -1,8 +1,7 @@ import sys -import typing -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type -import pydantic +from pydantic import model_validator from _nebari.stages.base import NebariTerraformStage from _nebari.stages.tf_objects import ( @@ -16,37 +15,34 @@ class ExtContainerReg(schema.Base): enabled: bool = False - access_key_id: typing.Optional[str] - secret_access_key: typing.Optional[str] - extcr_account: typing.Optional[str] - extcr_region: typing.Optional[str] - - @pydantic.root_validator - def enabled_must_have_fields(cls, values): - if values["enabled"]: + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + extcr_account: Optional[str] = None + extcr_region: Optional[str] = None + + @model_validator(mode="after") + def enabled_must_have_fields(self): + if self.enabled: for fldname in ( "access_key_id", "secret_access_key", "extcr_account", "extcr_region", ): - if ( - fldname not in values - or values[fldname] is None - or values[fldname].strip() == "" - ): + value = getattr(self, fldname) + if value is None or value.strip() == "": raise ValueError( f"external_container_reg must contain a non-blank {fldname} when enabled is true" ) - return values + return self class InputVars(schema.Base): name: str environment: str cloud_provider: str - aws_region: Union[str, None] = None - external_container_reg: Union[ExtContainerReg, None] = None + aws_region: Optional[str] = None + external_container_reg: Optional[ExtContainerReg] = None gpu_enabled: bool = False gpu_node_group_names: List[str] = [] diff --git a/src/_nebari/stages/kubernetes_keycloak/__init__.py b/src/_nebari/stages/kubernetes_keycloak/__init__.py index 767c83189b..59d3ee0f50 100644 --- a/src/_nebari/stages/kubernetes_keycloak/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak/__init__.py @@ -6,11 +6,9 @@ import string import sys import time -import typing -from abc import ABC -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type, Union -import pydantic +from pydantic import Field, ValidationInfo, field_validator from _nebari.stages.base import NebariTerraformStage from _nebari.stages.tf_objects import ( @@ -62,93 +60,79 @@ def to_yaml(cls, representer, node): class GitHubConfig(schema.Base): - client_id: str = pydantic.Field( - default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID") + client_id: str = Field( + default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID"), + validate_default=True, ) - client_secret: str = pydantic.Field( - default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET") + client_secret: str = Field( + default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET"), + validate_default=True, ) - @pydantic.root_validator(allow_reuse=True) - def validate_required(cls, values): - missing = [] - for k, v in { + @field_validator("client_id", "client_secret", mode="before") + @classmethod + def validate_credentials(cls, value: Optional[str], info: ValidationInfo) -> str: + variable_mapping = { "client_id": "GITHUB_CLIENT_ID", "client_secret": "GITHUB_CLIENT_SECRET", - }.items(): - if not values.get(k): - missing.append(v) - - if len(missing) > 0: + } + if value is None: raise ValueError( - f"Missing the following required environment variable(s): {', '.join(missing)}" + f"{variable_mapping[info.field_name]} is not set in the environment" ) - - return values + return value class Auth0Config(schema.Base): - client_id: str = pydantic.Field( - default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID") + client_id: str = Field( + default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID"), + validate_default=True, ) - client_secret: str = pydantic.Field( - default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET") + client_secret: str = Field( + default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET"), + validate_default=True, ) - auth0_subdomain: str = pydantic.Field( - default_factory=lambda: os.environ.get("AUTH0_DOMAIN") + auth0_subdomain: str = Field( + default_factory=lambda: os.environ.get("AUTH0_DOMAIN"), + validate_default=True, ) - @pydantic.root_validator(allow_reuse=True) - def validate_required(cls, values): - missing = [] - for k, v in { + @field_validator("client_id", "client_secret", "auth0_subdomain", mode="before") + @classmethod + def validate_credentials(cls, value: Optional[str], info: ValidationInfo) -> str: + variable_mapping = { "client_id": "AUTH0_CLIENT_ID", "client_secret": "AUTH0_CLIENT_SECRET", "auth0_subdomain": "AUTH0_DOMAIN", - }.items(): - if not values.get(k): - missing.append(v) - - if len(missing) > 0: + } + if value is None: raise ValueError( - f"Missing the following required environment variable(s): {', '.join(missing)}" + f"{variable_mapping[info.field_name]} is not set in the environment" ) + return value - return values - - -class Authentication(schema.Base, ABC): - _types: typing.Dict[str, type] = {} +class BaseAuthentication(schema.Base): type: AuthenticationEnum - # Based on https://github.com/samuelcolvin/pydantic/issues/2177#issuecomment-739578307 - # This allows type field to determine which subclass of Authentication should be used for validation. +class PasswordAuthentication(BaseAuthentication): + type: AuthenticationEnum = AuthenticationEnum.password - # Used to register automatically all the submodels in `_types`. - def __init_subclass__(cls): - cls._types[cls._typ.value] = cls - @classmethod - def __get_validators__(cls): - yield cls.validate +class Auth0Authentication(BaseAuthentication): + type: AuthenticationEnum = AuthenticationEnum.auth0 + config: Auth0Config = Field(default_factory=lambda: Auth0Config()) - @classmethod - def validate(cls, value: typing.Dict[str, typing.Any]) -> "Authentication": - if "type" not in value: - raise ValueError("type field is missing from security.authentication") - specified_type = value.get("type") - sub_class = cls._types.get(specified_type, None) +class GitHubAuthentication(BaseAuthentication): + type: AuthenticationEnum = AuthenticationEnum.github + config: GitHubConfig = Field(default_factory=lambda: GitHubConfig()) - if not sub_class: - raise ValueError( - f"No registered Authentication type called {specified_type}" - ) - # init with right submodel - return sub_class(**value) +Authentication = Union[ + PasswordAuthentication, Auth0Authentication, GitHubAuthentication +] def random_secure_string( @@ -157,33 +141,56 @@ def random_secure_string( return "".join(secrets.choice(chars) for i in range(length)) -class PasswordAuthentication(Authentication): - _typ = AuthenticationEnum.password - +class Keycloak(schema.Base): + initial_root_password: str = Field(default_factory=random_secure_string) + overrides: Dict = {} + realm_display_name: str = "Nebari" -class Auth0Authentication(Authentication): - _typ = AuthenticationEnum.auth0 - config: Auth0Config = pydantic.Field(default_factory=lambda: Auth0Config()) +auth_enum_to_model = { + AuthenticationEnum.password: PasswordAuthentication, + AuthenticationEnum.auth0: Auth0Authentication, + AuthenticationEnum.github: GitHubAuthentication, +} -class GitHubAuthentication(Authentication): - _typ = AuthenticationEnum.github - config: GitHubConfig = pydantic.Field(default_factory=lambda: GitHubConfig()) - - -class Keycloak(schema.Base): - initial_root_password: str = pydantic.Field(default_factory=random_secure_string) - overrides: typing.Dict = {} - realm_display_name: str = "Nebari" +auth_enum_to_config = { + AuthenticationEnum.auth0: Auth0Config, + AuthenticationEnum.github: GitHubConfig, +} class Security(schema.Base): - authentication: Authentication = PasswordAuthentication( - type=AuthenticationEnum.password - ) + authentication: Authentication = PasswordAuthentication() shared_users_group: bool = True keycloak: Keycloak = Keycloak() + @field_validator("authentication", mode="before") + @classmethod + def validate_authentication(cls, value: Optional[Dict]) -> Authentication: + if value is None: + return PasswordAuthentication() + if "type" not in value: + raise ValueError( + "Authentication type must be specified if authentication is set" + ) + auth_type = value["type"] if hasattr(value, "__getitem__") else value.type + if auth_type in auth_enum_to_model: + if auth_type == AuthenticationEnum.password: + return auth_enum_to_model[auth_type]() + else: + if "config" in value: + config_dict = ( + value["config"] + if hasattr(value, "__getitem__") + else value.config + ) + config = auth_enum_to_config[auth_type](**config_dict) + else: + config = auth_enum_to_config[auth_type]() + return auth_enum_to_model[auth_type](config=config) + else: + raise ValueError(f"Unsupported authentication type {auth_type}") + class InputSchema(schema.Base): security: Security = Security() diff --git a/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py b/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py index 4cb0c23aeb..1c33429e37 100644 --- a/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Type from _nebari.stages.base import NebariTerraformStage +from _nebari.stages.kubernetes_keycloak import Authentication from _nebari.stages.tf_objects import NebariTerraformState from nebari import schema from nebari.hookspecs import NebariStage, hookimpl @@ -14,7 +15,7 @@ class InputVars(schema.Base): realm: str = "nebari" realm_display_name: str - authentication: Dict[str, Any] + authentication: Authentication keycloak_groups: List[str] = ["superadmin", "admin", "developer", "analyst"] default_groups: List[str] = ["analyst"] @@ -39,7 +40,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): input_vars.keycloak_groups += users_group input_vars.default_groups += users_group - return input_vars.dict() + return input_vars.model_dump() def check( self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False diff --git a/src/_nebari/stages/kubernetes_services/__init__.py b/src/_nebari/stages/kubernetes_services/__init__.py index 8542ceb9d5..51fdfa9a2d 100644 --- a/src/_nebari/stages/kubernetes_services/__init__.py +++ b/src/_nebari/stages/kubernetes_services/__init__.py @@ -2,12 +2,10 @@ import json import sys import time -import typing -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type, Union from urllib.parse import urlencode -import pydantic -from pydantic import Field +from pydantic import ConfigDict, Field, field_validator, model_validator from _nebari import constants from _nebari.stages.base import NebariTerraformStage @@ -76,12 +74,10 @@ class Theme(schema.Base): class KubeSpawner(schema.Base): cpu_limit: int - cpu_guarantee: int + cpu_guarantee: float mem_limit: str mem_guarantee: str - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class JupyterLabProfile(schema.Base): @@ -89,36 +85,31 @@ class JupyterLabProfile(schema.Base): display_name: str description: str default: bool = False - users: typing.Optional[typing.List[str]] - groups: typing.Optional[typing.List[str]] - kubespawner_override: typing.Optional[KubeSpawner] - - @pydantic.root_validator - def only_yaml_can_have_groups_and_users(cls, values): - if values["access"] != AccessEnum.yaml: - if ( - values.get("users", None) is not None - or values.get("groups", None) is not None - ): + users: Optional[List[str]] = None + groups: Optional[List[str]] = None + kubespawner_override: Optional[KubeSpawner] = None + + @model_validator(mode="after") + def only_yaml_can_have_groups_and_users(self): + if self.access != AccessEnum.yaml: + if self.users is not None or self.groups is not None: raise ValueError( "Profile must not contain groups or users fields unless access = yaml" ) - return values + return self class DaskWorkerProfile(schema.Base): worker_cores_limit: int - worker_cores: int + worker_cores: float worker_memory_limit: str worker_memory: str worker_threads: int = 1 - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Profiles(schema.Base): - jupyterlab: typing.List[JupyterLabProfile] = [ + jupyterlab: List[JupyterLabProfile] = [ JupyterLabProfile( display_name="Small Instance", description="Stable environment with 2 cpu / 8 GB ram", @@ -141,7 +132,7 @@ class Profiles(schema.Base): ), ), ] - dask_worker: typing.Dict[str, DaskWorkerProfile] = { + dask_worker: Dict[str, DaskWorkerProfile] = { "Small Worker": DaskWorkerProfile( worker_cores_limit=2, worker_cores=1.5, @@ -158,25 +149,26 @@ class Profiles(schema.Base): ), } - @pydantic.validator("jupyterlab") - def check_default(cls, v, values): + @field_validator("jupyterlab") + @classmethod + def check_default(cls, value): """Check if only one default value is present.""" - default = [attrs["default"] for attrs in v if "default" in attrs] + default = [attrs["default"] for attrs in value if "default" in attrs] if default.count(True) > 1: raise TypeError( "Multiple default Jupyterlab profiles may cause unexpected problems." ) - return v + return value class CondaEnvironment(schema.Base): name: str - channels: typing.Optional[typing.List[str]] - dependencies: typing.List[typing.Union[str, typing.Dict[str, typing.List[str]]]] + channels: Optional[List[str]] = None + dependencies: List[Union[str, Dict[str, List[str]]]] class CondaStore(schema.Base): - extra_settings: typing.Dict[str, typing.Any] = {} + extra_settings: Dict[str, Any] = {} extra_config: str = "" image: str = "quansight/conda-store-server" image_tag: str = constants.DEFAULT_CONDA_STORE_IMAGE_TAG @@ -191,7 +183,7 @@ class NebariWorkflowController(schema.Base): class ArgoWorkflows(schema.Base): enabled: bool = True - overrides: typing.Dict = {} + overrides: Dict = {} nebari_workflow_controller: NebariWorkflowController = NebariWorkflowController() @@ -213,7 +205,7 @@ class Telemetry(schema.Base): class JupyterHub(schema.Base): - overrides: typing.Dict = {} + overrides: Dict = {} class IdleCuller(schema.Base): @@ -236,7 +228,7 @@ class InputSchema(schema.Base): storage: Storage = Storage() theme: Theme = Theme() profiles: Profiles = Profiles() - environments: typing.Dict[str, CondaEnvironment] = { + environments: Dict[str, CondaEnvironment] = { "environment-dask.yaml": CondaEnvironment( name="dask", channels=["conda-forge"], @@ -355,7 +347,9 @@ class JupyterhubInputVars(schema.Base): initial_repositories: str = Field(alias="initial-repositories") jupyterhub_overrides: List[str] = Field(alias="jupyterhub-overrides") jupyterhub_stared_storage: str = Field(alias="jupyterhub-shared-storage") - jupyterhub_shared_endpoint: str = Field(None, alias="jupyterhub-shared-endpoint") + jupyterhub_shared_endpoint: Optional[str] = Field( + alias="jupyterhub-shared-endpoint", default=None + ) jupyterhub_profiles: List[JupyterLabProfile] = Field(alias="jupyterlab-profiles") jupyterhub_image: ImageNameTag = Field(alias="jupyterhub-image") jupyterhub_hub_extraEnv: str = Field(alias="jupyterhub-hub-extraEnv") diff --git a/src/_nebari/stages/nebari_tf_extensions/__init__.py b/src/_nebari/stages/nebari_tf_extensions/__init__.py index b5bfdbec4f..33adb588c5 100644 --- a/src/_nebari/stages/nebari_tf_extensions/__init__.py +++ b/src/_nebari/stages/nebari_tf_extensions/__init__.py @@ -1,5 +1,4 @@ -import typing -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type from _nebari.stages.base import NebariTerraformStage from _nebari.stages.tf_objects import ( @@ -25,8 +24,8 @@ class NebariExtension(schema.Base): keycloakadmin: bool = False jwt: bool = False nebariconfigyaml: bool = False - logout: typing.Optional[str] - envs: typing.Optional[typing.List[NebariExtensionEnv]] + logout: Optional[str] = None + envs: Optional[List[NebariExtensionEnv]] = None class HelmExtension(schema.Base): @@ -34,12 +33,12 @@ class HelmExtension(schema.Base): repository: str chart: str version: str - overrides: typing.Dict = {} + overrides: Dict = {} class InputSchema(schema.Base): - helm_extensions: typing.List[HelmExtension] = [] - tf_extensions: typing.List[NebariExtension] = [] + helm_extensions: List[HelmExtension] = [] + tf_extensions: List[NebariExtension] = [] class OutputSchema(schema.Base): diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index 094231e967..ac554496ab 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -4,10 +4,9 @@ import os import pathlib import re -import typing -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type -import pydantic +from pydantic import field_validator from _nebari.provider import terraform from _nebari.provider.cloud import azure_cloud @@ -39,10 +38,11 @@ class AzureInputVars(schema.Base): region: str storage_account_postfix: str state_resource_group_name: str - tags: Dict[str, str] = {} + tags: Dict[str, str] - @pydantic.validator("state_resource_group_name") - def _validate_resource_group_name(cls, value): + @field_validator("state_resource_group_name") + @classmethod + def _validate_resource_group_name(cls, value: str) -> str: if value is None: return value length = len(value) + len(AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX) @@ -59,9 +59,10 @@ def _validate_resource_group_name(cls, value): return value - @pydantic.validator("tags") - def _validate_tags(cls, tags): - return azure_cloud.validate_tags(tags) + @field_validator("tags") + @classmethod + def _validate_tags(cls, value: Dict[str, str]) -> Dict[str, str]: + return azure_cloud.validate_tags(value) class AWSInputVars(schema.Base): @@ -82,8 +83,8 @@ def to_yaml(cls, representer, node): class TerraformState(schema.Base): type: TerraformStateEnum = TerraformStateEnum.remote - backend: typing.Optional[str] - config: typing.Dict[str, str] = {} + backend: Optional[str] = None + config: Dict[str, str] = {} class InputSchema(schema.Base): diff --git a/src/_nebari/upgrade.py b/src/_nebari/upgrade.py index ef933f48ea..2cd7241302 100644 --- a/src/_nebari/upgrade.py +++ b/src/_nebari/upgrade.py @@ -9,7 +9,7 @@ import rich from packaging.version import Version -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from rich.prompt import Prompt from _nebari.config import backup_configuration diff --git a/src/_nebari/utils.py b/src/_nebari/utils.py index 3378116a1d..d68b96ee85 100644 --- a/src/_nebari/utils.py +++ b/src/_nebari/utils.py @@ -11,7 +11,7 @@ import time import warnings from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Set from ruamel.yaml import YAML @@ -350,3 +350,18 @@ def get_provider_config_block_name(provider): return PROVIDER_CONFIG_NAMES[provider] else: return provider + + +def check_environment_variables(variables: Set[str], reference: str) -> None: + """Check that environment variables are set.""" + required_variables = { + variable: os.environ.get(variable, None) for variable in variables + } + missing_variables = { + variable for variable, value in required_variables.items() if value is None + } + if missing_variables: + raise ValueError( + f"""Missing the following required environment variables: {required_variables}\n + Please see the documentation for more information: {reference}""" + ) diff --git a/src/nebari/schema.py b/src/nebari/schema.py index ab90f8ebc7..bceea0b539 100644 --- a/src/nebari/schema.py +++ b/src/nebari/schema.py @@ -1,33 +1,38 @@ import enum +import sys import pydantic +from pydantic import ConfigDict, Field, StringConstraints, field_validator from ruamel.yaml import yaml_object from _nebari.utils import escape_string, yaml from _nebari.version import __version__, rounded_ver_parse +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + # Regex for suitable project names project_name_regex = r"^[A-Za-z][A-Za-z0-9\-_]{1,14}[A-Za-z0-9]$" -project_name_pydantic = pydantic.constr(regex=project_name_regex) +project_name_pydantic = Annotated[str, StringConstraints(pattern=project_name_regex)] # Regex for suitable namespaces namespace_regex = r"^[A-Za-z][A-Za-z\-_]*[A-Za-z]$" -namespace_pydantic = pydantic.constr(regex=namespace_regex) +namespace_pydantic = Annotated[str, StringConstraints(pattern=namespace_regex)] email_regex = "^[^ @]+@[^ @]+\\.[^ @]+$" -email_pydantic = pydantic.constr(regex=email_regex) +email_pydantic = Annotated[str, StringConstraints(pattern=email_regex)] github_url_regex = "^(https://)?github.com/([^/]+)/([^/]+)/?$" -github_url_pydantic = pydantic.constr(regex=github_url_regex) +github_url_pydantic = Annotated[str, StringConstraints(pattern=github_url_regex)] class Base(pydantic.BaseModel): - ... - - class Config: - extra = "forbid" - validate_assignment = True - allow_population_by_field_name = True + model_config = ConfigDict( + extra="forbid", validate_assignment=True, populate_by_name=True + ) @yaml_object(yaml) @@ -49,7 +54,7 @@ class Main(Base): namespace: namespace_pydantic = "dev" provider: ProviderEnum = ProviderEnum.local # In nebari_version only use major.minor.patch version - drop any pre/post/dev suffixes - nebari_version: str = __version__ + nebari_version: Annotated[str, Field(validate_default=True)] = __version__ prevent_deploy: bool = ( False # Optional, but will be given default value if not present @@ -57,19 +62,13 @@ class Main(Base): # If the nebari_version in the schema is old # we must tell the user to first run nebari upgrade - @pydantic.validator("nebari_version", pre=True, always=True) - def check_default(cls, v): - """ - Always called even if nebari_version is not supplied at all (so defaults to ''). That way we can give a more helpful error message. - """ - if not cls.is_version_accepted(v): - if v == "": - v = "not supplied" - raise ValueError( - f"nebari_version in the config file must be equivalent to {__version__} to be processed by this version of nebari (your config file version is {v})." - " Install a different version of nebari or run nebari upgrade to ensure your config file is compatible." - ) - return v + @field_validator("nebari_version") + @classmethod + def check_default(cls, value): + assert cls.is_version_accepted( + value + ), f"nebari_version={value} is not an accepted version, it must be equivalent to {__version__}.\nInstall a different version of nebari or run nebari upgrade to ensure your config file is compatible." + return value @classmethod def is_version_accepted(cls, v): diff --git a/tests/tests_unit/conftest.py b/tests/tests_unit/conftest.py index e98661c214..aed1eaa3e9 100644 --- a/tests/tests_unit/conftest.py +++ b/tests/tests_unit/conftest.py @@ -2,7 +2,9 @@ from unittest.mock import Mock import pytest +from typer.testing import CliRunner +from _nebari.cli import create_cli from _nebari.config import write_configuration from _nebari.constants import ( AWS_DEFAULT_REGION, @@ -13,8 +15,6 @@ from _nebari.initialize import render_config from _nebari.render import render_template from _nebari.stages.bootstrap import CiEnum -from _nebari.stages.kubernetes_keycloak import AuthenticationEnum -from _nebari.stages.terraform_state import TerraformStateEnum from nebari import schema from nebari.plugins import nebari_plugin_manager @@ -100,81 +100,42 @@ def _mock_return_value(return_value): @pytest.fixture( params=[ - # project, namespace, domain, cloud_provider, region, ci_provider, auth_provider + # cloud_provider, region ( - "pytestdo", - "dev", - "do.nebari.dev", schema.ProviderEnum.do, DO_DEFAULT_REGION, - CiEnum.github_actions, - AuthenticationEnum.password, ), ( - "pytestaws", - "dev", - "aws.nebari.dev", schema.ProviderEnum.aws, AWS_DEFAULT_REGION, - CiEnum.github_actions, - AuthenticationEnum.password, ), ( - "pytestgcp", - "dev", - "gcp.nebari.dev", schema.ProviderEnum.gcp, GCP_DEFAULT_REGION, - CiEnum.github_actions, - AuthenticationEnum.password, ), ( - "pytestazure", - "dev", - "azure.nebari.dev", schema.ProviderEnum.azure, AZURE_DEFAULT_REGION, - CiEnum.github_actions, - AuthenticationEnum.password, ), ] ) -def nebari_config_options(request) -> schema.Main: +def nebari_config_options(request): """This fixtures creates a set of nebari configurations for tests""" - DEFAULT_GH_REPO = "github.com/test/test" - DEFAULT_TERRAFORM_STATE = TerraformStateEnum.remote - - ( - project, - namespace, - domain, - cloud_provider, - region, - ci_provider, - auth_provider, - ) = request.param - - return dict( - project_name=project, - namespace=namespace, - nebari_domain=domain, - cloud_provider=cloud_provider, - region=region, - ci_provider=ci_provider, - auth_provider=auth_provider, - repository=DEFAULT_GH_REPO, - repository_auto_provision=False, - auth_auto_provision=False, - terraform_state=DEFAULT_TERRAFORM_STATE, - disable_prompt=True, - ) + cloud_provider, region = request.param + return { + "project_name": "testproject", + "nebari_domain": "test.nebari.dev", + "cloud_provider": cloud_provider, + "region": region, + "ci_provider": CiEnum.github_actions, + "repository": "github.com/test/test", + "disable_prompt": True, + } @pytest.fixture -def nebari_config(nebari_config_options): - return nebari_plugin_manager.config_schema.parse_obj( - render_config(**nebari_config_options) - ) +def nebari_config(nebari_config_options, config_schema): + return config_schema.model_validate(render_config(**nebari_config_options)) @pytest.fixture @@ -207,3 +168,13 @@ def new_upgrade_cls(): @pytest.fixture def config_schema(): return nebari_plugin_manager.config_schema + + +@pytest.fixture +def cli(): + return create_cli() + + +@pytest.fixture(scope="session") +def runner(): + return CliRunner() diff --git a/tests/tests_unit/test_cli.py b/tests/tests_unit/test_cli.py deleted file mode 100644 index d8a4e423b9..0000000000 --- a/tests/tests_unit/test_cli.py +++ /dev/null @@ -1,67 +0,0 @@ -import subprocess - -import pytest - -from _nebari.subcommands.init import InitInputs -from nebari.plugins import nebari_plugin_manager - -PROJECT_NAME = "clitest" -DOMAIN_NAME = "clitest.dev" - - -@pytest.mark.parametrize( - "namespace, auth_provider, ci_provider, ssl_cert_email", - ( - [None, None, None, None], - ["prod", "password", "github-actions", "it@acme.org"], - ), -) -def test_nebari_init(tmp_path, namespace, auth_provider, ci_provider, ssl_cert_email): - """Test `nebari init` CLI command.""" - command = [ - "nebari", - "init", - "local", - f"--project={PROJECT_NAME}", - f"--domain={DOMAIN_NAME}", - "--disable-prompt", - ] - - default_values = InitInputs() - - if namespace: - command.append(f"--namespace={namespace}") - else: - namespace = default_values.namespace - if auth_provider: - command.append(f"--auth-provider={auth_provider}") - else: - auth_provider = default_values.auth_provider - if ci_provider: - command.append(f"--ci-provider={ci_provider}") - else: - ci_provider = default_values.ci_provider - if ssl_cert_email: - command.append(f"--ssl-cert-email={ssl_cert_email}") - else: - ssl_cert_email = default_values.ssl_cert_email - - subprocess.run(command, cwd=tmp_path, check=True) - - config = nebari_plugin_manager.read_config(tmp_path / "nebari-config.yaml") - - assert config.namespace == namespace - assert config.security.authentication.type.lower() == auth_provider - assert config.ci_cd.type == ci_provider - assert config.certificate.acme_email == ssl_cert_email - - -@pytest.mark.parametrize( - "command", - ( - ["nebari", "--version"], - ["nebari", "info"], - ), -) -def test_nebari_commands_no_args(command): - subprocess.run(command, check=True, capture_output=True, text=True).stdout.strip() diff --git a/tests/tests_unit/test_cli_deploy.py b/tests/tests_unit/test_cli_deploy.py index 2a33b4e39e..cb393ed662 100644 --- a/tests/tests_unit/test_cli_deploy.py +++ b/tests/tests_unit/test_cli_deploy.py @@ -1,14 +1,6 @@ -from typer.testing import CliRunner - -from _nebari.cli import create_cli - -runner = CliRunner() - - -def test_dns_option(config_gcp): - app = create_cli() +def test_dns_option(config_gcp, runner, cli): result = runner.invoke( - app, + cli, [ "deploy", "-c", diff --git a/tests/tests_unit/test_cli_dev.py b/tests/tests_unit/test_cli_dev.py index 4a4d58ef22..5c795391d4 100644 --- a/tests/tests_unit/test_cli_dev.py +++ b/tests/tests_unit/test_cli_dev.py @@ -1,15 +1,10 @@ import json -import tempfile -from pathlib import Path from typing import Any, List from unittest.mock import Mock, patch import pytest import requests.exceptions import yaml -from typer.testing import CliRunner - -from _nebari.cli import create_cli TEST_KEYCLOAKAPI_REQUEST = "GET /" # get list of realms @@ -27,8 +22,6 @@ {"id": "master", "realm": "master"}, ] -runner = CliRunner() - @pytest.mark.parametrize( "args, exit_code, content", @@ -47,9 +40,8 @@ (["keycloak-api", "-r"], 2, ["requires an argument"]), ], ) -def test_cli_dev_stdout(args: List[str], exit_code: int, content: List[str]): - app = create_cli() - result = runner.invoke(app, ["dev"] + args) +def test_cli_dev_stdout(runner, cli, args, exit_code, content): + result = runner.invoke(cli, ["dev"] + args) assert result.exit_code == exit_code for c in content: assert c in result.stdout @@ -100,9 +92,9 @@ def mock_api_request( ), ) def test_cli_dev_keycloakapi_happy_path_from_env( - _mock_requests_post, _mock_requests_request + _mock_requests_post, _mock_requests_request, runner, cli, tmp_path ): - result = run_cli_dev(use_env=True) + result = run_cli_dev(runner, cli, tmp_path, use_env=True) assert 0 == result.exit_code assert not result.exception @@ -125,9 +117,9 @@ def test_cli_dev_keycloakapi_happy_path_from_env( ), ) def test_cli_dev_keycloakapi_happy_path_from_config( - _mock_requests_post, _mock_requests_request + _mock_requests_post, _mock_requests_request, runner, cli, tmp_path ): - result = run_cli_dev(use_env=False) + result = run_cli_dev(runner, cli, tmp_path, use_env=False) assert 0 == result.exit_code assert not result.exception @@ -143,8 +135,10 @@ def test_cli_dev_keycloakapi_happy_path_from_config( MOCK_KEYCLOAK_ENV["KEYCLOAK_ADMIN_PASSWORD"], url, headers, data, verify ), ) -def test_cli_dev_keycloakapi_error_bad_request(_mock_requests_post): - result = run_cli_dev(request="malformed") +def test_cli_dev_keycloakapi_error_bad_request( + _mock_requests_post, runner, cli, tmp_path +): + result = run_cli_dev(runner, cli, tmp_path, request="malformed") assert 1 == result.exit_code assert result.exception @@ -157,8 +151,10 @@ def test_cli_dev_keycloakapi_error_bad_request(_mock_requests_post): "invalid_admin_password", url, headers, data, verify ), ) -def test_cli_dev_keycloakapi_error_authentication(_mock_requests_post): - result = run_cli_dev() +def test_cli_dev_keycloakapi_error_authentication( + _mock_requests_post, runner, cli, tmp_path +): + result = run_cli_dev(runner, cli, tmp_path) assert 1 == result.exit_code assert result.exception @@ -179,9 +175,9 @@ def test_cli_dev_keycloakapi_error_authentication(_mock_requests_post): ), ) def test_cli_dev_keycloakapi_error_authorization( - _mock_requests_post, _mock_requests_request + _mock_requests_post, _mock_requests_request, runner, cli, tmp_path ): - result = run_cli_dev() + result = run_cli_dev(runner, cli, tmp_path) assert 1 == result.exit_code assert result.exception @@ -192,62 +188,66 @@ def test_cli_dev_keycloakapi_error_authorization( @patch( "_nebari.keycloak.requests.post", side_effect=requests.exceptions.RequestException() ) -def test_cli_dev_keycloakapi_request_exception(_mock_requests_post): - result = run_cli_dev() +def test_cli_dev_keycloakapi_request_exception( + _mock_requests_post, runner, cli, tmp_path +): + result = run_cli_dev(runner, cli, tmp_path) assert 1 == result.exit_code assert result.exception @patch("_nebari.keycloak.requests.post", side_effect=Exception()) -def test_cli_dev_keycloakapi_unhandled_error(_mock_requests_post): - result = run_cli_dev() +def test_cli_dev_keycloakapi_unhandled_error( + _mock_requests_post, runner, cli, tmp_path +): + result = run_cli_dev(runner, cli, tmp_path) assert 1 == result.exit_code assert result.exception def run_cli_dev( + runner, + cli, + tmp_path, request: str = TEST_KEYCLOAKAPI_REQUEST, use_env: bool = True, extra_args: List[str] = [], ): - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - extra_config = ( - { - "domain": TEST_DOMAIN, - "security": { - "keycloak": { - "initial_root_password": MOCK_KEYCLOAK_ENV[ - "KEYCLOAK_ADMIN_PASSWORD" - ] - } - }, - } - if not use_env - else {} - ) - config = {**{"project_name": "dev"}, **extra_config} - with open(tmp_file.resolve(), "w") as f: - yaml.dump(config, f) - - assert tmp_file.exists() is True - - app = create_cli() - - args = [ - "dev", - "keycloak-api", - "--config", - tmp_file.resolve(), - "--request", - request, - ] + extra_args - - env = MOCK_KEYCLOAK_ENV if use_env else {} - result = runner.invoke(app, args=args, env=env) - - return result + tmp_file = tmp_path.resolve() / "nebari-config.yaml" + assert tmp_file.exists() is False + + extra_config = ( + { + "domain": TEST_DOMAIN, + "security": { + "keycloak": { + "initial_root_password": MOCK_KEYCLOAK_ENV[ + "KEYCLOAK_ADMIN_PASSWORD" + ] + } + }, + } + if not use_env + else {} + ) + config = {**{"project_name": "dev"}, **extra_config} + with tmp_file.open("w") as f: + yaml.dump(config, f) + + assert tmp_file.exists() + + args = [ + "dev", + "keycloak-api", + "--config", + tmp_file.resolve(), + "--request", + request, + ] + extra_args + + env = MOCK_KEYCLOAK_ENV if use_env else {} + result = runner.invoke(cli, args=args, env=env) + + return result diff --git a/tests/tests_unit/test_cli_init.py b/tests/tests_unit/test_cli_init.py index 0cd0fe03d2..3025e37930 100644 --- a/tests/tests_unit/test_cli_init.py +++ b/tests/tests_unit/test_cli_init.py @@ -1,18 +1,10 @@ -import tempfile from collections.abc import MutableMapping -from pathlib import Path -from typing import List import pytest import yaml -from typer import Typer -from typer.testing import CliRunner -from _nebari.cli import create_cli from _nebari.constants import AZURE_DEFAULT_REGION -runner = CliRunner() - MOCK_KUBERNETES_VERSIONS = { "aws": ["1.20"], "azure": ["1.20"], @@ -53,9 +45,8 @@ (["-o"], 2, ["requires an argument"]), ], ) -def test_cli_init_stdout(args: List[str], exit_code: int, content: List[str]): - app = create_cli() - result = runner.invoke(app, ["init"] + args) +def test_cli_init_stdout(runner, cli, args, exit_code, content): + result = runner.invoke(cli, ["init"] + args) assert result.exit_code == exit_code for c in content: assert c in result.stdout @@ -121,18 +112,20 @@ def generate_test_data_test_cli_init_happy_path(): def test_cli_init_happy_path( - provider: str, - region: str, - project_name: str, - domain_name: str, - namespace: str, - auth_provider: str, - ci_provider: str, - terraform_state: str, - email: str, - kubernetes_version: str, + runner, + cli, + provider, + region, + project_name, + domain_name, + namespace, + auth_provider, + ci_provider, + terraform_state, + email, + kubernetes_version, + tmp_path, ): - app = create_cli() args = [ "init", provider, @@ -160,57 +153,39 @@ def test_cli_init_happy_path( region, ] - expected_yaml = f""" - provider: {provider} - namespace: {namespace} - project_name: {project_name} - domain: {domain_name} - ci_cd: - type: {ci_provider} - terraform_state: - type: {terraform_state} - security: - authentication: - type: {auth_provider} - certificate: - type: lets-encrypt - acme_email: {email} - """ + expected = { + "provider": provider, + "namespace": namespace, + "project_name": project_name, + "domain": domain_name, + "ci_cd": {"type": ci_provider}, + "terraform_state": {"type": terraform_state}, + "security": {"authentication": {"type": auth_provider}}, + "certificate": { + "type": "lets-encrypt", + "acme_email": email, + }, + } provider_section = get_provider_section_header(provider) if provider_section != "" and kubernetes_version != "latest": - expected_yaml += f""" - {provider_section}: - kubernetes_version: '{kubernetes_version}' - region: '{region}' - """ - - assert_nebari_init_args(app, args, expected_yaml) - - -def assert_nebari_init_args( - app: Typer, args: List[str], expected_yaml: str, input: str = None -): - """ - Run nebari init with happy path assertions and verify the generated yaml contains - all values in expected_yaml. - """ - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - result = runner.invoke( - app, args + ["--output", tmp_file.resolve()], input=input - ) - - assert not result.exception - assert 0 == result.exit_code - assert tmp_file.exists() is True - - with open(tmp_file.resolve(), "r") as config_yaml: - config = flatten_dict(yaml.safe_load(config_yaml)) - expected = flatten_dict(yaml.safe_load(expected_yaml)) - assert expected.items() <= config.items() + expected[provider_section] = { + "kubernetes_version": kubernetes_version, + "region": region, + } + + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() + + result = runner.invoke(cli, args + ["--output", tmp_file.resolve()]) + assert not result.exception + assert 0 == result.exit_code + assert tmp_file.exists() + + with tmp_file.open() as f: + config = flatten_dict(yaml.safe_load(f)) + expected = flatten_dict(expected) + assert expected.items() <= config.items() def pytest_generate_tests(metafunc): diff --git a/tests/tests_unit/test_cli_init_repository.py b/tests/tests_unit/test_cli_init_repository.py index 6bc0d4e7d4..3aa65a1522 100644 --- a/tests/tests_unit/test_cli_init_repository.py +++ b/tests/tests_unit/test_cli_init_repository.py @@ -1,18 +1,11 @@ import logging -import tempfile -from pathlib import Path from unittest.mock import Mock, patch -import pytest import requests.auth import requests.exceptions -from typer.testing import CliRunner -from _nebari.cli import create_cli from _nebari.provider.cicd.github import GITHUB_BASE_URL -runner = CliRunner() - TEST_GITHUB_USERNAME = "test-nebari-github-user" TEST_GITHUB_TOKEN = "nebari-super-secret" @@ -69,22 +62,21 @@ def test_cli_init_repository_auto_provision( _mock_requests_post, _mock_requests_put, _mock_git, - monkeypatch: pytest.MonkeyPatch, + runner, + cli, + monkeypatch, + tmp_path, ): monkeypatch.setenv("GITHUB_USERNAME", TEST_GITHUB_USERNAME) monkeypatch.setenv("GITHUB_TOKEN", TEST_GITHUB_TOKEN) - app = create_cli() + tmp_file = tmp_path / "nebari-config.yaml" - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False + result = runner.invoke(cli, DEFAULT_ARGS + ["--output", tmp_file.resolve()]) - result = runner.invoke(app, DEFAULT_ARGS + ["--output", tmp_file.resolve()]) - - assert 0 == result.exit_code - assert not result.exception - assert tmp_file.exists() is True + # assert 0 == result.exit_code + assert not result.exception + assert tmp_file.exists() is True @patch( @@ -124,9 +116,12 @@ def test_cli_init_repository_repo_exists( _mock_requests_post, _mock_requests_put, _mock_git, - monkeypatch: pytest.MonkeyPatch, + runner, + cli, + monkeypatch, capsys, caplog, + tmp_path, ): monkeypatch.setenv("GITHUB_USERNAME", TEST_GITHUB_USERNAME) monkeypatch.setenv("GITHUB_TOKEN", TEST_GITHUB_TOKEN) @@ -134,21 +129,18 @@ def test_cli_init_repository_repo_exists( with capsys.disabled(): caplog.set_level(logging.WARNING) - app = create_cli() - - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() - result = runner.invoke(app, DEFAULT_ARGS + ["--output", tmp_file.resolve()]) + result = runner.invoke(cli, DEFAULT_ARGS + ["--output", tmp_file.resolve()]) - assert 0 == result.exit_code - assert not result.exception - assert tmp_file.exists() is True - assert "already exists" in caplog.text + assert 0 == result.exit_code + assert not result.exception + assert tmp_file.exists() + assert "already exists" in caplog.text -def test_cli_init_error_repository_missing_env(monkeypatch: pytest.MonkeyPatch): +def test_cli_init_error_repository_missing_env(runner, cli, monkeypatch, tmp_path): for e in [ "GITHUB_USERNAME", "GITHUB_TOKEN", @@ -158,28 +150,23 @@ def test_cli_init_error_repository_missing_env(monkeypatch: pytest.MonkeyPatch): except Exception as e: pass - app = create_cli() + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False + result = runner.invoke(cli, DEFAULT_ARGS + ["--output", tmp_file.resolve()]) - result = runner.invoke(app, DEFAULT_ARGS + ["--output", tmp_file.resolve()]) + assert 1 == result.exit_code + assert result.exception + assert "Environment variable(s) required for GitHub automation" in str( + result.exception + ) + assert not tmp_file.exists() - assert 1 == result.exit_code - assert result.exception - assert "Environment variable(s) required for GitHub automation" in str( - result.exception - ) - assert tmp_file.exists() is False - -def test_cli_init_error_invalid_repo(monkeypatch: pytest.MonkeyPatch): +def test_cli_init_error_invalid_repo(runner, cli, monkeypatch, tmp_path): monkeypatch.setenv("GITHUB_USERNAME", TEST_GITHUB_USERNAME) monkeypatch.setenv("GITHUB_TOKEN", TEST_GITHUB_TOKEN) - app = create_cli() - args = [ "init", "local", @@ -190,16 +177,15 @@ def test_cli_init_error_invalid_repo(monkeypatch: pytest.MonkeyPatch): "https://notgithub.com", ] - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() - result = runner.invoke(app, args + ["--output", tmp_file.resolve()]) + result = runner.invoke(cli, args + ["--output", tmp_file.resolve()]) - assert 2 == result.exit_code - assert result.exception - assert "repository URL" in str(result.stdout) - assert tmp_file.exists() is False + assert 2 == result.exit_code + assert result.exception + assert "repository URL" in str(result.stdout) + assert not tmp_file.exists() def mock_api_request( diff --git a/tests/tests_unit/test_cli_keycloak.py b/tests/tests_unit/test_cli_keycloak.py index a82c4cd044..4040bf7405 100644 --- a/tests/tests_unit/test_cli_keycloak.py +++ b/tests/tests_unit/test_cli_keycloak.py @@ -57,7 +57,7 @@ (["listusers", "-c"], 2, ["requires an argument"]), ], ) -def test_cli_keycloak_stdout(args: List[str], exit_code: int, content: List[str]): +def test_cli_keycloak_stdout(args, exit_code, content): app = create_cli() result = runner.invoke(app, ["keycloak"] + args) assert result.exit_code == exit_code diff --git a/tests/tests_unit/test_cli_support.py b/tests/tests_unit/test_cli_support.py index 66822d165d..30c2dc85e3 100644 --- a/tests/tests_unit/test_cli_support.py +++ b/tests/tests_unit/test_cli_support.py @@ -1,5 +1,3 @@ -import tempfile -from pathlib import Path from typing import List from unittest.mock import Mock, patch from zipfile import ZipFile @@ -8,11 +6,6 @@ import kubernetes.client.exceptions import pytest import yaml -from typer.testing import CliRunner - -from _nebari.cli import create_cli - -runner = CliRunner() class MockPod: @@ -63,9 +56,8 @@ def mock_read_namespaced_pod_log(name: str, namespace: str, container: str): (["-o"], 2, ["requires an argument"]), ], ) -def test_cli_support_stdout(args: List[str], exit_code: int, content: List[str]): - app = create_cli() - result = runner.invoke(app, ["support"] + args) +def test_cli_support_stdout(runner, cli, args, exit_code, content): + result = runner.invoke(cli, ["support"] + args) assert result.exit_code == exit_code for c in content: assert c in result.stdout @@ -96,59 +88,55 @@ def test_cli_support_stdout(args: List[str], exit_code: int, content: List[str]) ), ) def test_cli_support_happy_path( - _mock_k8s_corev1api, _mock_config, monkeypatch: pytest.MonkeyPatch + _mock_k8s_corev1api, _mock_config, runner, cli, monkeypatch, tmp_path ): - with tempfile.TemporaryDirectory() as tmp: - # NOTE: The support command leaves the ./log folder behind after running, - # relative to wherever the tests were run from. - # Changing context to the tmp dir so this will be cleaned up properly. - monkeypatch.chdir(Path(tmp).resolve()) - - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - with open(tmp_file.resolve(), "w") as f: - yaml.dump({"project_name": "support", "namespace": "test-ns"}, f) - - assert tmp_file.exists() is True - - app = create_cli() - - log_zip_file = Path(tmp).resolve() / "test-support.zip" - assert log_zip_file.exists() is False - - result = runner.invoke( - app, - [ - "support", - "--config", - tmp_file.resolve(), - "--output", - log_zip_file.resolve(), - ], - ) + # NOTE: The support command leaves the ./log folder behind after running, + # relative to wherever the tests were run from. + # Changing context to the tmp dir so this will be cleaned up properly. + monkeypatch.chdir(tmp_path) + + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() + + with tmp_file.open("w") as f: + yaml.dump({"project_name": "support", "namespace": "test-ns"}, f) + assert tmp_file.exists() + + log_zip_file = tmp_path / "test-support.zip" + assert not log_zip_file.exists() + + result = runner.invoke( + cli, + [ + "support", + "--config", + tmp_file.resolve(), + "--output", + log_zip_file.resolve(), + ], + ) - assert log_zip_file.exists() is True + assert log_zip_file.exists() - assert 0 == result.exit_code - assert not result.exception - assert "log/test-ns" in result.stdout + assert 0 == result.exit_code + assert not result.exception + assert "log/test-ns" in result.stdout - # open the zip and check a sample file for the expected formatting - with ZipFile(log_zip_file.resolve(), "r") as log_zip: - # expect 1 log file per pod - assert 2 == len(log_zip.namelist()) - with log_zip.open("log/test-ns/pod-1.txt") as log_file: - content = str(log_file.read(), "UTF-8") - # expect formatted header + logs for each container - expected = """ + # open the zip and check a sample file for the expected formatting + with ZipFile(log_zip_file.resolve(), "r") as log_zip: + # expect 1 log file per pod + assert 2 == len(log_zip.namelist()) + with log_zip.open("log/test-ns/pod-1.txt") as log_file: + content = str(log_file.read(), "UTF-8") + # expect formatted header + logs for each container + expected = """ 10.0.0.1\ttest-ns\tpod-1 Container: container-1-1 Test log entry: pod-1 -- test-ns -- container-1-1 Container: container-1-2 Test log entry: pod-1 -- test-ns -- container-1-2 """ - assert expected.strip() == content.strip() + assert expected.strip() == content.strip() @patch("kubernetes.config.kube_config.load_kube_config", return_value=Mock()) @@ -161,50 +149,44 @@ def test_cli_support_happy_path( ), ) def test_cli_support_error_apiexception( - _mock_k8s_corev1api, _mock_config, monkeypatch: pytest.MonkeyPatch + _mock_k8s_corev1api, _mock_config, runner, cli, monkeypatch, tmp_path ): - with tempfile.TemporaryDirectory() as tmp: - monkeypatch.chdir(Path(tmp).resolve()) + monkeypatch.chdir(tmp_path) - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() - with open(tmp_file.resolve(), "w") as f: - yaml.dump({"project_name": "support", "namespace": "test-ns"}, f) + with tmp_file.open("w") as f: + yaml.dump({"project_name": "support", "namespace": "test-ns"}, f) - assert tmp_file.exists() is True + assert tmp_file.exists() is True - app = create_cli() + log_zip_file = tmp_path / "test-support.zip" - log_zip_file = Path(tmp).resolve() / "test-support.zip" - - result = runner.invoke( - app, - [ - "support", - "--config", - tmp_file.resolve(), - "--output", - log_zip_file.resolve(), - ], - ) - - assert log_zip_file.exists() is False + result = runner.invoke( + cli, + [ + "support", + "--config", + tmp_file.resolve(), + "--output", + log_zip_file.resolve(), + ], + ) - assert 1 == result.exit_code - assert result.exception - assert "Reason: unit testing" in str(result.exception) + assert not log_zip_file.exists() + assert 1 == result.exit_code + assert result.exception + assert "Reason: unit testing" in str(result.exception) -def test_cli_support_error_missing_config(): - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - app = create_cli() +def test_cli_support_error_missing_config(runner, cli, tmp_path): + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() - result = runner.invoke(app, ["support", "--config", tmp_file.resolve()]) + result = runner.invoke(cli, ["support", "--config", tmp_file.resolve()]) - assert 1 == result.exit_code - assert result.exception - assert "nebari-config.yaml does not exist" in str(result.exception) + assert 1 == result.exit_code + assert result.exception + assert "nebari-config.yaml does not exist" in str(result.exception) diff --git a/tests/tests_unit/test_cli_upgrade.py b/tests/tests_unit/test_cli_upgrade.py index aa79838bee..c4a750dfce 100644 --- a/tests/tests_unit/test_cli_upgrade.py +++ b/tests/tests_unit/test_cli_upgrade.py @@ -1,14 +1,11 @@ import re -import tempfile from pathlib import Path from typing import Any, Dict, List import pytest import yaml -from typer.testing import CliRunner import _nebari.upgrade -from _nebari.cli import create_cli from _nebari.constants import AZURE_DEFAULT_REGION from _nebari.upgrade import UPGRADE_KUBERNETES_MESSAGE from _nebari.utils import get_provider_config_block_name @@ -53,8 +50,6 @@ class Test_Cli_Upgrade_2023_5_1(_nebari.upgrade.UpgradeStep): ### end dummy upgrade classes -runner = CliRunner() - @pytest.mark.parametrize( "args, exit_code, content", @@ -74,28 +69,36 @@ class Test_Cli_Upgrade_2023_5_1(_nebari.upgrade.UpgradeStep): ), ], ) -def test_cli_upgrade_stdout(args: List[str], exit_code: int, content: List[str]): - app = create_cli() - result = runner.invoke(app, ["upgrade"] + args) +def test_cli_upgrade_stdout(runner, cli, args, exit_code, content): + result = runner.invoke(cli, ["upgrade"] + args) assert result.exit_code == exit_code for c in content: assert c in result.stdout -def test_cli_upgrade_2022_10_1_to_2022_11_1(monkeypatch: pytest.MonkeyPatch): - assert_nebari_upgrade_success(monkeypatch, "2022.10.1", "2022.11.1") +def test_cli_upgrade_2022_10_1_to_2022_11_1(runner, cli, monkeypatch, tmp_path): + assert_nebari_upgrade_success( + runner, cli, tmp_path, monkeypatch, "2022.10.1", "2022.11.1" + ) -def test_cli_upgrade_2022_11_1_to_2023_1_1(monkeypatch: pytest.MonkeyPatch): - assert_nebari_upgrade_success(monkeypatch, "2022.11.1", "2023.1.1") +def test_cli_upgrade_2022_11_1_to_2023_1_1(runner, cli, monkeypatch, tmp_path): + assert_nebari_upgrade_success( + runner, cli, tmp_path, monkeypatch, "2022.11.1", "2023.1.1" + ) -def test_cli_upgrade_2023_1_1_to_2023_4_1(monkeypatch: pytest.MonkeyPatch): - assert_nebari_upgrade_success(monkeypatch, "2023.1.1", "2023.4.1") +def test_cli_upgrade_2023_1_1_to_2023_4_1(runner, cli, monkeypatch, tmp_path): + assert_nebari_upgrade_success( + runner, cli, tmp_path, monkeypatch, "2023.1.1", "2023.4.1" + ) -def test_cli_upgrade_2023_4_1_to_2023_5_1(monkeypatch: pytest.MonkeyPatch): +def test_cli_upgrade_2023_4_1_to_2023_5_1(runner, cli, monkeypatch, tmp_path): assert_nebari_upgrade_success( + runner, + cli, + tmp_path, monkeypatch, "2023.4.1", "2023.5.1", @@ -108,11 +111,9 @@ def test_cli_upgrade_2023_4_1_to_2023_5_1(monkeypatch: pytest.MonkeyPatch): "provider", ["aws", "azure", "do", "gcp"], ) -def test_cli_upgrade_2023_5_1_to_2023_7_1( - monkeypatch: pytest.MonkeyPatch, provider: str -): +def test_cli_upgrade_2023_5_1_to_2023_7_1(runner, cli, monkeypatch, provider, tmp_path): config = assert_nebari_upgrade_success( - monkeypatch, "2023.5.1", "2023.7.1", provider=provider + runner, cli, tmp_path, monkeypatch, "2023.5.1", "2023.7.1", provider=provider ) prevent_deploy = config.get("prevent_deploy") if provider == "aws": @@ -126,9 +127,12 @@ def test_cli_upgrade_2023_5_1_to_2023_7_1( [(True, True), (True, False), (False, None), (None, None)], ) def test_cli_upgrade_2023_7_1_to_2023_7_2( - monkeypatch: pytest.MonkeyPatch, - workflows_enabled: bool, - workflow_controller_enabled: bool, + runner, + cli, + tmp_path, + monkeypatch, + workflows_enabled, + workflow_controller_enabled, ): addl_config = {} inputs = [] @@ -139,6 +143,9 @@ def test_cli_upgrade_2023_7_1_to_2023_7_2( inputs.append("y" if workflow_controller_enabled else "n") upgraded = assert_nebari_upgrade_success( + runner, + cli, + tmp_path, monkeypatch, "2023.7.1", "2023.7.2", @@ -164,41 +171,58 @@ def test_cli_upgrade_2023_7_1_to_2023_7_2( assert "argo_workflows" not in upgraded -def test_cli_upgrade_image_tags(monkeypatch: pytest.MonkeyPatch): +def test_cli_upgrade_image_tags(runner, cli, monkeypatch, tmp_path): start_version = "2023.5.1" end_version = "2023.7.1" + addl_config = { + "default_images": { + "jupyterhub": f"quay.io/nebari/nebari-jupyterhub:{end_version}", + "jupyterlab": f"quay.io/nebari/nebari-jupyterlab:{end_version}", + "dask_worker": f"quay.io/nebari/nebari-dask-worker:{end_version}", + }, + "profiles": { + "jupyterlab": [ + { + "display_name": "base", + "kubespawner_override": { + "image": f"quay.io/nebari/nebari-jupyterlab:{end_version}" + }, + }, + { + "display_name": "gpu", + "kubespawner_override": { + "image": f"quay.io/nebari/nebari-jupyterlab-gpu:{end_version}" + }, + }, + { + "display_name": "any-other-version", + "kubespawner_override": { + "image": "quay.io/nebari/nebari-jupyterlab:1955.11.5" + }, + }, + { + "display_name": "leave-me-alone", + "kubespawner_override": { + "image": f"quay.io/nebari/leave-me-alone:{start_version}" + }, + }, + ], + "dask_worker": { + "test": {"image": f"quay.io/nebari/nebari-dask-worker:{end_version}"} + }, + }, + } upgraded = assert_nebari_upgrade_success( + runner, + cli, + tmp_path, monkeypatch, start_version, end_version, # # number of "y" inputs directly corresponds to how many matching images are found in yaml inputs=["y", "y", "y", "y", "y", "y", "y"], - addl_config=yaml.safe_load( - f""" -default_images: - jupyterhub: quay.io/nebari/nebari-jupyterhub:{start_version} - jupyterlab: quay.io/nebari/nebari-jupyterlab:{start_version} - dask_worker: quay.io/nebari/nebari-dask-worker:{start_version} -profiles: - jupyterlab: - - display_name: base - kubespawner_override: - image: quay.io/nebari/nebari-jupyterlab:{start_version} - - display_name: gpu - kubespawner_override: - image: quay.io/nebari/nebari-jupyterlab-gpu:{start_version} - - display_name: any-other-version - kubespawner_override: - image: quay.io/nebari/nebari-jupyterlab:1955.11.5 - - display_name: leave-me-alone - kubespawner_override: - image: quay.io/nebari/leave-me-alone:{start_version} - dask_worker: - test: - image: quay.io/nebari/nebari-dask-worker:{start_version} -""" - ), + addl_config=addl_config, ) for _, v in upgraded["default_images"].items(): @@ -216,101 +240,71 @@ def test_cli_upgrade_image_tags(monkeypatch: pytest.MonkeyPatch): assert profile["image"].endswith(end_version) -def test_cli_upgrade_fail_on_missing_file(): - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - app = create_cli() - - result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()]) - - assert 1 == result.exit_code - assert result.exception - assert ( - f"passed in configuration filename={tmp_file.resolve()} must exist" - in str(result.exception) - ) - - -def test_cli_upgrade_fail_on_downgrade(): - start_version = "9999.9.9" # way in the future - end_version = _nebari.upgrade.__version__ - - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = yaml.safe_load( - f""" -project_name: test -provider: local -domain: test.example.com -namespace: dev -nebari_version: {start_version} - """ - ) - - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) - - assert tmp_file.exists() is True - app = create_cli() +def test_cli_upgrade_fail_on_missing_file(runner, cli, tmp_path): + tmp_file = tmp_path / "nebari-config.yaml" - result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()]) + result = runner.invoke(cli, ["upgrade", "--config", tmp_file.resolve()]) - assert 1 == result.exit_code - assert result.exception - assert ( - f"already belongs to a later version ({start_version}) than the installed version of Nebari ({end_version})" - in str(result.exception) - ) - - # make sure the file is unaltered - with open(tmp_file.resolve(), "r") as c: - assert yaml.safe_load(c) == nebari_config + assert 1 == result.exit_code + assert result.exception + assert f"passed in configuration filename={tmp_file.resolve()} must exist" in str( + result.exception + ) -def test_cli_upgrade_does_nothing_on_same_version(): +def test_cli_upgrade_does_nothing_on_same_version(runner, cli, tmp_path): # this test only seems to work against the actual current version, any # mocked earlier versions trigger an actual update start_version = _nebari.upgrade.__version__ + tmp_file = tmp_path / "nebari-config.yaml" + nebari_config = { + "project_name": "test", + "provider": "local", + "domain": "test.example.com", + "namespace": "dev", + "nebari_version": start_version, + } - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = yaml.safe_load( - f""" -project_name: test -provider: local -domain: test.example.com -namespace: dev -nebari_version: {start_version} - """ - ) - - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) + with tmp_file.open("w") as f: + yaml.dump(nebari_config, f) - assert tmp_file.exists() is True - app = create_cli() + assert tmp_file.exists() - result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()]) + result = runner.invoke(cli, ["upgrade", "--config", tmp_file.resolve()]) - # feels like this should return a non-zero exit code if the upgrade is not happening - assert 0 == result.exit_code - assert not result.exception - assert "up-to-date" in result.stdout + # feels like this should return a non-zero exit code if the upgrade is not happening + assert 0 == result.exit_code + assert not result.exception + assert "up-to-date" in result.stdout - # make sure the file is unaltered - with open(tmp_file.resolve(), "r") as c: - assert yaml.safe_load(c) == nebari_config + # make sure the file is unaltered + with tmp_file.open() as f: + assert yaml.safe_load(f) == nebari_config -def test_cli_upgrade_0_3_12_to_0_4_0(monkeypatch: pytest.MonkeyPatch): +def test_cli_upgrade_0_3_12_to_0_4_0(runner, cli, monkeypatch, tmp_path): start_version = "0.3.12" end_version = "0.4.0" + addl_config = { + "security": { + "authentication": { + "type": "custom", + "config": { + "oauth_callback_url": "", + "scope": "", + }, + }, + "users": {}, + "groups": { + "users": {}, + }, + }, + "terraform_modules": [], + "default_images": { + "conda_store": "", + "dask_gateway": "", + }, + } def callback(tmp_file: Path, _result: Any): users_import_file = tmp_file.parent / "nebari-users-import.json" @@ -320,27 +314,14 @@ def callback(tmp_file: Path, _result: Any): # custom authenticators removed in 0.4.0, should be replaced by password upgraded = assert_nebari_upgrade_success( + runner, + cli, + tmp_path, monkeypatch, start_version, end_version, addl_args=["--attempt-fixes"], - addl_config=yaml.safe_load( - """ -security: - authentication: - type: custom - config: - oauth_callback_url: "" - scope: "" - users: {} - groups: - users: {} -terraform_modules: [] -default_images: - conda_store: "" - dask_gateway: "" -""" - ), + addl_config=addl_config, callback=callback, ) @@ -355,61 +336,62 @@ def callback(tmp_file: Path, _result: Any): assert True is upgraded["prevent_deploy"] -def test_cli_upgrade_to_0_4_0_fails_for_custom_auth_without_attempt_fixes(): +def test_cli_upgrade_to_0_4_0_fails_for_custom_auth_without_attempt_fixes( + runner, cli, tmp_path +): start_version = "0.3.12" + tmp_file = tmp_path / "nebari-config.yaml" + nebari_config = { + "project_name": "test", + "provider": "local", + "domain": "test.example.com", + "namespace": "dev", + "nebari_version": start_version, + "security": { + "authentication": { + "type": "custom", + }, + }, + } - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = yaml.safe_load( - f""" -project_name: test -provider: local -domain: test.example.com -namespace: dev -nebari_version: {start_version} -security: - authentication: - type: custom - """ - ) - - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) + with tmp_file.open("w") as f: + yaml.dump(nebari_config, f) - assert tmp_file.exists() is True - app = create_cli() + assert tmp_file.exists() - result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()]) + result = runner.invoke(cli, ["upgrade", "--config", tmp_file.resolve()]) - assert 1 == result.exit_code - assert result.exception - assert "Custom Authenticators are no longer supported" in str(result.exception) + assert 1 == result.exit_code + assert result.exception + assert "Custom Authenticators are no longer supported" in str(result.exception) - # make sure the file is unaltered - with open(tmp_file.resolve(), "r") as c: - assert yaml.safe_load(c) == nebari_config + # make sure the file is unaltered + with tmp_file.open() as f: + assert yaml.safe_load(f) == nebari_config @pytest.mark.skipif( rounded_ver_parse(_nebari.upgrade.__version__) < rounded_ver_parse("2023.10.1"), reason="This test is only valid for versions >= 2023.10.1", ) -def test_cli_upgrade_to_2023_10_1_cdsdashboard_removed(monkeypatch: pytest.MonkeyPatch): +def test_cli_upgrade_to_2023_10_1_cdsdashboard_removed( + runner, cli, monkeypatch, tmp_path +): start_version = "2023.7.2" end_version = "2023.10.1" - addl_config = yaml.safe_load( - """ -cdsdashboards: - enabled: true - cds_hide_user_named_servers: true - cds_hide_user_dashboard_servers: false - """ - ) + addl_config = { + "cdsdashboards": { + "enabled": True, + "cds_hide_user_named_servers": True, + "cds_hide_user_dashboard_servers": False, + } + } upgraded = assert_nebari_upgrade_success( + runner, + cli, + tmp_path, monkeypatch, start_version, end_version, @@ -443,7 +425,7 @@ def test_cli_upgrade_to_2023_10_1_cdsdashboard_removed(monkeypatch: pytest.Monke ], ) def test_cli_upgrade_to_2023_10_1_kubernetes_validations( - monkeypatch: pytest.MonkeyPatch, provider: str, k8s_status: str + runner, cli, monkeypatch, provider, k8s_status, tmp_path ): start_version = "2023.7.2" end_version = "2023.10.1" @@ -460,62 +442,60 @@ def test_cli_upgrade_to_2023_10_1_kubernetes_validations( "gcp": {"incompatible": "1.23", "compatible": "1.26", "invalid": "badname"}, } - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = yaml.safe_load( - f""" -project_name: test -provider: {provider} -domain: test.example.com -namespace: dev -nebari_version: {start_version} -cdsdashboards: - enabled: true - cds_hide_user_named_servers: true - cds_hide_user_dashboard_servers: false -{get_provider_config_block_name(provider)}: - region: {MOCK_CLOUD_REGIONS.get(provider, {})[0]} - kubernetes_version: {kubernetes_configs[provider][k8s_status]} - """ - ) - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) + tmp_file = tmp_path / "nebari-config.yaml" + + nebari_config = { + "project_name": "test", + "provider": provider, + "domain": "test.example.com", + "namespace": "dev", + "nebari_version": start_version, + "cdsdashboards": { + "enabled": True, + "cds_hide_user_named_servers": True, + "cds_hide_user_dashboard_servers": False, + }, + get_provider_config_block_name(provider): { + "region": MOCK_CLOUD_REGIONS.get(provider, {})[0], + "kubernetes_version": kubernetes_configs[provider][k8s_status], + }, + } - assert tmp_file.exists() is True - app = create_cli() + if provider == "gcp": + nebari_config["google_cloud_platform"]["project"] = "test-project" - result = runner.invoke(app, ["upgrade", "--config", tmp_file.resolve()]) + with tmp_file.open("w") as f: + yaml.dump(nebari_config, f) - if k8s_status == "incompatible": - UPGRADE_KUBERNETES_MESSAGE_WO_BRACKETS = re.sub( - r"\[.*?\]", "", UPGRADE_KUBERNETES_MESSAGE - ) - assert UPGRADE_KUBERNETES_MESSAGE_WO_BRACKETS in result.stdout.replace( - "\n", "" - ) + result = runner.invoke(cli, ["upgrade", "--config", tmp_file.resolve()]) - if k8s_status == "compatible": - assert 0 == result.exit_code - assert not result.exception - assert "Saving new config file" in result.stdout + if k8s_status == "incompatible": + UPGRADE_KUBERNETES_MESSAGE_WO_BRACKETS = re.sub( + r"\[.*?\]", "", UPGRADE_KUBERNETES_MESSAGE + ) + assert UPGRADE_KUBERNETES_MESSAGE_WO_BRACKETS in result.stdout.replace("\n", "") - # load the modified nebari-config.yaml and check the new version has changed - with open(tmp_file.resolve(), "r") as f: - upgraded = yaml.safe_load(f) - assert end_version == upgraded["nebari_version"] + if k8s_status == "compatible": + assert 0 == result.exit_code + assert not result.exception + assert "Saving new config file" in result.stdout - if k8s_status == "invalid": - assert ( - "Unable to detect Kubernetes version for provider {}".format( - provider - ) - in result.stdout - ) + # load the modified nebari-config.yaml and check the new version has changed + with tmp_file.open() as f: + upgraded = yaml.safe_load(f) + assert end_version == upgraded["nebari_version"] + + if k8s_status == "invalid": + assert ( + f"Unable to detect Kubernetes version for provider {provider}" + in result.stdout + ) def assert_nebari_upgrade_success( + runner, + cli, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, start_version: str, end_version: str, @@ -528,65 +508,57 @@ def assert_nebari_upgrade_success( monkeypatch.setattr(_nebari.upgrade, "__version__", end_version) # create a tmp dir and clean up when done - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - # merge basic config with any test case specific values provided - nebari_config = { - **yaml.safe_load( - f""" -project_name: test -provider: {provider} -domain: test.example.com -namespace: dev -nebari_version: {start_version} - """ - ), - **addl_config, - } - - # write the test nebari-config.yaml file to tmp location - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) - - assert tmp_file.exists() is True - app = create_cli() + tmp_file = tmp_path / "nebari-config.yaml" + assert not tmp_file.exists() + + # merge basic config with any test case specific values provided + nebari_config = { + "project_name": "test", + "provider": provider, + "domain": "test.example.com", + "namespace": "dev", + "nebari_version": start_version, + **addl_config, + } - if inputs is not None and len(inputs) > 0: - inputs.append("") # trailing newline for last input + # write the test nebari-config.yaml file to tmp location + with tmp_file.open("w") as f: + yaml.dump(nebari_config, f) - # run nebari upgrade -c tmp/nebari-config.yaml - result = runner.invoke( - app, - ["upgrade", "--config", tmp_file.resolve()] + addl_args, - input="\n".join(inputs), - ) + assert tmp_file.exists() - enable_default_assertions = True + if inputs is not None and len(inputs) > 0: + inputs.append("") # trailing newline for last input - if callback is not None: - enable_default_assertions = callback(tmp_file, result) + # run nebari upgrade -c tmp/nebari-config.yaml + result = runner.invoke( + cli, + ["upgrade", "--config", tmp_file.resolve()] + addl_args, + input="\n".join(inputs), + ) - if enable_default_assertions: - assert 0 == result.exit_code - assert not result.exception - assert "Saving new config file" in result.stdout + enable_default_assertions = True - # load the modified nebari-config.yaml and check the new version has changed - with open(tmp_file.resolve(), "r") as f: - upgraded = yaml.safe_load(f) - assert end_version == upgraded["nebari_version"] + if callback is not None: + enable_default_assertions = callback(tmp_file, result) - # check backup matches original - backup_file = ( - Path(tmp).resolve() / f"nebari-config.yaml.{start_version}.backup" - ) - assert backup_file.exists() is True - with open(backup_file.resolve(), "r") as b: - backup = yaml.safe_load(b) - assert backup == nebari_config - - # pass the parsed nebari-config.yaml with upgrade mods back to caller for - # additional assertions - return upgraded + if enable_default_assertions: + assert 0 == result.exit_code + assert not result.exception + assert "Saving new config file" in result.stdout + + # load the modified nebari-config.yaml and check the new version has changed + with tmp_file.open() as f: + upgraded = yaml.safe_load(f) + assert end_version == upgraded["nebari_version"] + + # check backup matches original + backup_file = tmp_path / f"nebari-config.yaml.{start_version}.backup" + assert backup_file.exists() + with backup_file.open() as b: + backup = yaml.safe_load(b) + assert backup == nebari_config + + # pass the parsed nebari-config.yaml with upgrade mods back to caller for + # additional assertions + return upgraded diff --git a/tests/tests_unit/test_cli_validate.py b/tests/tests_unit/test_cli_validate.py index 00c46c2cd6..81e65ac166 100644 --- a/tests/tests_unit/test_cli_validate.py +++ b/tests/tests_unit/test_cli_validate.py @@ -1,22 +1,16 @@ import re import shutil -import tempfile from pathlib import Path -from typing import Any, Dict, List import pytest import yaml -from typer.testing import CliRunner from _nebari._version import __version__ -from _nebari.cli import create_cli TEST_DATA_DIR = Path(__file__).resolve().parent / "cli_validate" -runner = CliRunner() - -def _update_yaml_file(file_path: Path, key: str, value: Any): +def _update_yaml_file(file_path, key, value): """Utility function to update a yaml file with a new key/value pair.""" with open(file_path, "r") as f: yaml_data = yaml.safe_load(f) @@ -44,9 +38,8 @@ def _update_yaml_file(file_path: Path, key: str, value: Any): ), # https://github.com/nebari-dev/nebari/issues/1937 ], ) -def test_cli_validate_stdout(args: List[str], exit_code: int, content: List[str]): - app = create_cli() - result = runner.invoke(app, ["validate"] + args) +def test_cli_validate_stdout(runner, cli, args, exit_code, content): + result = runner.invoke(cli, ["validate"] + args) assert result.exit_code == exit_code for c in content: assert c in result.stdout @@ -71,70 +64,66 @@ def generate_test_data_test_cli_validate_local_happy_path(): return {"keys": keys, "test_data": test_data} -def test_cli_validate_local_happy_path(config_yaml: str): - test_file = TEST_DATA_DIR / config_yaml +def test_cli_validate_local_happy_path(runner, cli, config_yaml, config_path, tmp_path): + test_file = config_path / config_yaml assert test_file.exists() is True - with tempfile.TemporaryDirectory() as tmpdirname: - temp_test_file = shutil.copy(test_file, tmpdirname) - - # update the copied test file with the current version if necessary - _update_yaml_file(temp_test_file, "nebari_version", __version__) - - app = create_cli() - result = runner.invoke(app, ["validate", "--config", temp_test_file]) - assert not result.exception - assert 0 == result.exit_code - assert "Successfully validated configuration" in result.stdout + temp_test_file = shutil.copy(test_file, tmp_path) + # update the copied test file with the current version if necessary + _update_yaml_file(temp_test_file, "nebari_version", __version__) -def test_cli_validate_from_env(): - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = yaml.safe_load( - """ -provider: aws -project_name: test -amazon_web_services: - region: us-east-1 - kubernetes_version: '1.19' - """ - ) + result = runner.invoke(cli, ["validate", "--config", temp_test_file]) + assert not result.exception + assert 0 == result.exit_code + assert "Successfully validated configuration" in result.stdout - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) - assert tmp_file.exists() is True - app = create_cli() +def test_cli_validate_from_env(runner, cli, tmp_path): + tmp_file = tmp_path / "nebari-config.yaml" - valid_result = runner.invoke( - app, - ["validate", "--config", tmp_file.resolve()], - env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.20"}, - ) + nebari_config = { + "provider": "aws", + "project_name": "test", + "amazon_web_services": { + "region": "us-east-1", + "kubernetes_version": "1.19", + }, + } - assert 0 == valid_result.exit_code - assert not valid_result.exception - assert "Successfully validated configuration" in valid_result.stdout + with tmp_file.open("w") as f: + yaml.dump(nebari_config, f) - invalid_result = runner.invoke( - app, - ["validate", "--config", tmp_file.resolve()], - env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.0"}, - ) + valid_result = runner.invoke( + cli, + ["validate", "--config", tmp_file.resolve()], + env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.18"}, + ) + assert 0 == valid_result.exit_code + assert not valid_result.exception + assert "Successfully validated configuration" in valid_result.stdout - assert 1 == invalid_result.exit_code - assert invalid_result.exception - assert "Invalid `kubernetes-version`" in invalid_result.stdout + invalid_result = runner.invoke( + cli, + ["validate", "--config", tmp_file.resolve()], + env={"NEBARI_SECRET__amazon_web_services__kubernetes_version": "1.0"}, + ) + assert 1 == invalid_result.exit_code + assert invalid_result.exception + assert "Invalid `kubernetes-version`" in invalid_result.stdout @pytest.mark.parametrize( "key, value, provider, expected_message, addl_config", [ ("NEBARI_SECRET__project_name", "123invalid", "local", "validation error", {}), - ("NEBARI_SECRET__this_is_an_error", "true", "local", "object has no field", {}), + ( + "NEBARI_SECRET__this_is_an_error", + "true", + "local", + "Object has no attribute", + {}, + ), ( "NEBARI_SECRET__amazon_web_services__kubernetes_version", "1.0", @@ -150,137 +139,42 @@ def test_cli_validate_from_env(): ], ) def test_cli_validate_error_from_env( - key: str, - value: str, - provider: str, - expected_message: str, - addl_config: Dict[str, Any], + runner, + cli, + key, + value, + provider, + expected_message, + addl_config, + tmp_path, ): - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = { - **yaml.safe_load( - f""" -provider: {provider} -project_name: test - """ - ), - **addl_config, - } - - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) - - assert tmp_file.exists() is True - app = create_cli() - - # confirm the file is otherwise valid without environment variable overrides - pre = runner.invoke(app, ["validate", "--config", tmp_file.resolve()]) - assert 0 == pre.exit_code - assert not pre.exception - - # run validate again with environment variables that are expected to trigger - # validation errors - result = runner.invoke( - app, ["validate", "--config", tmp_file.resolve()], env={key: value} - ) + tmp_file = tmp_path / "nebari-config.yaml" - assert 1 == result.exit_code - assert result.exception - assert expected_message in result.stdout + nebari_config = { + "provider": provider, + "project_name": "test", + } + nebari_config.update(addl_config) + with tmp_file.open("w") as f: + yaml.dump(nebari_config, f) -@pytest.mark.parametrize( - "provider, addl_config", - [ - ( - "aws", - { - "amazon_web_services": { - "kubernetes_version": "1.20", - "region": "us-east-1", - } - }, - ), - ("azure", {"azure": {"kubernetes_version": "1.20", "region": "Central US"}}), - ( - "gcp", - { - "google_cloud_platform": { - "kubernetes_version": "1.20", - "region": "us-east1", - "project": "test", - } - }, - ), - ("do", {"digital_ocean": {"kubernetes_version": "1.20", "region": "nyc3"}}), - pytest.param( - "local", - {"security": {"authentication": {"type": "Auth0"}}}, - id="auth-provider-auth0", - ), - pytest.param( - "local", - {"security": {"authentication": {"type": "GitHub"}}}, - id="auth-provider-github", - ), - ], -) -def test_cli_validate_error_missing_cloud_env( - monkeypatch: pytest.MonkeyPatch, provider: str, addl_config: Dict[str, Any] -): - # cloud methods are all globally mocked, need to reset so the env variables will be checked - monkeypatch.undo() - for e in [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "GOOGLE_CREDENTIALS", - "PROJECT_ID", - "ARM_SUBSCRIPTION_ID", - "ARM_TENANT_ID", - "ARM_CLIENT_ID", - "ARM_CLIENT_SECRET", - "DIGITALOCEAN_TOKEN", - "SPACES_ACCESS_KEY_ID", - "SPACES_SECRET_ACCESS_KEY", - "AUTH0_CLIENT_ID", - "AUTH0_CLIENT_SECRET", - "AUTH0_DOMAIN", - "GITHUB_CLIENT_ID", - "GITHUB_CLIENT_SECRET", - ]: - try: - monkeypatch.delenv(e) - except Exception: - pass - - with tempfile.TemporaryDirectory() as tmp: - tmp_file = Path(tmp).resolve() / "nebari-config.yaml" - assert tmp_file.exists() is False - - nebari_config = { - **yaml.safe_load( - f""" -provider: {provider} -project_name: test - """ - ), - **addl_config, - } - - with open(tmp_file.resolve(), "w") as f: - yaml.dump(nebari_config, f) - - assert tmp_file.exists() is True - app = create_cli() - - result = runner.invoke(app, ["validate", "--config", tmp_file.resolve()]) - - assert 1 == result.exit_code - assert result.exception - assert "Missing the following required environment variable" in result.stdout + assert tmp_file.exists() + + # confirm the file is otherwise valid without environment variable overrides + pre = runner.invoke(cli, ["validate", "--config", tmp_file.resolve()]) + assert 0 == pre.exit_code + assert not pre.exception + + # run validate again with environment variables that are expected to trigger + # validation errors + result = runner.invoke( + cli, ["validate", "--config", tmp_file.resolve()], env={key: value} + ) + + assert 1 == result.exit_code + assert result.exception + assert expected_message in result.stdout def generate_test_data_test_cli_validate_error(): @@ -309,21 +203,20 @@ def generate_test_data_test_cli_validate_error(): return {"keys": keys, "test_data": test_data} -def test_cli_validate_error(config_yaml: str, expected_message: str): - test_file = TEST_DATA_DIR / config_yaml +def test_cli_validate_error(runner, cli, config_yaml, config_path, expected_message): + test_file = config_path / config_yaml assert test_file.exists() is True - app = create_cli() - result = runner.invoke(app, ["validate", "--config", test_file]) + result = runner.invoke(cli, ["validate", "--config", test_file]) assert result.exception assert 1 == result.exit_code assert "ERROR validating configuration" in result.stdout if expected_message: # since this will usually come from a parsed filename, assume spacing/hyphenation/case is optional - assert (expected_message in result.stdout.lower()) or ( - expected_message.replace("-", " ").replace("_", " ") - in result.stdout.lower() + actual_message = result.stdout.lower().replace("\n", "") + assert (expected_message in actual_message) or ( + expected_message.replace("-", " ").replace("_", " ") in actual_message ) diff --git a/tests/tests_unit/test_config.py b/tests/tests_unit/test_config.py index ccc52543d7..bf01d703e9 100644 --- a/tests/tests_unit/test_config.py +++ b/tests/tests_unit/test_config.py @@ -1,7 +1,10 @@ import os import pathlib +from typing import Optional import pytest +import yaml +from pydantic import BaseModel from _nebari.config import ( backup_configuration, @@ -12,6 +15,23 @@ ) +def test_parse_env_config(monkeypatch): + keyword = "NEBARI_SECRET__amazon_web_services__kubernetes_version" + value = "1.20" + monkeypatch.setenv(keyword, value) + + class DummyAWSModel(BaseModel): + kubernetes_version: Optional[str] = None + + class DummmyModel(BaseModel): + amazon_web_services: DummyAWSModel = DummyAWSModel() + + model = DummmyModel() + + model_updated = set_config_from_environment_variables(model) + assert model_updated.amazon_web_services.kubernetes_version == value + + def test_set_nested_attribute(): data = {"a": {"b": {"c": 1}}} set_nested_attribute(data, ["a", "b", "c"], 2) @@ -62,6 +82,27 @@ def test_set_config_from_environment_variables(nebari_config): del os.environ[secret_key_nested] +def test_set_config_from_env(monkeypatch, tmp_path, config_schema): + keyword = "NEBARI_SECRET__amazon_web_services__kubernetes_version" + value = "1.20" + monkeypatch.setenv(keyword, value) + + config_dict = { + "provider": "aws", + "project_name": "test", + "amazon_web_services": {"region": "us-east-1", "kubernetes_version": "1.19"}, + } + + config_file = tmp_path / "nebari-config.yaml" + with config_file.open("w") as f: + yaml.dump(config_dict, f) + + from _nebari.config import read_configuration + + config = read_configuration(config_file, config_schema) + assert config.amazon_web_services.kubernetes_version == value + + def test_set_config_from_environment_invalid_secret(nebari_config): invalid_secret_key = "NEBARI_SECRET__nonexistent__attribute" os.environ[invalid_secret_key] = "some_value" @@ -97,7 +138,7 @@ def test_read_configuration_non_existent_file(nebari_config): def test_write_configuration_with_dict(nebari_config, tmp_path): config_file = tmp_path / "nebari-config-dict.yaml" - config_dict = nebari_config.dict() + config_dict = nebari_config.model_dump() write_configuration(config_file, config_dict) read_config = read_configuration(config_file, nebari_config.__class__) diff --git a/tests/tests_unit/test_render.py b/tests/tests_unit/test_render.py index 73c4fb5ca1..e0fd6636fe 100644 --- a/tests/tests_unit/test_render.py +++ b/tests/tests_unit/test_render.py @@ -1,7 +1,6 @@ import os from _nebari.stages.bootstrap import CiEnum -from nebari import schema from nebari.plugins import nebari_plugin_manager @@ -22,18 +21,12 @@ def test_render_config(nebari_render): "03-kubernetes-initialize", }.issubset(os.listdir(output_directory / "stages")) - if config.provider == schema.ProviderEnum.do: - assert (output_directory / "stages" / "01-terraform-state/do").is_dir() - assert (output_directory / "stages" / "02-infrastructure/do").is_dir() - elif config.provider == schema.ProviderEnum.aws: - assert (output_directory / "stages" / "01-terraform-state/aws").is_dir() - assert (output_directory / "stages" / "02-infrastructure/aws").is_dir() - elif config.provider == schema.ProviderEnum.gcp: - assert (output_directory / "stages" / "01-terraform-state/gcp").is_dir() - assert (output_directory / "stages" / "02-infrastructure/gcp").is_dir() - elif config.provider == schema.ProviderEnum.azure: - assert (output_directory / "stages" / "01-terraform-state/azure").is_dir() - assert (output_directory / "stages" / "02-infrastructure/azure").is_dir() + assert ( + output_directory / "stages" / f"01-terraform-state/{config.provider.value}" + ).is_dir() + assert ( + output_directory / "stages" / f"02-infrastructure/{config.provider.value}" + ).is_dir() if config.ci_cd.type == CiEnum.github_actions: assert (output_directory / ".github/workflows/").is_dir() diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index b4fb58bc62..91d16b6051 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -1,9 +1,8 @@ from contextlib import nullcontext import pytest -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError -from nebari import schema from nebari.plugins import nebari_plugin_manager @@ -49,12 +48,6 @@ def test_minimal_schema_from_file_without_env(tmp_path, monkeypatch): assert config.storage.conda_store == "200Gi" -def test_render_schema(nebari_config): - assert isinstance(nebari_config, schema.Main) - assert nebari_config.project_name == f"pytest{nebari_config.provider.value}" - assert nebari_config.namespace == "dev" - - @pytest.mark.parametrize( "provider, exception", [ @@ -125,7 +118,7 @@ def test_no_provider(config_schema, provider, full_name, default_fields): } config = config_schema(**config_dict) assert config.provider == provider - assert full_name in config.dict() + assert full_name in config.model_dump() def test_multiple_providers(config_schema): @@ -164,6 +157,145 @@ def test_setted_provider(config_schema, provider): } config = config_schema(**config_dict) assert config.provider == provider - result_config_dict = config.dict() + result_config_dict = config.model_dump() assert provider in result_config_dict assert result_config_dict[provider]["kube_context"] == "some_context" + + +def test_invalid_nebari_version(config_schema): + nebari_version = "9999.99.9" + config_dict = { + "project_name": "test", + "provider": "local", + "nebari_version": f"{nebari_version}", + } + with pytest.raises( + ValidationError, + match=rf".* Assertion failed, nebari_version={nebari_version} is not an accepted version.*", + ): + config_schema(**config_dict) + + +def test_unsupported_kubernetes_version(config_schema): + # the mocked available kubernetes versions are 1.18, 1.19, 1.20 + unsupported_version = "1.23" + config_dict = { + "project_name": "test", + "provider": "gcp", + "google_cloud_platform": { + "project": "test", + "region": "us-east1", + "kubernetes_version": f"{unsupported_version}", + }, + } + with pytest.raises( + ValidationError, + match=rf"Invalid `kubernetes-version` provided: {unsupported_version}..*", + ): + config_schema(**config_dict) + + +@pytest.mark.parametrize( + "auth_provider, env_vars", + [ + ( + "Auth0", + [ + "AUTH0_CLIENT_ID", + "AUTH0_CLIENT_SECRET", + "AUTH0_DOMAIN", + ], + ), + ( + "GitHub", + [ + "GITHUB_CLIENT_ID", + "GITHUB_CLIENT_SECRET", + ], + ), + ], +) +def test_missing_auth_env_var(monkeypatch, config_schema, auth_provider, env_vars): + # auth related variables are all globally mocked, reset here + monkeypatch.undo() + for env_var in env_vars: + monkeypatch.delenv(env_var, raising=False) + + config_dict = { + "provider": "local", + "project_name": "test", + "security": {"authentication": {"type": auth_provider}}, + } + with pytest.raises( + ValidationError, + match=r".* is not set in the environment", + ): + config_schema(**config_dict) + + +@pytest.mark.parametrize( + "provider, addl_config, env_vars", + [ + ( + "aws", + { + "amazon_web_services": { + "kubernetes_version": "1.20", + "region": "us-east-1", + } + }, + ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + ), + ( + "azure", + { + "azure": { + "kubernetes_version": "1.20", + "region": "Central US", + "storage_account_postfix": "test", + } + }, + [ + "ARM_SUBSCRIPTION_ID", + "ARM_TENANT_ID", + "ARM_CLIENT_ID", + "ARM_CLIENT_SECRET", + ], + ), + ( + "gcp", + { + "google_cloud_platform": { + "kubernetes_version": "1.20", + "region": "us-east1", + "project": "test", + } + }, + ["GOOGLE_CREDENTIALS", "PROJECT_ID"], + ), + ( + "do", + {"digital_ocean": {"kubernetes_version": "1.20", "region": "nyc3"}}, + ["DIGITALOCEAN_TOKEN", "SPACES_ACCESS_KEY_ID", "SPACES_SECRET_ACCESS_KEY"], + ), + ], +) +def test_missing_cloud_env_var( + monkeypatch, config_schema, provider, addl_config, env_vars +): + # cloud methods are all globally mocked, need to reset so the env variables will be checked + monkeypatch.undo() + for env_var in env_vars: + monkeypatch.delenv(env_var, raising=False) + + nebari_config = { + "provider": provider, + "project_name": "test", + } + nebari_config.update(addl_config) + + with pytest.raises( + ValidationError, + match=r".* Missing the following required environment variables: .*", + ): + config_schema(**nebari_config)