Skip to content

Commit

Permalink
perf: move serialization to background threads (#994)
Browse files Browse the repository at this point in the history
  • Loading branch information
hassiebp authored Nov 6, 2024
1 parent 05895c9 commit 87f98f0
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 58 deletions.
6 changes: 0 additions & 6 deletions langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,14 +991,8 @@ def _log_debug_event(
parent_run_id: Optional[UUID] = None,
**kwargs,
):
kwargs_log = (
", " + ", ".join([f"{key}: {value}" for key, value in kwargs.items()])
if len(kwargs) > 0
else ""
)
self.log.debug(
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}"
+ kwargs_log
)


Expand Down
70 changes: 43 additions & 27 deletions langfuse/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,11 +1334,11 @@ def trace(

new_body = TraceBody(**new_dict)

self.log.debug(f"Creating trace {new_body}")
self.log.debug(f"Creating trace {_filter_io_from_event_body(new_dict)}")
event = {
"id": str(uuid.uuid4()),
"type": "trace-create",
"body": new_body.dict(exclude_none=True),
"body": new_body,
}

self.task_manager.add_task(
Expand Down Expand Up @@ -1503,7 +1503,7 @@ def score(
event = {
"id": str(uuid.uuid4()),
"type": "score-create",
"body": new_body.dict(exclude_none=True),
"body": new_body,
}
self.task_manager.add_task(event)

Expand Down Expand Up @@ -1604,17 +1604,16 @@ def span(
if trace_id is None:
self._generate_trace(new_trace_id, name or new_trace_id)

self.log.debug(f"Creating span {span_body}...")
self.log.debug(f"Creating span {_filter_io_from_event_body(span_body)}...")

span_body = CreateSpanBody(**span_body)

event = {
"id": str(uuid.uuid4()),
"type": "span-create",
"body": span_body.dict(exclude_none=True),
"body": span_body,
}

self.log.debug(f"Creating span {event}...")
self.task_manager.add_task(event)

except Exception as e:
Expand Down Expand Up @@ -1710,10 +1709,12 @@ def event(
event = {
"id": str(uuid.uuid4()),
"type": "event-create",
"body": request.dict(exclude_none=True),
"body": request,
}

self.log.debug(f"Creating event {event}...")
self.log.debug(
f"Creating event {_filter_io_from_event_body(event_body)} ..."
)
self.task_manager.add_task(event)

except Exception as e:
Expand Down Expand Up @@ -1835,23 +1836,24 @@ def generation(
event = {
"id": str(uuid.uuid4()),
"type": "trace-create",
"body": request.dict(exclude_none=True),
"body": request,
}

self.log.debug(f"Creating trace {event}...")
self.log.debug("Creating trace...")

self.task_manager.add_task(event)

self.log.debug(f"Creating generation max {generation_body} {usage}...")
self.log.debug(
f"Creating generation max {_filter_io_from_event_body(generation_body)}..."
)
request = CreateGenerationBody(**generation_body)

event = {
"id": str(uuid.uuid4()),
"type": "generation-create",
"body": request.dict(exclude_none=True),
"body": request,
}

self.log.debug(f"Creating top-level generation {event} ...")
self.task_manager.add_task(event)

except Exception as e:
Expand All @@ -1877,10 +1879,10 @@ def _generate_trace(self, trace_id: str, name: str):
event = {
"id": str(uuid.uuid4()),
"type": "trace-create",
"body": trace_body.dict(exclude_none=True),
"body": trace_body,
}

self.log.debug(f"Creating trace {event}...")
self.log.debug(f"Creating trace {_filter_io_from_event_body(trace_dict)}...")
self.task_manager.add_task(event)

def join(self):
Expand Down Expand Up @@ -2087,7 +2089,9 @@ def generation(
"body": new_body.dict(exclude_none=True, exclude_unset=False),
}

self.log.debug(f"Creating generation {new_body}...")
self.log.debug(
f"Creating generation {_filter_io_from_event_body(generation_body)}..."
)
self.task_manager.add_task(event)

except Exception as e:
Expand Down Expand Up @@ -2165,7 +2169,7 @@ def span(
**kwargs,
}

self.log.debug(f"Creating span {span_body}...")
self.log.debug(f"Creating span {_filter_io_from_event_body(span_body)}...")

new_dict = self._add_state_to_event(span_body)
new_body = self._add_default_values(new_dict)
Expand All @@ -2175,7 +2179,7 @@ def span(
event = {
"id": str(uuid.uuid4()),
"type": "span-create",
"body": event.dict(exclude_none=True),
"body": event,
}

self.task_manager.add_task(event)
Expand Down Expand Up @@ -2284,7 +2288,7 @@ def score(
event = {
"id": str(uuid.uuid4()),
"type": "score-create",
"body": request.dict(exclude_none=True),
"body": request,
}

self.task_manager.add_task(event)
Expand Down Expand Up @@ -2369,10 +2373,12 @@ def event(
event = {
"id": str(uuid.uuid4()),
"type": "event-create",
"body": request.dict(exclude_none=True),
"body": request,
}

self.log.debug(f"Creating event {event}...")
self.log.debug(
f"Creating event {_filter_io_from_event_body(event_body)}..."
)
self.task_manager.add_task(event)

except Exception as e:
Expand Down Expand Up @@ -2497,7 +2503,9 @@ def update(
**kwargs,
}

self.log.debug(f"Update generation {generation_body}...")
self.log.debug(
f"Update generation {_filter_io_from_event_body(generation_body)}..."
)

request = UpdateGenerationBody(**generation_body)

Expand All @@ -2507,7 +2515,9 @@ def update(
"body": request.dict(exclude_none=True, exclude_unset=False),
}

self.log.debug(f"Update generation {event}...")
self.log.debug(
f"Update generation {_filter_io_from_event_body(generation_body)}..."
)
self.task_manager.add_task(event)

except Exception as e:
Expand Down Expand Up @@ -2684,14 +2694,14 @@ def update(
"end_time": end_time,
**kwargs,
}
self.log.debug(f"Update span {span_body}...")
self.log.debug(f"Update span {_filter_io_from_event_body(span_body)}...")

request = UpdateSpanBody(**span_body)

event = {
"id": str(uuid.uuid4()),
"type": "span-update",
"body": request.dict(exclude_none=True),
"body": request,
}

self.task_manager.add_task(event)
Expand Down Expand Up @@ -2888,14 +2898,14 @@ def update(
"tags": tags,
**kwargs,
}
self.log.debug(f"Update trace {trace_body}...")
self.log.debug(f"Update trace {_filter_io_from_event_body(trace_body)}...")

request = TraceBody(**trace_body)

event = {
"id": str(uuid.uuid4()),
"type": "trace-create",
"body": request.dict(exclude_none=True),
"body": request,
}

self.task_manager.add_task(event)
Expand Down Expand Up @@ -3350,3 +3360,9 @@ def __init__(self, dataset: Dataset, items: typing.List[DatasetItemClient]):
self.created_at = dataset.created_at
self.updated_at = dataset.updated_at
self.items = items


def _filter_io_from_event_body(event_body: Dict[str, Any]):
return {
k: v for k, v in event_body.items() if k not in ("input", "output", "metadata")
}
70 changes: 48 additions & 22 deletions langfuse/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class Consumer(threading.Thread):
_sdk_name: str
_sdk_version: str
_sdk_integration: str
_mask: Optional[MaskFunction]
_sampler: Sampler

def __init__(
self,
Expand All @@ -71,6 +73,8 @@ def __init__(
sdk_name: str,
sdk_version: str,
sdk_integration: str,
sample_rate: float,
mask: Optional[MaskFunction] = None,
):
"""Create a consumer thread."""
threading.Thread.__init__(self)
Expand All @@ -91,6 +95,8 @@ def __init__(
self._sdk_name = sdk_name
self._sdk_version = sdk_version
self._sdk_integration = sdk_integration
self._mask = mask
self._sampler = Sampler(sample_rate)

def _next(self):
"""Return the next batch of items to upload."""
Expand All @@ -107,13 +113,37 @@ def _next(self):
try:
item = queue.get(block=True, timeout=self._flush_interval - elapsed)

# convert pydantic models to dicts
if "body" in item and isinstance(item["body"], pydantic.BaseModel):
item["body"] = item["body"].dict(exclude_none=True)

# sample event
if not self._sampler.sample_event(item):
queue.task_done()

continue

# truncate item if it exceeds size limit
item_size = self._truncate_item_in_place(
item=item,
max_size=MAX_MSG_SIZE,
log_message="<truncated due to size exceeding limit>",
)

# apply mask
self._apply_mask_in_place(item)

# check for serialization errors
try:
json.dumps(item, cls=EventSerializer)
except Exception as e:
self._log.error(f"Error serializing item, skipping: {e}")
queue.task_done()

continue

items.append(item)

total_size += item_size
if total_size >= BATCH_SIZE_LIMIT:
self._log.debug("hit batch size limit (size: %d)", total_size)
Expand Down Expand Up @@ -190,6 +220,20 @@ def _get_item_size(self, item: typing.Any) -> int:
"""Return the size of the item in bytes."""
return len(json.dumps(item, cls=EventSerializer).encode())

def _apply_mask_in_place(self, event: dict):
"""Apply the mask function to the event. This is done in place."""
if not self._mask:
return

body = event["body"] if "body" in event else {}
for key in ("input", "output"):
if key in body:
try:
body[key] = self._mask(data=body[key])
except Exception as e:
self._log.error(f"Mask function failed with error: {e}")
body[key] = "<fully masked due to failed mask function>"

def run(self):
"""Runs the consumer."""
self._log.debug("consumer is running...")
Expand Down Expand Up @@ -261,7 +305,7 @@ class TaskManager(object):
_sdk_name: str
_sdk_version: str
_sdk_integration: str
_sampler: Sampler
_sample_rate: float
_mask: Optional[MaskFunction]

def __init__(
Expand Down Expand Up @@ -293,7 +337,7 @@ def __init__(
self._sdk_version = sdk_version
self._sdk_integration = sdk_integration
self._enabled = enabled
self._sampler = Sampler(sample_rate)
self._sample_rate = sample_rate
self._mask = mask

self.init_resources()
Expand All @@ -314,6 +358,8 @@ def init_resources(self):
sdk_name=self._sdk_name,
sdk_version=self._sdk_version,
sdk_integration=self._sdk_integration,
sample_rate=self._sample_rate,
mask=self._mask,
)
consumer.start()
self._consumers.append(consumer)
Expand All @@ -323,12 +369,6 @@ def add_task(self, event: dict):
return

try:
if not self._sampler.sample_event(event):
return # event was sampled out

self._apply_mask_in_place(event)

json.dumps(event, cls=EventSerializer)
event["timestamp"] = _get_timestamp()

self._queue.put(event, block=False)
Expand All @@ -340,20 +380,6 @@ def add_task(self, event: dict):

return False

def _apply_mask_in_place(self, event: dict):
"""Apply the mask function to the event. This is done in place."""
if not self._mask:
return

body = event["body"] if "body" in event else {}
for key in ("input", "output"):
if key in body:
try:
body[key] = self._mask(data=body[key])
except Exception as e:
self._log.error(f"Mask function failed with error: {e}")
body[key] = "<fully masked due to failed mask function>"

def flush(self):
"""Force a flush from the internal queue to the server."""
self._log.debug("flushing queue")
Expand Down
Loading

0 comments on commit 87f98f0

Please sign in to comment.