Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable AWS tags support #2096

Merged
merged 6 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 47 additions & 51 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import re
import sys
import tempfile
import typing
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import pydantic

Expand Down Expand Up @@ -52,9 +51,9 @@ class DigitalOceanInputVars(schema.Base):
name: str
environment: str
region: str
tags: typing.List[str]
tags: List[str]
kubernetes_version: str
node_groups: typing.Dict[str, DigitalOceanNodeGroup]
node_groups: Dict[str, DigitalOceanNodeGroup]
kubeconfig_filename: str = get_kubeconfig_filename()


Expand Down Expand Up @@ -143,6 +142,7 @@ class AWSInputVars(schema.Base):
vpc_cidr_block: str
permissions_boundary: Optional[str] = None
kubeconfig_filename: str = get_kubeconfig_filename()
tags: Dict[str, str] = {}


def _calculate_node_groups(config: schema.Main):
Expand Down Expand Up @@ -216,7 +216,7 @@ class DigitalOceanProvider(schema.Base):
region: str
kubernetes_version: str
# Digital Ocean image slugs are listed here https://slugs.do-api.dev/
node_groups: typing.Dict[str, DigitalOceanNodeGroup] = {
node_groups: Dict[str, DigitalOceanNodeGroup] = {
"general": DigitalOceanNodeGroup(
instance="g-8vcpu-32gb", min_nodes=1, max_nodes=1
),
Expand All @@ -227,7 +227,7 @@ class DigitalOceanProvider(schema.Base):
instance="g-4vcpu-16gb", min_nodes=1, max_nodes=5
),
}
tags: typing.Optional[typing.List[str]] = []
tags: Optional[List[str]] = []

@pydantic.validator("region")
def _validate_region(cls, value):
Expand Down Expand Up @@ -289,7 +289,7 @@ class GCPCIDRBlock(schema.Base):


class GCPMasterAuthorizedNetworksConfig(schema.Base):
cidr_blocks: typing.List[GCPCIDRBlock]
cidr_blocks: List[GCPCIDRBlock]


class GCPPrivateClusterConfig(schema.Base):
Expand All @@ -314,34 +314,28 @@ class GCPNodeGroup(schema.Base):
min_nodes: pydantic.conint(ge=0) = 0
max_nodes: pydantic.conint(ge=1) = 1
preemptible: bool = False
labels: typing.Dict[str, str] = {}
guest_accelerators: typing.List[GCPGuestAccelerator] = []
labels: Dict[str, str] = {}
guest_accelerators: List[GCPGuestAccelerator] = []


class GoogleCloudPlatformProvider(schema.Base):
region: str
project: str
kubernetes_version: str
availability_zones: typing.Optional[typing.List[str]] = []
availability_zones: Optional[List[str]] = []
release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL
node_groups: typing.Dict[str, GCPNodeGroup] = {
node_groups: Dict[str, GCPNodeGroup] = {
"general": GCPNodeGroup(instance="n1-standard-8", min_nodes=1, max_nodes=1),
"user": GCPNodeGroup(instance="n1-standard-4", min_nodes=0, max_nodes=5),
"worker": GCPNodeGroup(instance="n1-standard-4", min_nodes=0, max_nodes=5),
}
tags: typing.Optional[typing.List[str]] = []
tags: Optional[List[str]] = []
networking_mode: str = "ROUTE"
network: str = "default"
subnetwork: typing.Optional[typing.Union[str, None]] = None
ip_allocation_policy: typing.Optional[
typing.Union[GCPIPAllocationPolicy, None]
] = None
master_authorized_networks_config: typing.Optional[
typing.Union[GCPCIDRBlock, None]
] = None
private_cluster_config: typing.Optional[
typing.Union[GCPPrivateClusterConfig, None]
] = None
subnetwork: Optional[Union[str, None]] = None
ip_allocation_policy: Optional[Union[GCPIPAllocationPolicy, None]] = None
master_authorized_networks_config: Optional[Union[GCPCIDRBlock, None]] = None
private_cluster_config: Optional[Union[GCPPrivateClusterConfig, None]] = None

@pydantic.root_validator
def validate_all(cls, values):
Expand Down Expand Up @@ -381,18 +375,18 @@ class AzureProvider(schema.Base):
kubernetes_version: str
storage_account_postfix: str
resource_group_name: str = None
node_groups: typing.Dict[str, AzureNodeGroup] = {
node_groups: Dict[str, AzureNodeGroup] = {
"general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1),
"user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5),
"worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5),
}
storage_account_postfix: str
vnet_subnet_id: typing.Optional[typing.Union[str, None]] = None
vnet_subnet_id: Optional[Union[str, None]] = None
private_cluster_enabled: bool = False
resource_group_name: typing.Optional[str] = None
tags: typing.Optional[typing.Dict[str, str]] = {}
network_profile: typing.Optional[typing.Dict[str, str]] = None
max_pods: typing.Optional[int] = None
resource_group_name: Optional[str] = None
tags: Optional[Dict[str, str]] = {}
network_profile: Optional[Dict[str, str]] = None
max_pods: Optional[int] = None

@pydantic.validator("kubernetes_version")
def _validate_kubernetes_version(cls, value):
Expand Down Expand Up @@ -440,8 +434,8 @@ class AWSNodeGroup(schema.Base):
class AmazonWebServicesProvider(schema.Base):
region: str
kubernetes_version: str
availability_zones: typing.Optional[typing.List[str]]
node_groups: typing.Dict[str, AWSNodeGroup] = {
availability_zones: Optional[List[str]]
node_groups: Dict[str, AWSNodeGroup] = {
"general": AWSNodeGroup(instance="m5.2xlarge", min_nodes=1, max_nodes=1),
"user": AWSNodeGroup(
instance="m5.xlarge", min_nodes=1, max_nodes=5, single_subnet=False
Expand All @@ -450,10 +444,11 @@ class AmazonWebServicesProvider(schema.Base):
instance="m5.xlarge", min_nodes=1, max_nodes=5, single_subnet=False
),
}
existing_subnet_ids: typing.List[str] = None
existing_security_group_ids: str = None
existing_subnet_ids: List[str] = None
existing_security_group_id: str = None
vpc_cidr_block: str = "10.10.0.0/16"
permissions_boundary: Optional[str] = None
tags: Optional[Dict[str, str]] = {}

@pydantic.root_validator
def validate_all(cls, values):
Expand Down Expand Up @@ -491,17 +486,17 @@ def validate_all(cls, values):


class LocalProvider(schema.Base):
kube_context: typing.Optional[str]
node_selectors: typing.Dict[str, KeyValueDict] = {
kube_context: Optional[str]
node_selectors: Dict[str, KeyValueDict] = {
"general": KeyValueDict(key="kubernetes.io/os", value="linux"),
"user": KeyValueDict(key="kubernetes.io/os", value="linux"),
"worker": KeyValueDict(key="kubernetes.io/os", value="linux"),
}


class ExistingProvider(schema.Base):
kube_context: typing.Optional[str]
node_selectors: typing.Dict[str, KeyValueDict] = {
kube_context: Optional[str]
node_selectors: Dict[str, KeyValueDict] = {
"general": KeyValueDict(key="kubernetes.io/os", value="linux"),
"user": KeyValueDict(key="kubernetes.io/os", value="linux"),
"worker": KeyValueDict(key="kubernetes.io/os", value="linux"),
Expand Down Expand Up @@ -532,12 +527,12 @@ class ExistingProvider(schema.Base):


class InputSchema(schema.Base):
local: typing.Optional[LocalProvider]
existing: typing.Optional[ExistingProvider]
google_cloud_platform: typing.Optional[GoogleCloudPlatformProvider]
amazon_web_services: typing.Optional[AmazonWebServicesProvider]
azure: typing.Optional[AzureProvider]
digital_ocean: typing.Optional[DigitalOceanProvider]
local: Optional[LocalProvider]
existing: Optional[ExistingProvider]
google_cloud_platform: Optional[GoogleCloudPlatformProvider]
amazon_web_services: Optional[AmazonWebServicesProvider]
azure: Optional[AzureProvider]
digital_ocean: Optional[DigitalOceanProvider]

@pydantic.root_validator(pre=True)
def check_provider(cls, values):
Expand Down Expand Up @@ -580,20 +575,20 @@ class NodeSelectorKeyValue(schema.Base):
class KubernetesCredentials(schema.Base):
host: str
cluster_ca_certifiate: str
token: typing.Optional[str]
username: typing.Optional[str]
password: typing.Optional[str]
client_certificate: typing.Optional[str]
client_key: typing.Optional[str]
config_path: typing.Optional[str]
config_context: typing.Optional[str]
token: Optional[str]
username: Optional[str]
password: Optional[str]
client_certificate: Optional[str]
client_key: Optional[str]
config_path: Optional[str]
config_context: Optional[str]


class OutputSchema(schema.Base):
node_selectors: Dict[str, NodeSelectorKeyValue]
kubernetes_credentials: KubernetesCredentials
kubeconfig_filename: str
nfs_endpoint: typing.Optional[str]
nfs_endpoint: Optional[str]


class KubernetesInfrastructureStage(NebariTerraformStage):
Expand Down Expand Up @@ -760,7 +755,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
name=self.config.escaped_project_name,
environment=self.config.namespace,
existing_subnet_ids=self.config.amazon_web_services.existing_subnet_ids,
existing_security_group_id=self.config.amazon_web_services.existing_security_group_ids,
existing_security_group_id=self.config.amazon_web_services.existing_security_group_id,
region=self.config.amazon_web_services.region,
kubernetes_version=self.config.amazon_web_services.kubernetes_version,
node_groups=[
Expand All @@ -779,6 +774,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
availability_zones=self.config.amazon_web_services.availability_zones,
vpc_cidr_block=self.config.amazon_web_services.vpc_cidr_block,
permissions_boundary=self.config.amazon_web_services.permissions_boundary,
tags=self.config.amazon_web_services.tags,
).dict()
else:
raise ValueError(f"Unknown provider: {self.config.provider}")
Expand Down
14 changes: 8 additions & 6 deletions src/_nebari/stages/infrastructure/template/aws/locals.tf
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
locals {
additional_tags = {
Project = var.name
Owner = "terraform"
Environment = var.environment
}

additional_tags = merge(
{
Project = var.name
Owner = "terraform"
Environment = var.environment
},
var.tags,
)
cluster_name = "${var.name}-${var.environment}"
}
6 changes: 6 additions & 0 deletions src/_nebari/stages/infrastructure/template/aws/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,9 @@ variable "permissions_boundary" {
type = string
default = null
}

variable "tags" {
description = "Additional tags to add to resources"
type = map(string)
default = {}
}
Loading