diff --git a/dbt_common/events/base_types.py b/dbt_common/events/base_types.py index 78b0368..f5a30f1 100644 --- a/dbt_common/events/base_types.py +++ b/dbt_common/events/base_types.py @@ -151,6 +151,22 @@ def msg_from_base_event(event: BaseEvent, level: Optional[EventLevel] = None): return new_event +def msg_to_dict(msg: EventMsg) -> dict: + msg_dict = MessageToDict( + msg, + preserving_proto_field_name=True, + including_default_value_fields=True, # type: ignore + ) + # We don't want an empty NodeInfo in output + if ( + "data" in msg_dict + and "node_info" in msg_dict["data"] + and msg_dict["data"]["node_info"]["node_name"] == "" + ): + del msg_dict["data"]["node_info"] + return msg_dict + + # DynamicLevel requires that the level be supplied on the # event construction call using the "info" function from functions.py class DynamicLevel(BaseEvent): diff --git a/dbt_common/events/functions.py b/dbt_common/events/functions.py index f7984dc..27e86cd 100644 --- a/dbt_common/events/functions.py +++ b/dbt_common/events/functions.py @@ -5,17 +5,20 @@ import sys from typing import Any, Callable, Dict, Optional, TextIO, Union -from google.protobuf.json_format import MessageToDict -from snowplow_tracker import Subject from snowplow_tracker.typing import FailureCallback from dbt_common.helper_types import WarnErrorOptions from dbt_common.invocation import get_invocation_id -from dbt_common.events.base_types import BaseEvent, EventLevel, EventMsg +from dbt_common.events.base_types import ( + BaseEvent, + EventLevel, + EventMsg, + msg_to_dict as _msg_to_dict, +) from dbt_common.events.cookie import Cookie from dbt_common.events.event_manager_client import get_event_manager from dbt_common.events.logger import LoggerConfig, LineFormat -from dbt_common.events.tracker import TrackerConfig +from dbt_common.events.tracker import FileTracker, SnowplowTracker, Tracker, TrackerConfig from dbt_common.events.types import DisableTracking, Note from dbt_common.events.user import User from dbt_common.exceptions import EventCompilationError, scrub_secrets, env_secrets @@ -117,26 +120,14 @@ def msg_to_json(msg: EventMsg) -> str: def msg_to_dict(msg: EventMsg) -> dict: - msg_dict = dict() try: - msg_dict = MessageToDict( - msg, - preserving_proto_field_name=True, - including_default_value_fields=True, # type: ignore - ) + return _msg_to_dict(msg) except Exception as exc: event_type = type(msg).__name__ fire_event( Note(msg=f"type {event_type} is not serializable. {str(exc)}"), level=EventLevel.WARN ) - # We don't want an empty NodeInfo in output - if ( - "data" in msg_dict - and "node_info" in msg_dict["data"] - and msg_dict["data"]["node_info"]["node_name"] == "" - ): - del msg_dict["data"]["node_info"] - return msg_dict + return {} def warn_or_error(event, node=None) -> None: @@ -190,32 +181,61 @@ def _default_on_failure(num_ok, unsent): fire_event(DisableTracking()) -def snowplow_config( +def tracker_factory( + user: User, + endpoint: Optional[str], + protocol: Optional[str] = "https", + on_failure: Optional[FailureCallback] = _default_on_failure, + name: Optional[str] = None, + output_file_name: Optional[str] = None, + output_file_max_bytes: Optional[int] = None, +) -> Tracker: + if all([user, endpoint]): + return snowplow_tracker(user, endpoint, protocol, on_failure) + elif all([user, name, output_file_name]): + return file_tracker(user, name, output_file_name, output_file_max_bytes) + raise Exception("Invalid tracking configuration.") + + +def snowplow_tracker( user: User, endpoint: str, protocol: Optional[str] = "https", on_failure: Optional[FailureCallback] = _default_on_failure, -) -> TrackerConfig: - return TrackerConfig( +) -> Tracker: + config = TrackerConfig( invocation_id=user.invocation_id, endpoint=endpoint, protocol=protocol, on_failure=on_failure, ) + return SnowplowTracker.from_config(config) -def enable_tracking(tracker, user: User): +def file_tracker( + user: User, + name: str, + output_file_name: str, + output_file_max_bytes: Optional[int], +) -> Tracker: + config = TrackerConfig( + invocation_id=user.invocation_id, + name=name, + output_file_name=output_file_name, + output_file_max_bytes=output_file_max_bytes, + ) + return FileTracker.from_config(config) + + +def enable_tracking(tracker: Tracker, user: User): cookie = _get_cookie(user) user.enable_tracking(cookie) - - subject = Subject() - subject.set_user_id(cookie.get("id")) - tracker.set_subject(subject) + tracker.enable_tracking(cookie) -def disable_tracking(tracker, user: User): +def disable_tracking(tracker: Tracker, user: User): user.disable_tracking() - tracker.set_subject(None) + tracker.disable_tracking() def _get_cookie(user: User) -> Dict[str, Any]: @@ -240,3 +260,15 @@ def _set_cookie(user: User) -> Dict[str, Any]: user.cookie = cookie.as_dict() return user.cookie return {} + + +def track(tracker: Tracker, user: User, msg: EventMsg) -> None: + if user.do_not_track: + return + + # fire_event(SendingEvent(kwargs=str(**msg_to_dict(msg)))) + try: + tracker.track(msg) + except Exception: + # fire_event(SendEventFailure()) + pass diff --git a/dbt_common/events/tracker.py b/dbt_common/events/tracker.py index 43b700a..a99f3af 100644 --- a/dbt_common/events/tracker.py +++ b/dbt_common/events/tracker.py @@ -1,71 +1,121 @@ from dataclasses import dataclass import logging from logging.handlers import RotatingFileHandler -from typing import Optional +from typing import Any, Dict, Optional, Protocol, Self -from snowplow_tracker import Emitter, Tracker +import snowplow_tracker from snowplow_tracker.typing import FailureCallback -from dbt_common.events.base_types import EventMsg +from dbt_common.events.base_types import EventMsg, msg_to_dict +from dbt_common.events.format import timestamp_to_datetime_string @dataclass class TrackerConfig: invocation_id: Optional[str] = None + msg_schemas: Optional[Dict[str, str]] = None endpoint: Optional[str] = None - protocol: Optional[str] = None + protocol: Optional[str] = "https" on_failure: Optional[FailureCallback] = None name: Optional[str] = None output_file_name: Optional[str] = None output_file_max_bytes: Optional[int] = 10 * 1024 * 1024 # 10 mb -class _Tracker: - def __init__(self, config: TrackerConfig) -> None: - self.invocation_id: Optional[str] = config.invocation_id +class Tracker(Protocol): + @classmethod + def from_config(cls, config: TrackerConfig) -> Self: + ... - if all([config.name, config.output_file_name]): - file_handler = RotatingFileHandler( - filename=str(config.output_file_name), - encoding="utf8", - maxBytes=config.output_file_max_bytes, # type: ignore - backupCount=5, - ) - self._tracker = self._python_file_logger(config.name, file_handler) + def track(self, msg: EventMsg) -> None: + ... - elif all([config.endpoint, config.protocol]): - self._tracker = self._snowplow_tracker(config.endpoint, config.protocol) + def enable_tracking(self, cookie: Dict[str, Any]) -> None: + ... - def track(self, msg: EventMsg) -> str: - raise NotImplementedError() + def disable_tracking(self) -> None: + ... - def _python_file_logger(self, name: str, handler: logging.Handler) -> logging.Logger: - log = logging.getLogger(name) - log.setLevel(logging.DEBUG) - handler.setFormatter(logging.Formatter(fmt="%(message)s")) - log.handlers.clear() - log.propagate = False - log.addHandler(handler) - return log - def _snowplow_tracker( +class FileTracker(Tracker): + def __init__(self, logger: logging.Logger, invocation_id: Optional[str]) -> None: + self.logger = logger + self.invocation_id = invocation_id + + @classmethod + def from_config(cls, config: TrackerConfig) -> Self: + file_handler = RotatingFileHandler( + filename=config.output_file_name, + maxBytes=config.output_file_max_bytes, # type: ignore + backupCount=5, + encoding="utf8", + ) + file_handler.setFormatter(logging.Formatter(fmt="%(message)s")) + + logger = logging.getLogger(config.name) + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + logger.propagate = False + logger.addHandler(file_handler) + return cls(logger, config.invocation_id) + + def track(self, msg: EventMsg) -> None: + ts: str = timestamp_to_datetime_string(msg.info.ts) + log_line = f"{ts} | {msg.info.msg}" + self.logger.debug(log_line) + + def enable_tracking(self, cookie: Dict[str, Any]) -> None: + pass + + def disable_tracking(self) -> None: + pass + + +class SnowplowTracker(Tracker): + def __init__( self, - endpoint: str, - protocol: Optional[str] = "https", - on_failure: Optional[FailureCallback] = None, - ) -> Tracker: - emitter = Emitter( - endpoint, - protocol, + tracker: snowplow_tracker.Tracker, + msg_schemas: Dict[str, str], + invocation_id: Optional[str], + ) -> None: + self.tracker = tracker + self.msg_schemas = msg_schemas + self.invocation_id = invocation_id + + @classmethod + def from_config(cls, config: TrackerConfig) -> Self: + emitter = snowplow_tracker.Emitter( + config.endpoint, + config.protocol, method="post", batch_size=30, - on_failure=on_failure, + on_failure=config.on_failure, byte_limit=None, request_timeout=5.0, ) - tracker = Tracker( + tracker = snowplow_tracker.Tracker( emitters=emitter, namespace="cf", app_id="dbt", ) - return tracker + return cls(tracker, config.msg_schemas, config.invocation_id) + + def track(self, msg: EventMsg) -> None: + data = msg_to_dict(msg) + schema = self.msg_schemas.get(msg.info.name) + context = [snowplow_tracker.SelfDescribingJson(schema, data)] + event = snowplow_tracker.StructuredEvent( + category="dbt", + action=msg.info.name, + label=self.invocation_id, + context=context, + ) + self.tracker.track(event) + + def enable_tracking(self, cookie: Dict[str, Any]) -> None: + subject = snowplow_tracker.Subject() + subject.set_user_id(cookie.get("id")) + self.tracker.set_subject(subject) + + def disable_tracking(self) -> None: + self.tracker.set_subject(None) diff --git a/tests/unit/test_tracker.py b/tests/unit/test_tracker.py new file mode 100644 index 0000000..e69de29