diff --git a/aws_lambda_powertools/metrics/base.py b/aws_lambda_powertools/metrics/base.py index 2c45aa1fb3e..5e3b9c84733 100644 --- a/aws_lambda_powertools/metrics/base.py +++ b/aws_lambda_powertools/metrics/base.py @@ -378,11 +378,11 @@ def handler(event, context): ) @functools.wraps(lambda_handler) - def decorate(event, context): + def decorate(event, context, *args, **kwargs): try: if default_dimensions: self.set_default_dimensions(**default_dimensions) - response = lambda_handler(event, context) + response = lambda_handler(event, context, *args, **kwargs) if capture_cold_start_metric: self._add_cold_start_metric(context=context) finally: diff --git a/aws_lambda_powertools/metrics/provider/base.py b/aws_lambda_powertools/metrics/provider/base.py index 702b4b3d2ba..edcc8e07ec3 100644 --- a/aws_lambda_powertools/metrics/provider/base.py +++ b/aws_lambda_powertools/metrics/provider/base.py @@ -199,9 +199,9 @@ def handler(event, context): ) @functools.wraps(lambda_handler) - def decorate(event, context): + def decorate(event, context, *args, **kwargs): try: - response = lambda_handler(event, context) + response = lambda_handler(event, context, *args, **kwargs) if capture_cold_start_metric: self._add_cold_start_metric(context=context) finally: diff --git a/tests/functional/metrics/test_metrics_cloudwatch_emf.py b/tests/functional/metrics/test_metrics_cloudwatch_emf.py index 5c4a1de1128..d3da81798b6 100644 --- a/tests/functional/metrics/test_metrics_cloudwatch_emf.py +++ b/tests/functional/metrics/test_metrics_cloudwatch_emf.py @@ -355,6 +355,22 @@ def lambda_handler(evt, context): assert lambda_handler({}, {}) is True +def test_log_metrics_decorator_with_additional_handler_args(namespace, service): + # GIVEN Metrics is initialized + my_metrics = Metrics(service=service, namespace=namespace) + + # WHEN log_metrics is used to serialize metrics + # AND the wrapped function uses additional parameters + @my_metrics.log_metrics + def lambda_handler(evt, context, additional_arg, additional_kw_arg="default_value"): + return additional_arg, additional_kw_arg + + # THEN the decorator should not raise any errors when + # the wrapped function is passed additional arguments + assert lambda_handler({}, {}, "arg_value", additional_kw_arg="kw_arg_value") == ("arg_value", "kw_arg_value") + assert lambda_handler({}, {}, "arg_value") == ("arg_value", "default_value") + + def test_schema_validation_incorrect_metric_resolution(metric, dimension): # GIVEN we pass a metric resolution that is not supported by CloudWatch metric["resolution"] = 10 # metric resolution must be 1 (High) or 60 (Standard) diff --git a/tests/functional/metrics/test_metrics_datadog.py b/tests/functional/metrics/test_metrics_datadog.py index 0900bb851b4..abedfd99424 100644 --- a/tests/functional/metrics/test_metrics_datadog.py +++ b/tests/functional/metrics/test_metrics_datadog.py @@ -136,6 +136,22 @@ def lambda_handler(evt, context): ) +def test_datadog_log_metrics_decorator_with_additional_handler_args(): + # GIVEN DatadogMetrics is initialized + my_metrics = DatadogMetrics(flush_to_log=True) + + # WHEN log_metrics is used to serialize metrics + # AND the wrapped function uses additional parameters + @my_metrics.log_metrics + def lambda_handler(evt, context, additional_arg, additional_kw_arg="default_value"): + return additional_arg, additional_kw_arg + + # THEN the decorator should not raise any errors when + # the wrapped function is passed additional arguments + assert lambda_handler({}, {}, "arg_value", additional_kw_arg="kw_arg_value") == ("arg_value", "kw_arg_value") + assert lambda_handler({}, {}, "arg_value") == ("arg_value", "default_value") + + def test_metrics_with_default_namespace(capsys, namespace): # GIVEN DatadogMetrics is initialized with default namespace metrics = DatadogMetrics(flush_to_log=True) diff --git a/tests/functional/metrics/test_metrics_provider.py b/tests/functional/metrics/test_metrics_provider.py index 2ed84a23a21..c9b627c1709 100644 --- a/tests/functional/metrics/test_metrics_provider.py +++ b/tests/functional/metrics/test_metrics_provider.py @@ -60,3 +60,19 @@ def lambda_handler(evt, context): # THEN log_metrics should invoke the function it decorates # and return no error if we have a namespace and dimension assert lambda_handler({}, {}) is True + + +def test_metrics_provider_class_decorator_with_additional_handler_args(): + # GIVEN Metrics is initialized + my_metrics = Metrics() + + # WHEN log_metrics is used to serialize metrics + # AND the wrapped function uses additional parameters + @my_metrics.log_metrics + def lambda_handler(evt, context, additional_arg, additional_kw_arg="default_value"): + return additional_arg, additional_kw_arg + + # THEN the decorator should not raise any errors when + # the wrapped function is passed additional arguments + assert lambda_handler({}, {}, "arg_value", additional_kw_arg="kw_arg_value") == ("arg_value", "kw_arg_value") + assert lambda_handler({}, {}, "arg_value") == ("arg_value", "default_value")