diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 553e520e3a..31de47ad36 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -9,7 +9,7 @@ import warnings from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union -from pydantic import ConfigDict, Field, field_validator, model_validator +from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator from _nebari import constants from _nebari.provider import opentofu @@ -39,15 +39,75 @@ class ExistingInputVars(schema.Base): kube_context: str +class NodeGroup(schema.Base): + instance: str + min_nodes: Annotated[int, Field(ge=0)] = 0 + max_nodes: Annotated[int, Field(ge=1)] = 1 + taints: Optional[List[schema.Taint]] = None + + @field_validator("taints", mode="before") + def validate_taint_strings(cls, taints: list[Any]): + if taints is None: + return taints + + TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)") + return_value = [] + for taint in taints: + if not isinstance(taint, str): + return_value.append(taint) + else: + match = TAINT_STR_REGEX.match(taint) + if not match: + raise ValueError(f"Invalid taint string: {taint}") + key, taints, effect = match.groups() + parsed_taint = schema.Taint(key=key, value=taints, effect=effect) + return_value.append(parsed_taint) + + return return_value + + +DEFAULT_GENERAL_NODE_GROUP_TAINTS = [] +DEFAULT_NODE_GROUP_TAINTS = [ + schema.Taint(key="dedicated", value="nebari", effect="NoSchedule") +] + + +def set_missing_taints_to_default_taints(node_groups: NodeGroup) -> NodeGroup: + + for node_group_name, node_group in node_groups.items(): + if node_group.taints is None: + if node_group_name == "general": + node_group.taints = DEFAULT_GENERAL_NODE_GROUP_TAINTS + else: + node_group.taints = DEFAULT_NODE_GROUP_TAINTS + return node_groups + + class GCPNodeGroupInputVars(schema.Base): name: str instance_type: str min_size: int max_size: int + node_taints: List[dict] labels: Dict[str, str] preemptible: bool guest_accelerators: List["GCPGuestAccelerator"] + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [ + dict( + key=taint.key, + value=taint.value, + effect={ + schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE", + schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE", + schema.TaintEffectEnum.NoExecute: "NO_EXECUTE", + }[taint.effect], + ) + for taint in value + ] + class GCPPrivateClusterConfig(schema.Base): enable_private_nodes: bool @@ -89,6 +149,11 @@ class AzureNodeGroupInputVars(schema.Base): instance: str min_nodes: int max_nodes: int + node_taints: list[str] + + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [f"{taint.key}={taint.value}:{taint.effect.value}" for taint in value] class AzureInputVars(schema.Base): @@ -132,6 +197,22 @@ class AWSNodeGroupInputVars(schema.Base): permissions_boundary: Optional[str] = None ami_type: Optional[AWSAmiTypes] = None launch_template: Optional[AWSNodeLaunchTemplate] = None + node_taints: list[dict] + + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [ + dict( + key=taint.key, + value=taint.value, + effect={ + schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE", + schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE", + schema.TaintEffectEnum.NoExecute: "NO_EXECUTE", + }[taint.effect], + ) + for taint in value + ] def construct_aws_ami_type( @@ -158,6 +239,21 @@ def construct_aws_ami_type( return "AL2_x86_64" + @field_validator("node_taints", mode="before") + def convert_taints(cls, value: Optional[List[schema.Taint]]): + return [ + dict( + key=taint.key, + value=taint.value, + effect={ + schema.TaintEffectEnum.NoSchedule: "NO_SCHEDULE", + schema.TaintEffectEnum.PreferNoSchedule: "PREFER_NO_SCHEDULE", + schema.TaintEffectEnum.NoExecute: "NO_EXECUTE", + }[taint.effect], + ) + for taint in value + ] + class AWSInputVars(schema.Base): name: str @@ -271,19 +367,28 @@ class GCPGuestAccelerator(schema.Base): count: Annotated[int, Field(ge=1)] = 1 -class GCPNodeGroup(schema.Base): - instance: str - min_nodes: Annotated[int, Field(ge=0)] = 0 - max_nodes: Annotated[int, Field(ge=1)] = 1 +class GCPNodeGroup(NodeGroup): preemptible: bool = False labels: Dict[str, str] = {} guest_accelerators: List[GCPGuestAccelerator] = [] DEFAULT_GCP_NODE_GROUPS = { - "general": GCPNodeGroup(instance="e2-standard-8", min_nodes=1, max_nodes=1), - "user": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5), - "worker": GCPNodeGroup(instance="e2-standard-4", min_nodes=0, max_nodes=5), + "general": GCPNodeGroup( + instance="e2-standard-8", + min_nodes=1, + max_nodes=1, + ), + "user": GCPNodeGroup( + instance="e2-standard-4", + min_nodes=0, + max_nodes=5, + ), + "worker": GCPNodeGroup( + instance="e2-standard-4", + min_nodes=0, + max_nodes=5, + ), } @@ -296,7 +401,9 @@ class GoogleCloudPlatformProvider(schema.Base): kubernetes_version: str availability_zones: Optional[List[str]] = [] release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL - node_groups: Dict[str, GCPNodeGroup] = DEFAULT_GCP_NODE_GROUPS + node_groups: Annotated[ + Dict[str, GCPNodeGroup], AfterValidator(set_missing_taints_to_default_taints) + ] = Field(DEFAULT_GCP_NODE_GROUPS, validate_default=True) tags: Optional[List[str]] = [] networking_mode: str = "ROUTE" network: str = "default" @@ -346,16 +453,26 @@ def _check_input(cls, data: Any) -> Any: return data -class AzureNodeGroup(schema.Base): - instance: str - min_nodes: int - max_nodes: int +class AzureNodeGroup(NodeGroup): + pass DEFAULT_AZURE_NODE_GROUPS = { - "general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1), - "user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), - "worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), + "general": AzureNodeGroup( + instance="Standard_D8_v3", + min_nodes=1, + max_nodes=1, + ), + "user": AzureNodeGroup( + instance="Standard_D4_v3", + min_nodes=0, + max_nodes=5, + ), + "worker": AzureNodeGroup( + instance="Standard_D4_v3", + min_nodes=0, + max_nodes=5, + ), } @@ -365,7 +482,9 @@ class AzureProvider(schema.Base): storage_account_postfix: str authorized_ip_ranges: Optional[List[str]] = ["0.0.0.0/0"] resource_group_name: Optional[str] = None - node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS + node_groups: Annotated[ + Dict[str, AzureNodeGroup], AfterValidator(set_missing_taints_to_default_taints) + ] = Field(DEFAULT_AZURE_NODE_GROUPS, validate_default=True) storage_account_postfix: str vnet_subnet_id: Optional[str] = None private_cluster_enabled: bool = False @@ -419,10 +538,7 @@ def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]: return value if value is None else azure_cloud.validate_tags(value) -class AWSNodeGroup(schema.Base): - instance: str - min_nodes: int = 0 - max_nodes: int +class AWSNodeGroup(NodeGroup): gpu: bool = False single_subnet: bool = False permissions_boundary: Optional[str] = None @@ -439,12 +555,22 @@ def check_launch_template(cls, values): DEFAULT_AWS_NODE_GROUPS = { - "general": AWSNodeGroup(instance="m5.2xlarge", min_nodes=1, max_nodes=1), + "general": AWSNodeGroup( + instance="m5.2xlarge", + min_nodes=1, + max_nodes=1, + ), "user": AWSNodeGroup( - instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False + instance="m5.xlarge", + min_nodes=0, + max_nodes=5, + single_subnet=False, ), "worker": AWSNodeGroup( - instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False + instance="m5.xlarge", + min_nodes=0, + max_nodes=5, + single_subnet=False, ), } @@ -453,7 +579,9 @@ class AmazonWebServicesProvider(schema.Base): region: str kubernetes_version: str availability_zones: Optional[List[str]] - node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS + node_groups: Annotated[ + Dict[str, AWSNodeGroup], AfterValidator(set_missing_taints_to_default_taints) + ] = Field(DEFAULT_AWS_NODE_GROUPS, validate_default=True) eks_endpoint_access: Optional[ Literal["private", "public", "public_and_private"] ] = "public" @@ -579,16 +707,8 @@ class ExistingProvider(schema.Base): schema.ProviderEnum.azure: AzureProvider, } -provider_enum_name_map: Dict[schema.ProviderEnum, str] = { - schema.ProviderEnum.local: "local", - schema.ProviderEnum.existing: "existing", - schema.ProviderEnum.gcp: "google_cloud_platform", - schema.ProviderEnum.aws: "amazon_web_services", - schema.ProviderEnum.azure: "azure", -} - provider_name_abbreviation_map: Dict[str, str] = { - value: key.value for key, value in provider_enum_name_map.items() + value: key.value for key, value in schema.provider_enum_name_map.items() } provider_enum_default_node_groups_map: Dict[schema.ProviderEnum, Any] = { @@ -628,7 +748,7 @@ def check_provider(cls, data: Any) -> Any: for provider in provider_name_abbreviation_map.keys() if provider in data and data[provider] } - expected_provider_config = provider_enum_name_map[provider] + expected_provider_config = schema.provider_enum_name_map[provider] extra_provider_config = set_providers - {expected_provider_config} if extra_provider_config: warnings.warn( @@ -776,6 +896,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): instance_type=node_group.instance, min_size=node_group.min_nodes, max_size=node_group.max_nodes, + node_taints=node_group.taints, preemptible=node_group.preemptible, guest_accelerators=node_group.guest_accelerators, ) @@ -808,6 +929,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): instance=node_group.instance, min_nodes=node_group.min_nodes, max_nodes=node_group.max_nodes, + node_taints=node_group.taints, ) for name, node_group in self.config.azure.node_groups.items() }, @@ -851,6 +973,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): single_subnet=node_group.single_subnet, permissions_boundary=node_group.permissions_boundary, launch_template=None, + node_taints=node_group.taints, ami_type=construct_aws_ami_type( gpu_enabled=node_group.gpu, launch_template=None, diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf index 2537b12dad..c246da0b0b 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/main.tf @@ -98,6 +98,15 @@ resource "aws_eks_node_group" "main" { max_size = var.node_groups[count.index].max_size } + dynamic "taint" { + for_each = var.node_groups[count.index].node_taints + content { + key = taint.value.key + value = taint.value.value + effect = taint.value.effect + } + } + # Only set launch_template if its node_group counterpart parameter is not null dynamic "launch_template" { for_each = var.node_groups[count.index].launch_template != null ? [0] : [] diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf index 63558e550f..26b5b82865 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf @@ -53,6 +53,11 @@ variable "node_groups" { single_subnet = bool launch_template = map(any) ami_type = string + node_taints = list(object({ + key = string + value = string + effect = string + })) })) } diff --git a/src/_nebari/stages/infrastructure/template/aws/variables.tf b/src/_nebari/stages/infrastructure/template/aws/variables.tf index a71df81d0f..29d3546519 100644 --- a/src/_nebari/stages/infrastructure/template/aws/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/variables.tf @@ -40,6 +40,11 @@ variable "node_groups" { single_subnet = bool launch_template = map(any) ami_type = string + node_taints = list(object({ + key = string + value = string + effect = string + })) })) } diff --git a/src/_nebari/stages/infrastructure/template/azure/main.tf b/src/_nebari/stages/infrastructure/template/azure/main.tf index 960b755f8c..1569844d45 100644 --- a/src/_nebari/stages/infrastructure/template/azure/main.tf +++ b/src/_nebari/stages/infrastructure/template/azure/main.tf @@ -39,6 +39,7 @@ module "kubernetes" { instance_type = config.instance min_size = config.min_nodes max_size = config.max_nodes + node_taints = config.node_taints } ] vnet_subnet_id = var.vnet_subnet_id diff --git a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf index f97f1f6383..50e2f48ff6 100644 --- a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/main.tf @@ -42,6 +42,7 @@ resource "azurerm_kubernetes_cluster" "main" { min_count = var.node_groups[0].min_size max_count = var.node_groups[0].max_size max_pods = var.max_pods + # It's not possible to add node_taints to the default node pool. See https://github.com/hashicorp/terraform-provider-azurerm/issues/9183 for more info orchestrator_version = var.kubernetes_version node_labels = { @@ -87,4 +88,5 @@ resource "azurerm_kubernetes_cluster_node_pool" "node_group" { orchestrator_version = var.kubernetes_version tags = var.tags vnet_subnet_id = var.vnet_subnet_id + node_taints = each.value.node_taints } diff --git a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf index 95d2045420..4140a4e486 100644 --- a/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/azure/modules/kubernetes/variables.tf @@ -29,10 +29,16 @@ variable "environment" { type = string } - variable "node_groups" { description = "Node pools to add to Azure Kubernetes Cluster" - type = list(map(any)) + type = list(object({ + name = string + auto_scale = bool + instance_type = string + min_size = number + max_size = number + node_taints = list(string) + })) } variable "vnet_subnet_id" { diff --git a/src/_nebari/stages/infrastructure/template/azure/variables.tf b/src/_nebari/stages/infrastructure/template/azure/variables.tf index 44ef90463f..eef4217061 100644 --- a/src/_nebari/stages/infrastructure/template/azure/variables.tf +++ b/src/_nebari/stages/infrastructure/template/azure/variables.tf @@ -21,9 +21,10 @@ variable "kubernetes_version" { variable "node_groups" { description = "Azure node groups" type = map(object({ - instance = string - min_nodes = number - max_nodes = number + instance = string + min_nodes = number + max_nodes = number + node_taints = list(string) })) } diff --git a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf index 57e8d9fc88..182168fada 100644 --- a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf +++ b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/main.tf @@ -93,6 +93,15 @@ resource "google_container_node_pool" "main" { oauth_scopes = local.node_group_oauth_scopes + dynamic "taint" { + for_each = local.merged_node_groups[count.index].node_taints + content { + key = taint.value.key + value = taint.value.value + effect = taint.value.effect + } + } + metadata = { disable-legacy-endpoints = "true" } @@ -108,9 +117,4 @@ resource "google_container_node_pool" "main" { tags = var.tags } - lifecycle { - ignore_changes = [ - node_config[0].taint - ] - } } diff --git a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf index 2ee2d78ed5..236a0b9017 100644 --- a/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/gcp/modules/kubernetes/variables.tf @@ -50,6 +50,7 @@ variable "node_groups" { min_size = 1 max_size = 1 labels = {} + node_taints = [] }, { name = "user" @@ -57,6 +58,7 @@ variable "node_groups" { min_size = 0 max_size = 2 labels = {} + node_taints = [] }, { name = "worker" @@ -64,6 +66,7 @@ variable "node_groups" { min_size = 0 max_size = 5 labels = {} + node_taints = [] } ] } diff --git a/src/_nebari/stages/kubernetes_services/__init__.py b/src/_nebari/stages/kubernetes_services/__init__.py index fdc413bd40..41025b7737 100644 --- a/src/_nebari/stages/kubernetes_services/__init__.py +++ b/src/_nebari/stages/kubernetes_services/__init__.py @@ -453,6 +453,28 @@ def handle_units(cls, value: Optional[str]) -> float: return byte_unit_conversion(value, "GiB") +class TolerationOperatorEnum(str, enum.Enum): + Equal = "Equal" + Exists = "Exists" + + @classmethod + def to_yaml(cls, representer, node): + return representer.represent_str(node.value) + + +class Toleration(schema.Taint): + operator: TolerationOperatorEnum = TolerationOperatorEnum.Equal + + @classmethod + def from_taint( + cls, taint: schema.Taint, operator: None | TolerationOperatorEnum = None + ): + kwargs = {} + if operator: + kwargs["operator"] = operator + return cls(**taint.model_dump(), **kwargs) + + class JupyterhubInputVars(schema.Base): jupyterhub_theme: Dict[str, Any] = Field(alias="jupyterhub-theme") jupyterlab_image: ImageNameTag = Field(alias="jupyterlab-image") @@ -478,6 +500,9 @@ class JupyterhubInputVars(schema.Base): cloud_provider: str = Field(alias="cloud-provider") jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir") shared_fs_type: SharedFsEnum + user_taint_tolerations: Optional[List[Toleration]] = Field( + alias="node-taint-tolerations" + ) @field_validator("jupyterhub_shared_storage", mode="before") @classmethod @@ -490,6 +515,9 @@ class DaskGatewayInputVars(schema.Base): dask_gateway_profiles: Dict[str, Any] = Field(alias="dask-gateway-profiles") cloud_provider: str = Field(alias="cloud-provider") forwardauth_middleware_name: str = _forwardauth_middleware_name + worker_taint_tolerations: Optional[list[Toleration]] = Field( + alias="worker-taint-tolerations" + ) class MonitoringInputVars(schema.Base): @@ -592,6 +620,27 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ): jupyterhub_theme.update({"version": f"v{self.config.nebari_version}"}) + def _node_taint_tolerations(node_group_name: str) -> List[Toleration]: + tolerations = [] + provider = getattr( + self.config, schema.provider_enum_name_map[self.config.provider] + ) + if not ( + hasattr(provider, "node_groups") + and provider.node_groups.get(node_group_name, {}) + and hasattr(provider.node_groups[node_group_name], "taints") + ): + return tolerations + tolerations = [ + Toleration.from_taint(taint) + for taint in getattr( + self.config, schema.provider_enum_name_map[self.config.provider] + ) + .node_groups[node_group_name] + .taints + ] + return tolerations + kubernetes_services_vars = KubernetesServicesInputVars( name=self.config.project_name, environment=self.config.namespace, @@ -646,6 +695,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): jupyterlab_default_settings=self.config.jupyterlab.default_settings, jupyterlab_gallery_settings=self.config.jupyterlab.gallery_settings, jupyterlab_preferred_dir=self.config.jupyterlab.preferred_dir, + user_taint_tolerations=_node_taint_tolerations(node_group_name="user"), shared_fs_type=( # efs is equivalent to nfs in these modules SharedFsEnum.nfs @@ -660,6 +710,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ), dask_gateway_profiles=self.config.profiles.model_dump()["dask_worker"], cloud_provider=cloud_provider, + worker_taint_tolerations=_node_taint_tolerations(node_group_name="worker"), ) monitoring_vars = MonitoringInputVars( diff --git a/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf b/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf index a47acee8fa..997a4ab294 100644 --- a/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf +++ b/src/_nebari/stages/kubernetes_services/template/dask_gateway.tf @@ -11,6 +11,16 @@ variable "dask-gateway-profiles" { description = "Dask Gateway profiles to expose to user" } +variable "worker-taint-tolerations" { + description = "Tolerations for the worker node taints needed by Dask Scheduler/Worker pods" + type = list(object({ + key = string + operator = string + value = string + effect = string + })) +} + # =================== RESOURCES ===================== module "dask-gateway" { source = "./modules/kubernetes/services/dask-gateway" @@ -43,6 +53,15 @@ module "dask-gateway" { forwardauth_middleware_name = var.forwardauth_middleware_name + cluster = { + scheduler_extra_pod_config = { + tolerations = var.worker-taint-tolerations + } + worker_extra_pod_config = { + tolerations = var.worker-taint-tolerations + } + } + depends_on = [ module.kubernetes-nfs-server, module.rook-ceph diff --git a/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf b/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf index 121cff4b22..8759f13f43 100644 --- a/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf +++ b/src/_nebari/stages/kubernetes_services/template/jupyterhub.tf @@ -85,6 +85,16 @@ variable "idle-culler-settings" { type = any } +variable "node-taint-tolerations" { + description = "Node taint toleration" + type = list(object({ + key = string + operator = string + value = string + effect = string + })) +} + variable "shared_fs_type" { type = string description = "Use NFS or Ceph" @@ -180,6 +190,7 @@ module "jupyterhub" { conda-store-service-name = module.kubernetes-conda-store-server.service_name conda-store-jhub-apps-token = module.kubernetes-conda-store-server.service-tokens.jhub-apps jhub-apps-enabled = var.jhub-apps-enabled + node-taint-tolerations = var.node-taint-tolerations jhub-apps-overrides = var.jhub-apps-overrides extra-mounts = { diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py index c58e3aa90d..427b8734a7 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/files/gateway_config.py @@ -15,7 +15,6 @@ def dask_gateway_config(path="/var/lib/dask-gateway/config.json"): config = dask_gateway_config() - c.DaskGateway.log_level = config["gateway"]["loglevel"] # Configure addresses @@ -26,6 +25,8 @@ def dask_gateway_config(path="/var/lib/dask-gateway/config.json"): c.KubeBackend.gateway_instance = config["gateway_service_name"] # ========= Dask Cluster Default Configuration ========= +# These settings are overridden by c.Backend.cluster_option if key e.g. image, scheduler_extra_pod_config, etc. is present + c.KubeClusterConfig.image = ( f"{config['cluster-image']['name']}:{config['cluster-image']['tag']}" ) @@ -40,6 +41,7 @@ def dask_gateway_config(path="/var/lib/dask-gateway/config.json"): c.KubeClusterConfig.scheduler_extra_container_config = config["cluster"][ "scheduler_extra_container_config" ] + c.KubeClusterConfig.scheduler_extra_pod_config = config["cluster"][ "scheduler_extra_pod_config" ] @@ -227,18 +229,24 @@ def base_username_mount(username, uid=1000, gid=100): } -def worker_profile(options, user): - namespace, name = options.conda_environment.split("/") +def options_handler(options, user): + namespace, environment_name = options.conda_environment.split("/") return functools.reduce( deep_merge, [ + # ordering is higher to lower precedence + {}, base_node_group(options), - base_conda_store_mounts(namespace, name), + base_conda_store_mounts(namespace, environment_name), base_username_mount(user.name), config["profiles"][options.profile], {"environment": {**options.environment_vars}}, + # merge with default values + { + k: config["cluster"][k] + for k in ("worker_extra_pod_config", "scheduler_extra_pod_config") + }, ], - {}, ) @@ -279,7 +287,7 @@ def user_options(user): return Options( *args, - handler=worker_profile, + handler=options_handler, ) @@ -288,7 +296,7 @@ def user_options(user): # ============== utils ============ def deep_merge(d1, d2): - """Deep merge two dictionaries. + """Deep merge two dictionaries. Left argument takes precedence. >>> value_1 = { 'a': [1, 2], 'b': {'c': 1, 'z': [5, 6]}, diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf index 121405a322..0b3fbcab35 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/dask-gateway/variables.tf @@ -130,23 +130,23 @@ variable "cluster" { description = "dask gateway cluster defaults" type = object({ # scheduler configuration - scheduler_cores = number - scheduler_cores_limit = number - scheduler_memory = string - scheduler_memory_limit = string - scheduler_extra_container_config = any - scheduler_extra_pod_config = any + scheduler_cores = optional(number, 1) + scheduler_cores_limit = optional(number, 1) + scheduler_memory = optional(string, "2 G") + scheduler_memory_limit = optional(string, "2 G") + scheduler_extra_container_config = optional(any, {}) + scheduler_extra_pod_config = optional(any, {}) # worker configuration - worker_cores = number - worker_cores_limit = number - worker_memory = string - worker_memory_limit = string - worker_extra_container_config = any - worker_extra_pod_config = any + worker_cores = optional(number, 1) + worker_cores_limit = optional(number, 1) + worker_memory = optional(string, "2 G") + worker_memory_limit = optional(string, "2 G") + worker_extra_container_config = optional(any, {}) + worker_extra_pod_config = optional(any, {}) # additional fields - idle_timeout = number - image_pull_policy = string - environment = map(string) + idle_timeout = optional(number, 1800) # 30 minutes + image_pull_policy = optional(string, "IfNotPresent") + environment = optional(map(string), {}) }) default = { # scheduler configuration diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py index b298ae5ae1..83d8444ac1 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/files/jupyterhub/03-profiles.py @@ -243,6 +243,25 @@ def base_profile_extra_mounts(): } +def node_taint_tolerations(): + tolerations = z2jh.get_config("custom.node-taint-tolerations") + + if not tolerations: + return {} + + return { + "tolerations": [ + { + "key": taint["key"], + "operator": taint["operator"], + "value": taint["value"], + "effect": taint["effect"], + } + for taint in tolerations + ] + } + + def configure_user_provisioned_repositories(username): # Define paths and configurations pvc_home_mount_path = f"home/{username}" @@ -523,6 +542,7 @@ def render_profile( configure_user(username, groups), configure_user_provisioned_repositories(username), profile_kubespawner_override, + node_taint_tolerations(), ], {}, ) diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf index 9a0675fc85..8588cb549d 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/main.tf @@ -79,6 +79,7 @@ resource "helm_release" "jupyterhub" { jhub-apps-enabled = var.jhub-apps-enabled jhub-apps-overrides = var.jhub-apps-overrides initial-repositories = var.initial-repositories + node-taint-tolerations = var.node-taint-tolerations skel-mount = { name = kubernetes_config_map.etc-skel.metadata.0.name namespace = kubernetes_config_map.etc-skel.metadata.0.namespace diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf index f395e08487..cc0c935872 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/jupyterhub/variables.tf @@ -219,3 +219,13 @@ variable "initial-repositories" { type = string default = "[]" } + +variable "node-taint-tolerations" { + description = "Node taint toleration" + type = list(object({ + key = string + operator = string + value = string + effect = string + })) +} diff --git a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf index 8180d46fb8..3868de9cbf 100644 --- a/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf +++ b/src/_nebari/stages/kubernetes_services/template/modules/kubernetes/services/monitoring/loki/main.tf @@ -96,6 +96,22 @@ resource "helm_release" "grafana-promtail" { values = concat([ file("${path.module}/values_promtail.yaml"), jsonencode({ + tolerations = [ + { + key = "node-role.kubernetes.io/master" + operator = "Exists" + effect = "NoSchedule" + }, + { + key = "node-role.kubernetes.io/control-plane" + operator = "Exists" + effect = "NoSchedule" + }, + { + operator = "Exists" + effect = "NoSchedule" + }, + ] }) ], var.grafana-promtail-overrides) diff --git a/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf b/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf index c40b6fae33..96cf6131e4 100644 --- a/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf +++ b/src/_nebari/stages/kubernetes_services/template/rook-ceph.tf @@ -41,6 +41,13 @@ resource "helm_release" "rook-ceph" { }, csi = { enableRbdDriver = false, # necessary to provision block storage, but saves some cpu and memory if not needed + provisionerReplicas : 1, # default is 2 on different nodes + pluginTolerations = [ + { + operator = "Exists" + effect = "NoSchedule" + } + ], }, }) ], diff --git a/src/_nebari/upgrade.py b/src/_nebari/upgrade.py index 71795dfa1e..5c07d55152 100644 --- a/src/_nebari/upgrade.py +++ b/src/_nebari/upgrade.py @@ -26,10 +26,6 @@ from _nebari.config import backup_configuration from _nebari.keycloak import get_keycloak_admin -from _nebari.stages.infrastructure import ( - provider_enum_default_node_groups_map, - provider_enum_name_map, -) from _nebari.utils import ( get_k8s_version_prefix, get_provider_config_block_name, @@ -37,7 +33,7 @@ yaml, ) from _nebari.version import __version__, rounded_ver_parse -from nebari.schema import ProviderEnum, is_version_accepted +from nebari.schema import ProviderEnum, is_version_accepted, provider_enum_name_map logger = logging.getLogger(__name__) @@ -1169,7 +1165,7 @@ def _version_specific_upgrade( provider_full_name, {} ): try: - default_node_groups = provider_enum_default_node_groups_map[ + default_node_groups = schema.provider_enum_default_node_groups_map[ provider ] continue_ = kwargs.get("attempt_fixes", False) or Confirm.ask( diff --git a/src/nebari/schema.py b/src/nebari/schema.py index b45af521be..bba6e9de11 100644 --- a/src/nebari/schema.py +++ b/src/nebari/schema.py @@ -104,3 +104,29 @@ def is_version_accepted(v): for deployment with the current Nebari package. """ return Main.is_version_accepted(v) + + +@yaml_object(yaml) +class TaintEffectEnum(str, enum.Enum): + NoSchedule: str = "NoSchedule" + PreferNoSchedule: str = "PreferNoSchedule" + NoExecute: str = "NoExecute" + + @classmethod + def to_yaml(cls, representer, node): + return representer.represent_str(node.value) + + +class Taint(Base): + key: str + value: str + effect: TaintEffectEnum + + +provider_enum_name_map: dict[ProviderEnum, str] = { + ProviderEnum.local: "local", + ProviderEnum.existing: "existing", + ProviderEnum.gcp: "google_cloud_platform", + ProviderEnum.aws: "amazon_web_services", + ProviderEnum.azure: "azure", +} diff --git a/tests/tests_unit/test_cli_init.py b/tests/tests_unit/test_cli_init.py index 03b22557ae..25cfcdbe0d 100644 --- a/tests/tests_unit/test_cli_init.py +++ b/tests/tests_unit/test_cli_init.py @@ -209,7 +209,7 @@ def assert_nebari_init_args( app, args + ["--output", tmp_file.resolve()], input=input ) - assert not result.exception + assert not result.exception, result.output assert 0 == result.exit_code assert tmp_file.exists() is True diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index e445ba37da..52acecd09c 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -3,6 +3,10 @@ import pytest from pydantic import ValidationError +from _nebari.stages.infrastructure import ( + DEFAULT_GENERAL_NODE_GROUP_TAINTS, + DEFAULT_NODE_GROUP_TAINTS, +) from nebari import schema from nebari.plugins import nebari_plugin_manager @@ -82,6 +86,51 @@ def test_provider_validation(config_schema, provider, exception): assert config.provider == provider +@pytest.mark.parametrize( + "provider, full_name, default_fields", + [ + ( + "aws", + "amazon_web_services", + {"region": "us-east-1", "kubernetes_version": "1.18"}, + ), + ( + "gcp", + "google_cloud_platform", + { + "region": "us-east1", + "project": "test-project", + "kubernetes_version": "1.18", + }, + ), + ( + "azure", + "azure", + { + "region": "eastus", + "kubernetes_version": "1.18", + "storage_account_postfix": "test", + }, + ), + ], +) +def test_node_group_default_taints_set( + config_schema, provider, full_name, default_fields +): + config_dict = { + "project_name": "test", + "provider": f"{provider}", + f"{full_name}": default_fields, + } + config = config_schema(**config_dict) + ng = getattr(config, schema.provider_enum_name_map[config.provider]).node_groups + for ng_name in ng: + if ng_name == "general": + assert ng[ng_name].taints == DEFAULT_GENERAL_NODE_GROUP_TAINTS + else: + assert ng[ng_name].taints == DEFAULT_NODE_GROUP_TAINTS + + @pytest.mark.parametrize( "provider, full_name, default_fields", [