Skip to content

Commit

Permalink
Cap the number of stats kept in StatsActor and purge in FIFO order if…
Browse files Browse the repository at this point in the history
… the limit exceeded (#27964)

There is a risk of using too much of memory in StatsActor, because its lifetime is the same as cluster lifetime.
This puts a cap on how many stats to keep, and purge the stats in FIFO order if this cap is exceeded.
  • Loading branch information
jianoaix authored Aug 18, 2022
1 parent 24aeea8 commit 440ae62
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
32 changes: 26 additions & 6 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,40 @@ class _StatsActor:
"""Actor holding stats for blocks created by LazyBlockList.
This actor is shared across all datasets created in the same cluster.
The stats data is small so we don't worry about clean up for now.
In order to cap memory usage, we set a max number of stats to keep
in the actor. When this limit is exceeded, the stats will be garbage
collected in FIFO order.
TODO(ekl) we should consider refactoring LazyBlockList so stats can be
extracted without using an out-of-band actor."""

def __init__(self):
def __init__(self, max_stats=1000):
# Mapping from uuid -> dataset-specific stats.
self.metadata = collections.defaultdict(dict)
self.last_time = {}
self.start_time = {}
self.max_stats = max_stats
self.fifo_queue = []

def record_start(self, stats_uuid):
self.start_time[stats_uuid] = time.perf_counter()

def record_task(self, stats_uuid, i, metadata):
self.metadata[stats_uuid][i] = metadata
self.last_time[stats_uuid] = time.perf_counter()
self.fifo_queue.append(stats_uuid)
# Purge the oldest stats if the limit is exceeded.
if len(self.fifo_queue) > self.max_stats:
uuid = self.fifo_queue.pop(0)
if uuid in self.start_time:
del self.start_time[uuid]
if uuid in self.last_time:
del self.last_time[uuid]
if uuid in self.metadata:
del self.metadata[uuid]

def record_task(self, stats_uuid, task_idx, metadata):
# Null out the schema to keep the stats size small.
metadata.schema = None
if stats_uuid in self.start_time:
self.metadata[stats_uuid][task_idx] = metadata
self.last_time[stats_uuid] = time.perf_counter()

def get(self, stats_uuid):
if stats_uuid not in self.metadata:
Expand All @@ -114,6 +131,9 @@ def get(self, stats_uuid):
self.last_time[stats_uuid] - self.start_time[stats_uuid],
)

def _get_stats_dict_size(self):
return len(self.start_time), len(self.last_time), len(self.metadata)


def _get_or_create_stats_actor():
ctx = DatasetContext.get_current()
Expand Down
38 changes: 38 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ray
from ray._private.test_utils import wait_for_condition
from ray.data._internal.stats import _StatsActor
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_builder import BlockBuilder
from ray.data._internal.lazy_block_list import LazyBlockList
Expand Down Expand Up @@ -4688,6 +4689,43 @@ def get_node_id():
assert set(locations) == {node1_id, node2_id}


def test_stats_actor_cap_num_stats(ray_start_cluster):
actor = _StatsActor.remote(3)
metadatas = []
task_idx = 0
for uuid in range(3):
metadatas.append(
BlockMetadata(
num_rows=uuid,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
)
)
num_stats = uuid + 1
actor.record_start.remote(uuid)
assert ray.get(actor._get_stats_dict_size.remote()) == (
num_stats,
num_stats - 1,
num_stats - 1,
)
actor.record_task.remote(uuid, task_idx, metadatas[-1])
assert ray.get(actor._get_stats_dict_size.remote()) == (
num_stats,
num_stats,
num_stats,
)
for uuid in range(3):
assert ray.get(actor.get.remote(uuid))[0][task_idx] == metadatas[uuid]
# Add the fourth stats to exceed the limit.
actor.record_start.remote(3)
# The first stats (with uuid=0) should have been purged.
assert ray.get(actor.get.remote(0))[0] == {}
# The start_time has 3 entries because we just added it above with record_start().
assert ray.get(actor._get_stats_dict_size.remote()) == (3, 2, 2)


@ray.remote
class Counter:
def __init__(self):
Expand Down

0 comments on commit 440ae62

Please sign in to comment.