forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Data] Make sure progress bars always finish at 100% (ray-project#36679)
Today the progress bar's total progress is defined by the number of estimated blocks. If the estimation is not correct, the progress bar may go over 100% or finishes half way. This PR make sure the progress bar always finishes at 100%. Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
- Loading branch information
1 parent
939d1d2
commit 695f3ee
Showing
4 changed files
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import functools | ||
|
||
from pytest import fixture | ||
|
||
import ray | ||
from ray.data._internal.progress_bar import ProgressBar | ||
|
||
|
||
@fixture(params=[True, False]) | ||
def enable_tqdm_ray(request): | ||
context = ray.data.DataContext.get_current() | ||
original_use_ray_tqdm = context.use_ray_tqdm | ||
context.use_ray_tqdm = request.param | ||
yield request.param | ||
context.use_ray_tqdm = original_use_ray_tqdm | ||
|
||
|
||
def test_progress_bar(enable_tqdm_ray): | ||
total = 10 | ||
# Used to record the total value of the bar at close | ||
total_at_close = 0 | ||
|
||
def patch_close(bar): | ||
nonlocal total_at_close | ||
total_at_close = 0 | ||
original_close = bar.close | ||
|
||
@functools.wraps(original_close) | ||
def wrapped_close(): | ||
nonlocal total_at_close | ||
total_at_close = bar.total | ||
|
||
bar.close = wrapped_close | ||
|
||
# Test basic usage | ||
pb = ProgressBar("", total, enabled=True) | ||
assert pb._bar is not None | ||
patch_close(pb._bar) | ||
for _ in range(total): | ||
pb.update(1) | ||
pb.close() | ||
|
||
assert pb._position == total | ||
assert total_at_close == total | ||
|
||
# Test if update() exceeds the original total, the total will be updated. | ||
pb = ProgressBar("", total, enabled=True) | ||
assert pb._bar is not None | ||
patch_close(pb._bar) | ||
new_total = total * 2 | ||
for _ in range(new_total): | ||
pb.update(1) | ||
pb.close() | ||
|
||
assert pb._position == new_total | ||
assert total_at_close == new_total | ||
|
||
# Test that if the bar is not complete at close(), the total will be updated. | ||
pb = ProgressBar("", total) | ||
assert pb._bar is not None | ||
patch_close(pb._bar) | ||
new_total = total // 2 | ||
for _ in range(new_total): | ||
pb.update(1) | ||
pb.close() | ||
|
||
assert pb._position == new_total | ||
assert total_at_close == new_total |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters