Skip to content

Commit

Permalink
add generic snowplow tracker with file logger for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare committed Sep 16, 2024
1 parent e671471 commit b3494cc
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 66 deletions.
16 changes: 16 additions & 0 deletions dbt_common/events/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ def msg_from_base_event(event: BaseEvent, level: Optional[EventLevel] = None):
return new_event


def msg_to_dict(msg: EventMsg) -> dict:
msg_dict = MessageToDict(
msg,
preserving_proto_field_name=True,
including_default_value_fields=True, # type: ignore
)
# We don't want an empty NodeInfo in output
if (
"data" in msg_dict
and "node_info" in msg_dict["data"]
and msg_dict["data"]["node_info"]["node_name"] == ""
):
del msg_dict["data"]["node_info"]
return msg_dict


# DynamicLevel requires that the level be supplied on the
# event construction call using the "info" function from functions.py
class DynamicLevel(BaseEvent):
Expand Down
88 changes: 60 additions & 28 deletions dbt_common/events/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import sys
from typing import Any, Callable, Dict, Optional, TextIO, Union

from google.protobuf.json_format import MessageToDict
from snowplow_tracker import Subject
from snowplow_tracker.typing import FailureCallback

from dbt_common.helper_types import WarnErrorOptions
from dbt_common.invocation import get_invocation_id
from dbt_common.events.base_types import BaseEvent, EventLevel, EventMsg
from dbt_common.events.base_types import (
BaseEvent,
EventLevel,
EventMsg,
msg_to_dict as _msg_to_dict,
)
from dbt_common.events.cookie import Cookie
from dbt_common.events.event_manager_client import get_event_manager
from dbt_common.events.logger import LoggerConfig, LineFormat
from dbt_common.events.tracker import TrackerConfig
from dbt_common.events.tracker import FileTracker, SnowplowTracker, Tracker, TrackerConfig
from dbt_common.events.types import DisableTracking, Note
from dbt_common.events.user import User
from dbt_common.exceptions import EventCompilationError, scrub_secrets, env_secrets
Expand Down Expand Up @@ -117,26 +120,14 @@ def msg_to_json(msg: EventMsg) -> str:


def msg_to_dict(msg: EventMsg) -> dict:
msg_dict = dict()
try:
msg_dict = MessageToDict(
msg,
preserving_proto_field_name=True,
including_default_value_fields=True, # type: ignore
)
return _msg_to_dict(msg)
except Exception as exc:
event_type = type(msg).__name__
fire_event(
Note(msg=f"type {event_type} is not serializable. {str(exc)}"), level=EventLevel.WARN
)
# We don't want an empty NodeInfo in output
if (
"data" in msg_dict
and "node_info" in msg_dict["data"]
and msg_dict["data"]["node_info"]["node_name"] == ""
):
del msg_dict["data"]["node_info"]
return msg_dict
return {}


def warn_or_error(event, node=None) -> None:
Expand Down Expand Up @@ -190,32 +181,61 @@ def _default_on_failure(num_ok, unsent):
fire_event(DisableTracking())


def snowplow_config(
def tracker_factory(
user: User,
endpoint: Optional[str],
protocol: Optional[str] = "https",
on_failure: Optional[FailureCallback] = _default_on_failure,
name: Optional[str] = None,
output_file_name: Optional[str] = None,
output_file_max_bytes: Optional[int] = None,
) -> Tracker:
if all([user, endpoint]):
return snowplow_tracker(user, endpoint, protocol, on_failure)
elif all([user, name, output_file_name]):
return file_tracker(user, name, output_file_name, output_file_max_bytes)
raise Exception("Invalid tracking configuration.")


def snowplow_tracker(
user: User,
endpoint: str,
protocol: Optional[str] = "https",
on_failure: Optional[FailureCallback] = _default_on_failure,
) -> TrackerConfig:
return TrackerConfig(
) -> Tracker:
config = TrackerConfig(
invocation_id=user.invocation_id,
endpoint=endpoint,
protocol=protocol,
on_failure=on_failure,
)
return SnowplowTracker.from_config(config)


def enable_tracking(tracker, user: User):
def file_tracker(
user: User,
name: str,
output_file_name: str,
output_file_max_bytes: Optional[int],
) -> Tracker:
config = TrackerConfig(
invocation_id=user.invocation_id,
name=name,
output_file_name=output_file_name,
output_file_max_bytes=output_file_max_bytes,
)
return FileTracker.from_config(config)


def enable_tracking(tracker: Tracker, user: User):
cookie = _get_cookie(user)
user.enable_tracking(cookie)

subject = Subject()
subject.set_user_id(cookie.get("id"))
tracker.set_subject(subject)
tracker.enable_tracking(cookie)


def disable_tracking(tracker, user: User):
def disable_tracking(tracker: Tracker, user: User):
user.disable_tracking()
tracker.set_subject(None)
tracker.disable_tracking()


def _get_cookie(user: User) -> Dict[str, Any]:
Expand All @@ -240,3 +260,15 @@ def _set_cookie(user: User) -> Dict[str, Any]:
user.cookie = cookie.as_dict()
return user.cookie
return {}


def track(tracker: Tracker, user: User, msg: EventMsg) -> None:
if user.do_not_track:
return

# fire_event(SendingEvent(kwargs=str(**msg_to_dict(msg))))
try:
tracker.track(msg)
except Exception:
# fire_event(SendEventFailure())
pass
126 changes: 88 additions & 38 deletions dbt_common/events/tracker.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,121 @@
from dataclasses import dataclass
import logging
from logging.handlers import RotatingFileHandler
from typing import Optional
from typing import Any, Dict, Optional, Protocol, Self

from snowplow_tracker import Emitter, Tracker
import snowplow_tracker
from snowplow_tracker.typing import FailureCallback

from dbt_common.events.base_types import EventMsg
from dbt_common.events.base_types import EventMsg, msg_to_dict
from dbt_common.events.format import timestamp_to_datetime_string


@dataclass
class TrackerConfig:
invocation_id: Optional[str] = None
msg_schemas: Optional[Dict[str, str]] = None
endpoint: Optional[str] = None
protocol: Optional[str] = None
protocol: Optional[str] = "https"
on_failure: Optional[FailureCallback] = None
name: Optional[str] = None
output_file_name: Optional[str] = None
output_file_max_bytes: Optional[int] = 10 * 1024 * 1024 # 10 mb


class _Tracker:
def __init__(self, config: TrackerConfig) -> None:
self.invocation_id: Optional[str] = config.invocation_id
class Tracker(Protocol):
@classmethod
def from_config(cls, config: TrackerConfig) -> Self:
...

if all([config.name, config.output_file_name]):
file_handler = RotatingFileHandler(
filename=str(config.output_file_name),
encoding="utf8",
maxBytes=config.output_file_max_bytes, # type: ignore
backupCount=5,
)
self._tracker = self._python_file_logger(config.name, file_handler)
def track(self, msg: EventMsg) -> None:
...

elif all([config.endpoint, config.protocol]):
self._tracker = self._snowplow_tracker(config.endpoint, config.protocol)
def enable_tracking(self, cookie: Dict[str, Any]) -> None:
...

def track(self, msg: EventMsg) -> str:
raise NotImplementedError()
def disable_tracking(self) -> None:
...

def _python_file_logger(self, name: str, handler: logging.Handler) -> logging.Logger:
log = logging.getLogger(name)
log.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter(fmt="%(message)s"))
log.handlers.clear()
log.propagate = False
log.addHandler(handler)
return log

def _snowplow_tracker(
class FileTracker(Tracker):
def __init__(self, logger: logging.Logger, invocation_id: Optional[str]) -> None:
self.logger = logger
self.invocation_id = invocation_id

@classmethod
def from_config(cls, config: TrackerConfig) -> Self:
file_handler = RotatingFileHandler(
filename=config.output_file_name,
maxBytes=config.output_file_max_bytes, # type: ignore
backupCount=5,
encoding="utf8",
)
file_handler.setFormatter(logging.Formatter(fmt="%(message)s"))

logger = logging.getLogger(config.name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
logger.propagate = False
logger.addHandler(file_handler)
return cls(logger, config.invocation_id)

def track(self, msg: EventMsg) -> None:
ts: str = timestamp_to_datetime_string(msg.info.ts)
log_line = f"{ts} | {msg.info.msg}"
self.logger.debug(log_line)

def enable_tracking(self, cookie: Dict[str, Any]) -> None:
pass

def disable_tracking(self) -> None:
pass


class SnowplowTracker(Tracker):
def __init__(
self,
endpoint: str,
protocol: Optional[str] = "https",
on_failure: Optional[FailureCallback] = None,
) -> Tracker:
emitter = Emitter(
endpoint,
protocol,
tracker: snowplow_tracker.Tracker,
msg_schemas: Dict[str, str],
invocation_id: Optional[str],
) -> None:
self.tracker = tracker
self.msg_schemas = msg_schemas
self.invocation_id = invocation_id

@classmethod
def from_config(cls, config: TrackerConfig) -> Self:
emitter = snowplow_tracker.Emitter(
config.endpoint,
config.protocol,
method="post",
batch_size=30,
on_failure=on_failure,
on_failure=config.on_failure,
byte_limit=None,
request_timeout=5.0,
)
tracker = Tracker(
tracker = snowplow_tracker.Tracker(
emitters=emitter,
namespace="cf",
app_id="dbt",
)
return tracker
return cls(tracker, config.msg_schemas, config.invocation_id)

def track(self, msg: EventMsg) -> None:
data = msg_to_dict(msg)
schema = self.msg_schemas.get(msg.info.name)
context = [snowplow_tracker.SelfDescribingJson(schema, data)]
event = snowplow_tracker.StructuredEvent(
category="dbt",
action=msg.info.name,
label=self.invocation_id,
context=context,
)
self.tracker.track(event)

def enable_tracking(self, cookie: Dict[str, Any]) -> None:
subject = snowplow_tracker.Subject()
subject.set_user_id(cookie.get("id"))
self.tracker.set_subject(subject)

def disable_tracking(self) -> None:
self.tracker.set_subject(None)
Empty file added tests/unit/test_tracker.py
Empty file.

0 comments on commit b3494cc

Please sign in to comment.