Skip to content

Commit

Permalink
[Data] Make sure progress bars always finish at 100% (ray-project#36679)
Browse files Browse the repository at this point in the history
Today the progress bar's total progress is defined by the number of estimated blocks. If the estimation is not correct, the progress bar may go over 100% or finishes half way.  This PR make sure the progress bar always finishes at 100%.

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
raulchen authored and arvind-chandra committed Aug 31, 2023
1 parent 939d1d2 commit 695f3ee
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,14 @@ py_test(
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_progress_bar",
size = "small",
srcs = ["tests/test_progress_bar.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_random_access",
size = "small",
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/_internal/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
self, name: str, total: int, position: int = 0, enabled: Optional[bool] = None
):
self._desc = name
self._position = position
if enabled is None:
from ray.data import DataContext

Expand Down Expand Up @@ -106,10 +107,18 @@ def set_description(self, name: str) -> None:

def update(self, i: int) -> None:
if self._bar and i != 0:
self._position += i
if self._bar.total is not None and self._position > self._bar.total:
# If the progress goes over 100%, update the total.
self._bar.total = self._position
self._bar.update(i)

def close(self):
if self._bar:
if self._bar.total is not None and self._position != self._bar.total:
# If the progress is not complete, update the total.
self._bar.total = self._position
self._bar.refresh()
self._bar.close()
self._bar = None

Expand Down
68 changes: 68 additions & 0 deletions python/ray/data/tests/test_progress_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import functools

from pytest import fixture

import ray
from ray.data._internal.progress_bar import ProgressBar


@fixture(params=[True, False])
def enable_tqdm_ray(request):
context = ray.data.DataContext.get_current()
original_use_ray_tqdm = context.use_ray_tqdm
context.use_ray_tqdm = request.param
yield request.param
context.use_ray_tqdm = original_use_ray_tqdm


def test_progress_bar(enable_tqdm_ray):
total = 10
# Used to record the total value of the bar at close
total_at_close = 0

def patch_close(bar):
nonlocal total_at_close
total_at_close = 0
original_close = bar.close

@functools.wraps(original_close)
def wrapped_close():
nonlocal total_at_close
total_at_close = bar.total

bar.close = wrapped_close

# Test basic usage
pb = ProgressBar("", total, enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
for _ in range(total):
pb.update(1)
pb.close()

assert pb._position == total
assert total_at_close == total

# Test if update() exceeds the original total, the total will be updated.
pb = ProgressBar("", total, enabled=True)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total * 2
for _ in range(new_total):
pb.update(1)
pb.close()

assert pb._position == new_total
assert total_at_close == new_total

# Test that if the bar is not complete at close(), the total will be updated.
pb = ProgressBar("", total)
assert pb._bar is not None
patch_close(pb._bar)
new_total = total // 2
for _ in range(new_total):
pb.update(1)
pb.close()

assert pb._position == new_total
assert total_at_close == new_total
15 changes: 15 additions & 0 deletions python/ray/experimental/tqdm_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def close(self):
if ray is not None:
self._dump_state()

def refresh(self):
"""Implements tqdm.tqdm.refresh."""
self._dump_state()

@property
def total(self) -> Optional[int]:
return self._total

@total.setter
def total(self, total: int):
self._total = total

def _dump_state(self) -> None:
if ray._private.worker.global_worker.mode == ray.WORKER_MODE:
# Include newline in payload to avoid split prints.
Expand Down Expand Up @@ -160,6 +172,9 @@ def update(self, state: ProgressBarState) -> None:
"""Apply the updated worker progress bar state."""
if state["desc"] != self.state["desc"]:
self.bar.set_description(state["desc"])
if state["total"] != self.state["total"]:
self.bar.total = state["total"]
self.bar.refresh()
delta = state["x"] - self.state["x"]
if delta:
self.bar.update(delta)
Expand Down

0 comments on commit 695f3ee

Please sign in to comment.