Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve progress bar display in Batch operations #2135

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Double precision mode solver is now supported in EME.
- `estimate_cost` is now called at the end of every `web.upload` call.
- Internal refactor of adjoint shape gradients using `GradientSurfaceMesh`.
- Enhanced progress bar display in batch operations with better formatting, colors, and status tracking.

### Fixed
- Significant speedup for field projection computations.
Expand Down
134 changes: 90 additions & 44 deletions tidy3d/web/api/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Dict, Optional, Tuple

import pydantic.v1 as pd
from rich.progress import Progress
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn

from ...components.base import Tidy3dBaseModel, cached_property
from ...components.types import annotate_type
Expand Down Expand Up @@ -627,12 +627,19 @@ def upload(self) -> None:
# progressbar (number of tasks uploaded)
if self.verbose:
console = get_logging_console()
with Progress(console=console) as progress:
pbar_message = f"Uploading data for {self.num_jobs} tasks."
pbar = progress.add_task(pbar_message, total=self.num_jobs - 1)
progress_columns = (
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
)
with Progress(*progress_columns, console=console, transient=False) as progress:
pbar_message = f"Uploading data for {self.num_jobs} tasks"
pbar = progress.add_task(pbar_message, total=self.num_jobs)
completed = 0
for _ in concurrent.futures.as_completed(futures):
progress.update(pbar, advance=1)
progress.update(pbar, completed=self.num_jobs - 1, refresh=True)
completed += 1
progress.update(pbar, completed=completed)

def get_info(self) -> Dict[TaskName, TaskInfo]:
"""Get information about each task in the :class:`Batch`.
Expand Down Expand Up @@ -678,22 +685,31 @@ def get_run_info(self) -> Dict[TaskName, RunInfo]:
return run_info_dict

def monitor(self) -> None:
"""Monitor progress of each of the running tasks.
"""Monitor progress of each of the running tasks."""

Note
----
To loop through the data of completed simulations, can call :meth:`Batch.items`.
"""

def pbar_description(task_name: str, status: str) -> str:
def pbar_description(
task_name: str, status: str, max_name_length: int, status_width: int
) -> str:
"""Make a progressbar description based on the status."""
description = f"{task_name}: status = {status}"
# if task name too long, truncate and add ...
if len(task_name) > max_name_length - 3: # -3 to leave room for ...
task_name = task_name[: (max_name_length - 3)] + "..."

# right-align status
task_part = f"{task_name:<{max_name_length}}"

# if something went wrong, make it red
if "error" in status or "diverge" in status:
description = f"[red]{description}"
status_part = f"→ [red]{status:<{status_width}}"
elif status == "success":
status_part = f"→ [green]{status:<{status_width}}"
elif status == "queued" or status == "queued_solver":
status_part = f"→ [yellow]{status:<{status_width}}"
elif status in ["preprocess", "postprocess", "running"]:
status_part = f"→ [blue]{status:<{status_width}}"
else:
status_part = f"→ {status:<{status_width}}"

return description
return f"{task_part} {status_part}"

run_statuses = [
"draft",
Expand All @@ -707,6 +723,12 @@ def pbar_description(task_name: str, status: str) -> str:
]
end_statuses = ("success", "error", "errored", "diverged", "diverge", "deleted", "draft")

max_task_name = max(len(task_name) for task_name in self.jobs.keys())
max_name_length = min(30, max(max_task_name, 15))
status_width = max(
max(len(status) for status in run_statuses), max(len(status) for status in end_statuses)
)

if self.verbose:
console = get_logging_console()

Expand All @@ -716,45 +738,64 @@ def pbar_description(task_name: str, status: str) -> str:
"get the billed FlexCredit cost after the Batch has completed."
)

with Progress(console=console) as progress:
# create progressbars
progress_columns = (
TextColumn("[progress.description]{task.description}"),
BarColumn(bar_width=25),
TaskProgressColumn(),
TimeElapsedColumn(),
)

with Progress(*progress_columns, console=console, transient=False) as progress:
# create progress bars
pbar_tasks = {}
for task_name, job in self.jobs.items():
status = job.status
description = pbar_description(task_name, status)
description = pbar_description(task_name, status, max_name_length, status_width)
completed = run_statuses.index(status) if status in run_statuses else 0
pbar = progress.add_task(
description,
total=len(run_statuses) - 1,
completed=completed,
description, total=len(run_statuses) - 1, completed=completed
)
pbar_tasks[task_name] = pbar

while any(job.status not in end_statuses for job in self.jobs.values()):
updates = []
for task_name, job in self.jobs.items():
pbar = pbar_tasks[task_name]
status = job.status
description = pbar_description(task_name, status)

if status in run_statuses:
completed = run_statuses.index(status)
progress.update(pbar, description=description, completed=completed)

updates.append(
(
pbar_tasks[task_name],
pbar_description(
task_name, status, max_name_length, status_width
),
run_statuses.index(status),
)
)

for pbar, description, completed in updates:
progress.update(
pbar, description=description, completed=completed, refresh=False
)

progress.refresh()
time.sleep(BATCH_MONITOR_PROGRESS_REFRESH_TIME)

# set all to 100% completed (if error or diverge, will be red)
updates = []
for task_name, job in self.jobs.items():
pbar = pbar_tasks[task_name]
status = job.status
description = pbar_description(task_name, status)
updates.append(
(
pbar_tasks[task_name],
pbar_description(task_name, job.status, max_name_length, status_width),
len(run_statuses) - 1,
)
)

for pbar, description, completed in updates:
progress.update(
pbar,
description=description,
completed=len(run_statuses) - 1,
refresh=True,
pbar, description=description, completed=completed, refresh=False
)

progress.refresh()
console.log("Batch complete.")

else:
Expand Down Expand Up @@ -825,22 +866,27 @@ def download(self, path_dir: str = DEFAULT_DATA_DIR) -> None:
continue

def fn(job=job, job_path=job_path) -> None:
"""Function to submit by executor, local variables bound by making kwargs."""
return job.download(path=job_path)

fns.append(fn)

futures = [executor.submit(fn) for fn in fns]

# progressbar (number of eligible tasks downloaded)
if self.verbose:
console = get_logging_console()
with Progress(console=console) as progress:
pbar_message = f"Downloading data for {len(fns)} tasks."
pbar = progress.add_task(pbar_message, total=len(fns) - 1)
progress_columns = (
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
)
with Progress(*progress_columns, console=console, transient=False) as progress:
pbar_message = f"Downloading data for {len(fns)} tasks"
pbar = progress.add_task(pbar_message, total=len(fns))
completed = 0
for _ in concurrent.futures.as_completed(futures):
progress.update(pbar, advance=1)
progress.update(pbar, completed=len(fns) - 1, refresh=True)
completed += 1
progress.update(pbar, completed=completed)

def load(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
"""Download results and load them into :class:`.BatchData` object.
Expand Down
Loading