diff --git a/api/src/data_migration/command/load_transform.py b/api/src/data_migration/command/load_transform.py index a213aeadf..7f41c5d17 100644 --- a/api/src/data_migration/command/load_transform.py +++ b/api/src/data_migration/command/load_transform.py @@ -10,6 +10,7 @@ import src.adapters.db.flask_db as flask_db import src.db.foreign import src.db.models.staging +from src.task.ecs_background_task import ecs_background_task from src.task.opportunities.set_current_opportunities_task import SetCurrentOpportunitiesTask from ..data_migration_blueprint import data_migration_blueprint @@ -32,6 +33,7 @@ ) @click.option("--tables-to-load", "-t", help="table to load", multiple=True) @flask_db.with_db_session() +@ecs_background_task(task_name="load-transform") def load_transform( db_session: db.Session, load: bool, diff --git a/api/src/logging/flask_logger.py b/api/src/logging/flask_logger.py index 6bddcee90..4a848a48a 100644 --- a/api/src/logging/flask_logger.py +++ b/api/src/logging/flask_logger.py @@ -18,6 +18,7 @@ """ import logging +import os import time import uuid @@ -26,6 +27,8 @@ logger = logging.getLogger(__name__) EXTRA_LOG_DATA_ATTR = "extra_log_data" +_GLOBAL_LOG_CONTEXT: dict = {} + def init_app(app_logger: logging.Logger, app: flask.Flask) -> None: """Initialize the Flask app logger. @@ -50,7 +53,7 @@ def init_app(app_logger: logging.Logger, app: flask.Flask) -> None: # set on the ancestors. # See https://docs.python.org/3/library/logging.html#logging.Logger.propagate for handler in app_logger.handlers: - handler.addFilter(_add_app_context_info_to_log_record) + handler.addFilter(_add_global_context_info_to_log_record) handler.addFilter(_add_request_context_info_to_log_record) # Add request context data to every log record for the current request @@ -63,6 +66,11 @@ def init_app(app_logger: logging.Logger, app: flask.Flask) -> None: app.before_request(_log_start_request) app.after_request(_log_end_request) + # Add some metadata to all log messages globally + add_extra_data_to_global_logs( + {"app.name": app.name, "environment": os.environ.get("ENVIRONMENT")} + ) + app_logger.info("initialized flask logger") @@ -77,6 +85,12 @@ def add_extra_data_to_current_request_logs( setattr(flask.g, EXTRA_LOG_DATA_ATTR, extra_log_data) +def add_extra_data_to_global_logs(data: dict[str, str | int | float | bool | None]) -> None: + """Add metadata to all logs for the rest of the lifecycle of this app process""" + global _GLOBAL_LOG_CONTEXT + _GLOBAL_LOG_CONTEXT.update(data) + + def _track_request_start_time() -> None: """Store the request start time in flask.g""" flask.g.request_start_time = time.perf_counter() @@ -117,20 +131,6 @@ def _log_end_request(response: flask.Response) -> flask.Response: return response -def _add_app_context_info_to_log_record(record: logging.LogRecord) -> bool: - """Add app context data to the log record. - - If there is no app context, then do not add any data. - """ - if not flask.has_app_context(): - return True - - assert flask.current_app is not None - record.__dict__ |= _get_app_context_info(flask.current_app) - - return True - - def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool: """Add request context data to the log record. @@ -146,8 +146,11 @@ def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool: return True -def _get_app_context_info(app: flask.Flask) -> dict: - return {"app.name": app.name} +def _add_global_context_info_to_log_record(record: logging.LogRecord) -> bool: + global _GLOBAL_LOG_CONTEXT + record.__dict__ |= _GLOBAL_LOG_CONTEXT + + return True def _get_request_context_info(request: flask.Request) -> dict: diff --git a/api/src/search/backend/load_search_data.py b/api/src/search/backend/load_search_data.py index 5b82e5a6d..38f26a7f4 100644 --- a/api/src/search/backend/load_search_data.py +++ b/api/src/search/backend/load_search_data.py @@ -5,6 +5,7 @@ from src.adapters.db import flask_db from src.search.backend.load_opportunities_to_index import LoadOpportunitiesToIndex from src.search.backend.load_search_data_blueprint import load_search_data_blueprint +from src.task.ecs_background_task import ecs_background_task @load_search_data_blueprint.cli.command( @@ -16,6 +17,7 @@ help="Whether to run a full refresh, or only incrementally update oppportunities", ) @flask_db.with_db_session() +@ecs_background_task(task_name="load-opportunity-data-opensearch") def load_opportunity_data(db_session: db.Session, full_refresh: bool) -> None: search_client = search.SearchClient() diff --git a/api/src/task/ecs_background_task.py b/api/src/task/ecs_background_task.py new file mode 100644 index 000000000..3b7fa2a66 --- /dev/null +++ b/api/src/task/ecs_background_task.py @@ -0,0 +1,137 @@ +import contextlib +import logging +import os +import time +import uuid +from functools import wraps +from typing import Callable, Generator, ParamSpec, TypeVar + +import requests + +from src.logging.flask_logger import add_extra_data_to_global_logs + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +T = TypeVar("T") + + +def ecs_background_task(task_name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: + """ + Decorator for any ECS Task entrypoint function. + + This encapsulates the setup required by all ECS tasks, making it easy to: + - add new shared initialization steps for logging + - write new ECS task code without thinking about the boilerplate + + Usage: + + TASK_NAME = "my-cool-task" + + @task_blueprint.cli.command(TASK_NAME, help="For running my cool task") + @ecs_background_task(TASK_NAME) + @flask_db.with_db_session() + def entrypoint(db_session: db.Session): + do_cool_stuff() + + Parameters: + task_name (str): Name of the ECS task + + IMPORTANT: Do not specify this decorator before the task command. + Click effectively rewrites your function to be a main function + and any decorators from before the "task_blueprint.cli.command(...)" + line are discarded. + See: https://click.palletsprojects.com/en/8.1.x/quickstart/#basic-concepts-creating-a-command + """ + + def decorator(f: Callable[P, T]) -> Callable[P, T]: + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with _ecs_background_task_impl(task_name): + return f(*args, **kwargs) + + return wrapper + + return decorator + + +@contextlib.contextmanager +def _ecs_background_task_impl(task_name: str) -> Generator[None, None, None]: + # The actual implementation, see the docs on the + # decorator method above for details on usage + + start = time.perf_counter() + _add_log_metadata(task_name) + + # initialize new relic here when we add that + + logger.info("Starting ECS task %s", task_name) + + try: + yield + except Exception: + # We want to make certain that any exception will always + # be logged as an error + # logger.exception is just an alias for logger.error(, exc_info=True) + logger.exception("ECS task failed", extra={"status": "error"}) + raise + + end = time.perf_counter() + duration = round((end - start), 3) + logger.info( + "Completed ECS task %s", + task_name, + extra={"ecs_task_duration_sec": duration, "status": "success"}, + ) + + +def _add_log_metadata(task_name: str) -> None: + # Note we set an "aws.ecs.task_name" as well pulled from ECS + # which may be different as that value is set based on our infra setup + # while this one is just based on whatever we passed the @ecs_background_task decorator + add_extra_data_to_global_logs({"task_name": task_name, "task_uuid": str(uuid.uuid4())}) + add_extra_data_to_global_logs(_get_ecs_metadata()) + + +def _get_ecs_metadata() -> dict: + """ + Retrieves ECS metadata from an AWS-provided metadata URI. This URI is injected to all ECS tasks by AWS as an envar. + See https://docs.aws.amazon.com/AmazonECS/latest/userguide/task-metadata-endpoint-v4-fargate.html for more. + """ + ecs_metadata_uri = os.environ.get("ECS_CONTAINER_METADATA_URI_V4") + + if os.environ.get("ENVIRONMENT", "local") == "local" or ecs_metadata_uri is None: + logger.info( + "ECS metadata not available for local environments. Run this task on ECS to see metadata." + ) + return {} + + task_metadata = requests.get(ecs_metadata_uri, timeout=1) # 1sec timeout + logger.info("Retrieved task metadata from ECS") + metadata_json = task_metadata.json() + + ecs_task_name = metadata_json["Name"] + ecs_task_id = metadata_json["Labels"]["com.amazonaws.ecs.task-arn"].split("/")[-1] + ecs_taskdef = ":".join( + [ + metadata_json["Labels"]["com.amazonaws.ecs.task-definition-family"], + metadata_json["Labels"]["com.amazonaws.ecs.task-definition-version"], + ] + ) + cloudwatch_log_group = metadata_json["LogOptions"]["awslogs-group"] + cloudwatch_log_stream = metadata_json["LogOptions"]["awslogs-stream"] + + # Step function only + sfn_execution_id = os.environ.get("SFN_EXECUTION_ID") + sfn_id = sfn_execution_id.split(":")[-2] if sfn_execution_id is not None else None + + return { + "aws.ecs.task_name": ecs_task_name, + "aws.ecs.task_id": ecs_task_id, + "aws.ecs.task_definition": ecs_taskdef, + # these will be added automatically by New Relic log ingester, but + # just to be sure and for non-log usages, explicitly declare them + "aws.cloudwatch.log_group": cloudwatch_log_group, + "aws.cloudwatch.log_stream": cloudwatch_log_stream, + "aws.step_function.id": sfn_id, + } diff --git a/api/src/task/opportunities/export_opportunity_data_task.py b/api/src/task/opportunities/export_opportunity_data_task.py index 6c729dde8..5ae6d2b66 100644 --- a/api/src/task/opportunities/export_opportunity_data_task.py +++ b/api/src/task/opportunities/export_opportunity_data_task.py @@ -14,6 +14,7 @@ from src.api.opportunities_v1.opportunity_schemas import OpportunityV1Schema from src.db.models.opportunity_models import CurrentOpportunitySummary, Opportunity from src.services.opportunities_v1.opportunity_to_csv import opportunities_to_csv +from src.task.ecs_background_task import ecs_background_task from src.task.task import Task from src.task.task_blueprint import task_blueprint from src.util.datetime_util import get_now_us_eastern_datetime @@ -27,6 +28,7 @@ help="Generate JSON and CSV files containing an export of all opportunity data", ) @flask_db.with_db_session() +@ecs_background_task(task_name="export-opportunity-data") def export_opportunity_data(db_session: db.Session) -> None: ExportOpportunityDataTask(db_session).run() diff --git a/api/tests/src/task/test_ecs_background_task.py b/api/tests/src/task/test_ecs_background_task.py new file mode 100644 index 000000000..46759fcd2 --- /dev/null +++ b/api/tests/src/task/test_ecs_background_task.py @@ -0,0 +1,59 @@ +import logging +import time + +import pytest + +from src.logging.flask_logger import add_extra_data_to_global_logs +from src.task.ecs_background_task import ecs_background_task + + +def test_ecs_background_task(app, caplog): + # We pull in the app so its initialized + # Global logging params like the task name are stored on the app + caplog.set_level(logging.INFO) + + @ecs_background_task(task_name="my_test_task_name") + def my_test_func(param1, param2): + # Add a brief sleep so that we can test the duration logic + time.sleep(0.2) # 0.2s + add_extra_data_to_global_logs({"example_param": 12345}) + + return param1 + param2 + + # Verify the function works uneventfully + assert my_test_func(1, 2) == 3 + + for record in caplog.records: + extra = record.__dict__ + assert extra["task_name"] == "my_test_task_name" + + last_record = caplog.records[-1].__dict__ + # Make sure the ECS task duration was tracked + allowed_error = 0.1 + assert last_record["ecs_task_duration_sec"] == pytest.approx(0.2, abs=allowed_error) + # Make sure the extra we added was put in this automatically + assert last_record["example_param"] == 12345 + assert last_record["message"] == "Completed ECS task my_test_task_name" + + +def test_ecs_background_task_when_erroring(app, caplog): + caplog.set_level(logging.INFO) + + @ecs_background_task(task_name="my_error_test_task_name") + def my_test_error_func(): + add_extra_data_to_global_logs({"another_param": "hello"}) + + raise ValueError("I am an error") + + with pytest.raises(ValueError, match="I am an error"): + my_test_error_func() + + for record in caplog.records: + extra = record.__dict__ + assert extra["task_name"] == "my_error_test_task_name" + + last_record = caplog.records[-1].__dict__ + + assert last_record["another_param"] == "hello" + assert last_record["levelname"] == "ERROR" + assert last_record["message"] == "ECS task failed"