Skip to content

Commit

Permalink
Filter out private subnets for default aws vpc
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor committed Feb 19, 2024
1 parent ad1966c commit 0e26d28
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 24 deletions.
58 changes: 35 additions & 23 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import boto3
import botocore.client
Expand Down Expand Up @@ -113,25 +113,11 @@ def create_instance(
{"Key": "dstack_user", "Value": instance_config.user},
]
try:
subnet_id = None
vpc_id = None
if self.config.vpc_name is not None:
vpc_id = aws_resources.get_vpc_id_by_name(
ec2_client=ec2_client,
vpc_name=self.config.vpc_name,
)
if vpc_id is None:
raise ComputeError(
f"No VPC named {self.config.vpc_name} in region {instance_offer.region}"
)
subnet_id = aws_resources.get_subnet_id_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
)
if subnet_id is None:
raise ComputeError(
f"Failed to find public subnet for VPC {self.config.vpc_name} in region {instance_offer.region}"
)
vpc_id, subnet_id = _get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
vpc_name=self.config.vpc_name,
region=instance_offer.region,
)
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
response = ec2.create_instances(
**aws_resources.create_instances_struct(
Expand Down Expand Up @@ -159,7 +145,6 @@ def create_instance(
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address

if instance_offer.instance.resources.spot: # it will not terminate the instance
ec2_client.cancel_spot_instance_requests(
SpotInstanceRequestIds=[instance.spot_instance_request_id]
Expand Down Expand Up @@ -224,15 +209,15 @@ def create_gateway(
user_data=get_gateway_user_data(ssh_key_pub),
tags=tags,
security_group_id=aws_resources.create_gateway_security_group(
ec2_client, project_id
ec2_client=ec2_client,
project_id=project_id,
),
spot=False,
)
)
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address

return LaunchedGatewayInfo(
instance_id=instance.instance_id,
region=region,
Expand All @@ -253,3 +238,30 @@ def _supported_instances(offer: InstanceOffer) -> bool:
if offer.instance.name.startswith(family):
return True
return False


def _get_vpc_id_subnet_id_or_error(
ec2_client: botocore.client.BaseClient,
vpc_name: Optional[str],
region: str,
) -> Tuple[str, str]:
if vpc_name is not None:
vpc_id = aws_resources.get_vpc_id_by_name(
ec2_client=ec2_client,
vpc_name=vpc_name,
)
if vpc_id is None:
raise ComputeError(f"No VPC named {vpc_name} in region {region}")
else:
vpc_id = aws_resources.get_default_vpc_id(ec2_client=ec2_client)
if vpc_id is None:
raise ComputeError(f"No default VPC in region {region}")
subnet_id = aws_resources.get_subnet_id_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
)
if subnet_id is not None:
return vpc_id, subnet_id
if vpc_name is not None:
raise ComputeError(f"Failed to find public subnet for VPC {vpc_name} in region {region}")
raise ComputeError(f"Failed to find public subnet for default VPC in region {region}")
33 changes: 32 additions & 1 deletion src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,13 @@ def get_vpc_id_by_name(
return response["Vpcs"][0]["VpcId"]


def get_default_vpc_id(ec2_client: botocore.client.BaseClient) -> Optional[str]:
response = ec2_client.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}])
if "Vpcs" in response and len(response["Vpcs"]) > 0:
return response["Vpcs"][0]["VpcId"]
return None


def get_subnet_id_for_vpc(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
Expand All @@ -318,7 +325,10 @@ def get_subnet_id_for_vpc(
# Return first public subnet
for subnet in subnets:
subnet_id = subnet["SubnetId"]
if _is_public_subnet(ec2_client=ec2_client, subnet_id=subnet_id):
is_public_subnet = _is_public_subnet(
ec2_client=ec2_client, vpc_id=vpc_id, subnet_id=subnet_id
)
if is_public_subnet:
return subnet_id
return None

Expand All @@ -333,15 +343,36 @@ def _get_subnets_by_vpc_id(

def _is_public_subnet(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
subnet_id: str,
) -> bool:
# Public subnet – The subnet has a direct route to an internet gateway.
# Private subnet – The subnet does not have a direct route to an internet gateway.

# Check explicitly associated route tables
response = ec2_client.describe_route_tables(
Filters=[{"Name": "association.subnet-id", "Values": [subnet_id]}]
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "GatewayId" in route and route["GatewayId"].startswith("igw-"):
return True

# Main route table controls the routing of all subnetes
# that are not explicitly associated with any other route table.
if len(response["RouteTables"]) > 0:
return False

# Check implicitly associated main route table
response = ec2_client.describe_route_tables(
Filters=[
{"Name": "association.main", "Values": ["true"]},
{"Name": "vpc-id", "Values": [vpc_id]},
]
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "GatewayId" in route and route["GatewayId"].startswith("igw-"):
return True

return False

0 comments on commit 0e26d28

Please sign in to comment.