diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py index 4f57b72d96aef..c0b6ea5dd553f 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py @@ -35,6 +35,8 @@ from airflow.configuration import conf from airflow.providers.amazon.aws.executors.ecs.utils import ( CONFIG_GROUP_NAME, + ECS_LAUNCH_TYPE_EC2, + ECS_LAUNCH_TYPE_FARGATE, AllEcsConfigKeys, RunTaskKwargsConfigKeys, camelize_dict_keys, @@ -56,13 +58,15 @@ def _fetch_config_values() -> dict[str, str]: def build_task_kwargs() -> dict: + all_config_keys = AllEcsConfigKeys() # This will put some kwargs at the root of the dictionary that do NOT belong there. However, # the code below expects them to be there and will rearrange them as necessary. task_kwargs = _fetch_config_values() task_kwargs.update(_fetch_templated_kwargs()) - has_launch_type: bool = "launch_type" in task_kwargs - has_capacity_provider: bool = "capacity_provider_strategy" in task_kwargs + has_launch_type: bool = all_config_keys.LAUNCH_TYPE in task_kwargs + has_capacity_provider: bool = all_config_keys.CAPACITY_PROVIDER_STRATEGY in task_kwargs + is_launch_type_ec2: bool = task_kwargs.get(all_config_keys.LAUNCH_TYPE, None) == ECS_LAUNCH_TYPE_EC2 if has_capacity_provider and has_launch_type: raise ValueError( @@ -75,7 +79,12 @@ def build_task_kwargs() -> dict: # the final fallback. cluster = EcsHook().conn.describe_clusters(clusters=[task_kwargs["cluster"]])["clusters"][0] if not cluster.get("defaultCapacityProviderStrategy"): - task_kwargs["launch_type"] = "FARGATE" + task_kwargs[all_config_keys.LAUNCH_TYPE] = ECS_LAUNCH_TYPE_FARGATE + + # If you're using the EC2 launch type, you should not/can not provide the platform_version. In this + # case we'll drop it on the floor on behalf of the user, instead of throwing an exception. + if is_launch_type_ec2: + task_kwargs.pop(all_config_keys.PLATFORM_VERSION, None) # There can only be 1 count of these containers task_kwargs["count"] = 1 # type: ignore @@ -105,7 +114,7 @@ def build_task_kwargs() -> dict: "awsvpcConfiguration": { "subnets": str(subnets).split(",") if subnets else None, "securityGroups": str(security_groups).split(",") if security_groups else None, - "assignPublicIp": parse_assign_public_ip(assign_public_ip), + "assignPublicIp": parse_assign_public_ip(assign_public_ip, is_launch_type_ec2), } } ) diff --git a/airflow/providers/amazon/aws/executors/ecs/utils.py b/airflow/providers/amazon/aws/executors/ecs/utils.py index cb730acac9355..10d9162d5a6b5 100644 --- a/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -40,6 +40,9 @@ ExecutorConfigFunctionType = Callable[[CommandType], dict] ExecutorConfigType = Dict[str, Any] +ECS_LAUNCH_TYPE_EC2 = "EC2" +ECS_LAUNCH_TYPE_FARGATE = "FARGATE" + CONFIG_GROUP_NAME = "aws_ecs_executor" CONFIG_DEFAULTS = { @@ -247,9 +250,12 @@ def _recursive_flatten_dict(nested_dict): return dict(items) -def parse_assign_public_ip(assign_public_ip): +def parse_assign_public_ip(assign_public_ip, is_launch_type_ec2=False): """Convert "assign_public_ip" from True/False to ENABLE/DISABLE.""" - return "ENABLED" if assign_public_ip == "True" else "DISABLED" + # If the launch type is EC2, you cannot/should not provide the assignPublicIp parameter (which is + # specific to Fargate) + if not is_launch_type_ec2: + return "ENABLED" if assign_public_ip == "True" else "DISABLED" def camelize_dict_keys(nested_dict) -> dict: