Skip to content

Commit

Permalink
Add template fields tests to aws operators (2) (apache#42202)
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan authored and joaopamaral committed Oct 21, 2024
1 parent ddd2042 commit 6e18da7
Show file tree
Hide file tree
Showing 26 changed files with 499 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/providers/amazon/aws/operators/test_eventbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
EventBridgePutEventsOperator,
EventBridgePutRuleOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from unittest.mock import MagicMock
Expand Down Expand Up @@ -96,6 +97,13 @@ def test_failed_to_send(self, mock_conn: MagicMock):
with pytest.raises(AirflowException):
operator.execute(context={})

def test_template_fields(self):
operator = EventBridgePutEventsOperator(
task_id="failed_put_events_job",
entries=ENTRIES,
)
validate_template_fields(operator)


class TestEventBridgePutRuleOperator:
def test_init(self):
Expand Down Expand Up @@ -150,6 +158,12 @@ def test_put_rule_with_bad_json_fails(self):
with pytest.raises(ValueError):
operator.execute(None)

def test_template_fields(self):
operator = EventBridgePutRuleOperator(
task_id="events_put_rule_job", name=RULE_NAME, event_pattern=EVENT_PATTERN
)
validate_template_fields(operator)


class TestEventBridgeEnableRuleOperator:
def test_init(self):
Expand Down Expand Up @@ -186,6 +200,13 @@ def test_enable_rule(self, mock_conn: MagicMock):
enable_rule.execute(context={})
mock_conn.enable_rule.assert_called_with(Name=RULE_NAME)

def test_template_fields(self):
operator = EventBridgeEnableRuleOperator(
task_id="events_enable_rule_job",
name=RULE_NAME,
)
validate_template_fields(operator)


class TestEventBridgeDisableRuleOperator:
def test_init(self):
Expand Down Expand Up @@ -221,3 +242,10 @@ def test_disable_rule(self, mock_conn: MagicMock):

disable_rule.execute(context={})
mock_conn.disable_rule.assert_called_with(Name=RULE_NAME)

def test_template_fields(self):
operator = EventBridgeDisableRuleOperator(
task_id="events_disable_rule_job",
name=RULE_NAME,
)
validate_template_fields(operator)
9 changes: 9 additions & 0 deletions tests/providers/amazon/aws/operators/test_glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GlacierCreateJobOperator,
GlacierUploadArchiveOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
Expand Down Expand Up @@ -78,6 +79,10 @@ def test_execute(self, hook_mock):
op.execute(mock.MagicMock())
hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME)

def test_template_fields(self):
operator = self.op_class(**self.default_op_kwargs)
validate_template_fields(operator)


class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests):
op_class = GlacierUploadArchiveOperator
Expand All @@ -97,3 +102,7 @@ def test_execute(self):
body=b"Test Data",
checksum=None,
)

def test_template_fields(self):
operator = self.op_class(**self.default_op_kwargs)
validate_template_fields(operator)
24 changes: 24 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GlueDataQualityRuleSetEvaluationRunOperator,
GlueJobOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.models import TaskInstance
Expand Down Expand Up @@ -307,6 +308,17 @@ def test_replace_script_file(
"folder/file", "artifacts/glue-scripts/file", bucket_name="bucket_name", replace=True
)

def test_template_fields(self):
operator = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="folder/file",
s3_bucket="bucket_name",
iam_role_name="role_arn",
replace_script_file=True,
)
validate_template_fields(operator)


class TestGlueDataQualityOperator:
RULE_SET_NAME = "TestRuleSet"
Expand Down Expand Up @@ -435,6 +447,12 @@ def test_validate_inputs_error(self):
with pytest.raises(AttributeError, match="RuleSet must starts with Rules = \\[ and ends with \\]"):
self.operator.validate_inputs()

def test_template_fields(self):
operator = GlueDataQualityOperator(
task_id="create_data_quality_ruleset", name=self.RULE_SET_NAME, ruleset=self.RULE_SET
)
validate_template_fields(operator)


class TestGlueDataQualityRuleSetEvaluationRunOperator:
RUN_ID = "1234567890"
Expand Down Expand Up @@ -538,6 +556,9 @@ def test_start_data_quality_ruleset_evaluation_run_wait_combinations(
assert glue_data_quality_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)


class TestGlueDataQualityRuleRecommendationRunOperator:
RUN_ID = "1234567890"
Expand Down Expand Up @@ -643,3 +664,6 @@ def test_start_data_quality_rule_recommendation_run_wait_combinations(
assert response == self.RUN_ID
assert glue_data_quality_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)
4 changes: 4 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
Expand Down Expand Up @@ -173,3 +174,6 @@ def test_crawler_wait_combinations(self, _, wait_for_completion, deferrable, moc
assert response == mock_crawler_name
assert crawler_hook.get_waiter.call_count == wait_for_completion
assert self.op.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.op)
5 changes: 5 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
from airflow.providers.amazon.aws.operators.glue_databrew import GlueDataBrewStartJobOperator
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

JOB_NAME = "test_job"

Expand Down Expand Up @@ -101,3 +102,7 @@ def test_start_job_with_deprecation_parameters(self, mock_hook_get_waiter, mock_
assert operator.waiter_delay == 15
operator.execute(None)
mock_hook_get_waiter.assert_not_called()

def test_template_fields(self):
operator = GlueDataBrewStartJobOperator(task_id="fake_task_id", job_name=JOB_NAME)
validate_template_fields(operator)
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/operators/test_kinesis_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
KinesisAnalyticsV2StartApplicationOperator,
KinesisAnalyticsV2StopApplicationOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
Expand Down Expand Up @@ -159,6 +160,15 @@ def test_create_application_throw_error_when_invalid_arguments_provided(
with pytest.raises(AirflowException, match=error_message):
operator.execute({})

def test_template_fields(self):
operator = KinesisAnalyticsV2CreateApplicationOperator(
task_id="create_application_operator",
application_name="demo",
runtime_environment="FLINK_18_9",
service_execution_role="arn",
)
validate_template_fields(operator)


class TestKinesisAnalyticsV2StartApplicationOperator:
APPLICATION_ARN = "arn:aws:kinesisanalytics:us-east-1:123456789012:application/demo"
Expand Down Expand Up @@ -327,6 +337,9 @@ def test_execute_complete_failure(self, kinesis_analytics_mock_conn):
):
self.operator.execute_complete(context=None, event=event)

def test_template_fields(self):
validate_template_fields(self.operator)


class TestKinesisAnalyticsV2StopApplicationOperator:
APPLICATION_ARN = "arn:aws:kinesisanalytics:us-east-1:123456789012:application/demo"
Expand Down Expand Up @@ -483,3 +496,6 @@ def test_execute_complete_failure(self, kinesis_analytics_mock_conn):
AirflowException, match="Error while stopping AWS Managed Service for Apache Flink application"
):
self.operator.execute_complete(context=None, event=event)

def test_template_fields(self):
validate_template_fields(self.operator)
19 changes: 19 additions & 0 deletions tests/providers/amazon/aws/operators/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LambdaCreateFunctionOperator,
LambdaInvokeFunctionOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

FUNCTION_NAME = "function_name"
PAYLOADS = [
Expand Down Expand Up @@ -160,6 +161,17 @@ def test_create_lambda_using_config_argument(self, mock_hook_conn, mock_hook_cre
assert operator.config.get("snap_start") == config.get("snap_start")
assert operator.config.get("ephemeral_storage") == config.get("ephemeral_storage")

def test_template_fields(self):
operator = LambdaCreateFunctionOperator(
task_id="task_test",
function_name=FUNCTION_NAME,
role=ROLE_ARN,
code={
"ImageUri": IMAGE_URI,
},
)
validate_template_fields(operator)


class TestLambdaInvokeFunctionOperator:
@pytest.mark.parametrize("payload", PAYLOADS)
Expand Down Expand Up @@ -280,3 +292,10 @@ def test_invoke_lambda_function_error(self, hook_mock):

with pytest.raises(ValueError):
operator.execute(None)

def test_template_fields(self):
operator = LambdaInvokeFunctionOperator(
task_id="task_test",
function_name="a",
)
validate_template_fields(operator)
21 changes: 21 additions & 0 deletions tests/providers/amazon/aws/operators/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
NeptuneStartDbClusterOperator,
NeptuneStopDbClusterOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

CLUSTER_ID = "test_cluster"

Expand Down Expand Up @@ -201,6 +202,16 @@ def test_start_cluster_instances_not_ready_defer(self, mock_conn, mock_defer):
# mock_defer.assert_has_calls(calls)
assert mock_defer.call_count == 2

def test_template_fields(self):
operator = NeptuneStartDbClusterOperator(
task_id="task_test",
db_cluster_id=CLUSTER_ID,
deferrable=True,
wait_for_completion=False,
aws_conn_id="aws_default",
)
validate_template_fields(operator)


class TestNeptuneStopClusterOperator:
@mock.patch.object(NeptuneHook, "conn")
Expand Down Expand Up @@ -368,3 +379,13 @@ def test_stop_cluster_deferrable(self, mock_conn):

with pytest.raises(TaskDeferred):
operator.execute(None)

def test_template_fields(self):
operator = NeptuneStopDbClusterOperator(
task_id="task_test",
db_cluster_id=CLUSTER_ID,
deferrable=True,
wait_for_completion=False,
aws_conn_id="aws_default",
)
validate_template_fields(operator)
5 changes: 5 additions & 0 deletions tests/providers/amazon/aws/operators/test_quicksight.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
from airflow.providers.amazon.aws.operators.quicksight import QuickSightCreateIngestionOperator
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

DATA_SET_ID = "DemoDataSet"
INGESTION_ID = "DemoDataSet_Ingestion"
Expand Down Expand Up @@ -80,3 +81,7 @@ def test_execute(self, mock_create_ingestion):
wait_for_completion=True,
check_interval=30,
)

def test_template_fields(self):
operator = QuickSightCreateIngestionOperator(**self.default_op_kwargs)
validate_template_fields(operator)
Loading

0 comments on commit 6e18da7

Please sign in to comment.