diff --git a/python/ray/data/_internal/progress_bar.py b/python/ray/data/_internal/progress_bar.py index e117dc1be481..e7cfa667010a 100644 --- a/python/ray/data/_internal/progress_bar.py +++ b/python/ray/data/_internal/progress_bar.py @@ -1,3 +1,4 @@ +import logging import threading from typing import Any, List, Optional @@ -5,6 +6,9 @@ from ray.experimental import tqdm_ray from ray.types import ObjectRef from ray.util.annotations import Deprecated +from ray.util.debug import log_once + +logger = logging.getLogger(__name__) try: import tqdm @@ -44,6 +48,10 @@ class ProgressBar: because no tasks have finished yet), doesn't display the full progress bar. Still displays basic progress stats from tqdm.""" + # If the name/description of the progress bar exceeds this length, + # it will be truncated. + MAX_NAME_LENGTH = 100 + def __init__( self, name: str, @@ -52,7 +60,7 @@ def __init__( position: int = 0, enabled: Optional[bool] = None, ): - self._desc = name + self._desc = self._truncate_name(name) self._progress = 0 # Prepend a space to the unit for better formatting. if unit[0] != " ": @@ -83,6 +91,42 @@ def __init__( needs_warning = False self._bar = None + def _truncate_name(self, name: str) -> str: + ctx = ray.data.context.DataContext.get_current() + if ( + not ctx.enable_progress_bar_name_truncation + or len(name) <= self.MAX_NAME_LENGTH + ): + return name + + if log_once("ray_data_truncate_operator_name"): + logger.warning( + f"Truncating long operator name to {self.MAX_NAME_LENGTH} characters." + "To disable this behavior, set `ray.data.DataContext.get_current()." + "DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`." + ) + op_names = name.split("->") + if len(op_names) == 1: + return op_names[0] + + # Include as many operators as possible without approximately + # exceeding `MAX_NAME_LENGTH`. Always include the first and + # last operator names soit is easy to identify the DAG. + truncated_op_names = [op_names[0]] + for op_name in op_names[1:-1]: + if ( + len("->".join(truncated_op_names)) + + len("->") + + len(op_name) + + len("->") + + len(op_names[-1]) + ) > self.MAX_NAME_LENGTH: + truncated_op_names.append("...") + break + truncated_op_names.append(op_name) + truncated_op_names.append(op_names[-1]) + return "->".join(truncated_op_names) + def block_until_complete(self, remaining: List[ObjectRef]) -> None: t = threading.current_thread() while remaining: @@ -117,6 +161,7 @@ def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]: return [ref_to_result[ref] for ref in refs] def set_description(self, name: str) -> None: + name = self._truncate_name(name) if self._bar and name != self._desc: self._desc = name self._bar.set_description(self._desc) diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 4f1c7d508c50..ad266336eb9b 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -83,6 +83,9 @@ DEFAULT_ENABLE_PROGRESS_BARS = not bool( env_integer("RAY_DATA_DISABLE_PROGRESS_BARS", 0) ) +DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = env_bool( + "RAY_DATA_ENABLE_PROGRESS_BAR_NAME_TRUNCATION", True +) DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS = False @@ -209,6 +212,9 @@ class DataContext: to use. use_ray_tqdm: Whether to enable distributed tqdm. enable_progress_bars: Whether to enable progress bars. + enable_progress_bar_name_truncation: If True, the name of the progress bar + (often the operator name) will be truncated if it exceeds + `ProgressBar.MAX_NAME_LENGTH`. Otherwise, the full operator name is shown. enable_get_object_locations_for_metrics: Whether to enable ``get_object_locations`` for metrics. write_file_retry_on_errors: A list of substrings of error messages that should @@ -271,6 +277,9 @@ class DataContext: ) use_ray_tqdm: bool = DEFAULT_USE_RAY_TQDM enable_progress_bars: bool = DEFAULT_ENABLE_PROGRESS_BARS + enable_progress_bar_name_truncation: bool = ( + DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION + ) enable_get_object_locations_for_metrics: bool = ( DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS )