Skip to content

Commit

Permalink
[CHORE] Remove the concept of runner configs (#3276)
Browse files Browse the repository at this point in the history
Removes the concept of RunnerConfigs.

Our Runners are now materialized in 2 places:

1. When a user calls `daft.context.set_runner_*`, eagerly!
2. When a user first interacts with a Daft API that calls
`get_context().get_or_create_runner()` under the hood

---------

Co-authored-by: Jay Chia <jaychia94@gmail.com@users.noreply.github.com>
Co-authored-by: Colin Ho <colinho@Colins-MacBook-Pro.local>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent 2c59675 commit e430942
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 67 deletions.
2 changes: 1 addition & 1 deletion benchmarking/tpch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def warmup_environment(requirements: str | None, parquet_folder: str):
runtime_env = get_ray_runtime_env(requirements)

ray.init(
address=ctx._runner_config.address,
address=ctx._runner.ray_address,
runtime_env=runtime_env,
)

Expand Down
100 changes: 40 additions & 60 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class DaftContext:
# Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config
_daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig.from_env()

_runner_config: _RunnerConfig | None = None
_runner: Runner | None = None

_instance: ClassVar[DaftContext | None] = None
Expand All @@ -152,7 +151,34 @@ def get_or_create_runner(self) -> Runner:
WARNING: This will set the runner if it has not yet been set.
"""
with self._lock:
return self._get_or_create_runner()
if self._runner is not None:
return self._runner

runner_config = _get_runner_config_from_env()
if runner_config.name == "ray":
from daft.runners.ray_runner import RayRunner

assert isinstance(runner_config, _RayRunnerConfig)
self._runner = RayRunner(
address=runner_config.address,
max_task_backlog=runner_config.max_task_backlog,
force_client_mode=runner_config.force_client_mode,
)
elif runner_config.name == "py":
from daft.runners.pyrunner import PyRunner

assert isinstance(runner_config, _PyRunnerConfig)
self._runner = PyRunner(use_thread_pool=runner_config.use_thread_pool)
elif runner_config.name == "native":
from daft.runners.native_runner import NativeRunner

assert isinstance(runner_config, _NativeRunnerConfig)
self._runner = NativeRunner()

else:
raise NotImplementedError(f"Runner config not implemented: {runner_config.name}")

return self._runner

@property
def daft_execution_config(self) -> PyDaftExecutionConfig:
Expand All @@ -164,55 +190,6 @@ def daft_planning_config(self) -> PyDaftPlanningConfig:
with self._lock:
return self._daft_planning_config

def _get_or_create_runner_config(self) -> _RunnerConfig:
"""Gets the runner config."""
if self._runner_config is not None:
return self._runner_config
self._runner_config = _get_runner_config_from_env()
return self._runner_config

def _get_or_create_runner(self) -> Runner:
"""Gets the runner."""
if self._runner is not None:
return self._runner

runner_config = self._get_or_create_runner_config()
if runner_config.name == "ray":
from daft.runners.ray_runner import RayRunner

assert isinstance(runner_config, _RayRunnerConfig)
self._runner = RayRunner(
address=runner_config.address,
max_task_backlog=runner_config.max_task_backlog,
force_client_mode=runner_config.force_client_mode,
)
elif runner_config.name == "py":
from daft.runners.pyrunner import PyRunner

assert isinstance(runner_config, _PyRunnerConfig)
self._runner = PyRunner(use_thread_pool=runner_config.use_thread_pool)
elif runner_config.name == "native":
from daft.runners.native_runner import NativeRunner

assert isinstance(runner_config, _NativeRunnerConfig)
self._runner = NativeRunner()

else:
raise NotImplementedError(f"Runner config not implemented: {runner_config.name}")

return self._runner

def _can_set_runner(self, new_runner_name: str) -> bool:
# If the runner has not been set yet, we can set it
if self._runner_config is None:
return True
# If the runner has been set to the ray runner, we can't set it again
elif self._runner_config.name == "ray":
return False
# If the runner has been set to a local runner, we can set it to a new local runner
else:
return new_runner_name in {"py", "native"}


_DaftContext = DaftContext()

Expand Down Expand Up @@ -248,20 +225,21 @@ def set_runner_ray(

ctx = get_context()
with ctx._lock:
if not ctx._can_set_runner("ray"):
if ctx._runner is not None:
if noop_if_initialized:
warnings.warn(
"Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call."
)
return ctx
raise RuntimeError("Cannot set runner more than once")

ctx._runner_config = _RayRunnerConfig(
from daft.runners.ray_runner import RayRunner

ctx._runner = RayRunner(
address=address,
max_task_backlog=max_task_backlog,
force_client_mode=force_client_mode,
)
ctx._runner = None
return ctx


Expand All @@ -275,11 +253,12 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
"""
ctx = get_context()
with ctx._lock:
if not ctx._can_set_runner("py"):
if ctx._runner is not None and ctx._runner.name not in {"py", "native"}:
raise RuntimeError("Cannot set runner more than once")

ctx._runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool)
ctx._runner = None
from daft.runners.pyrunner import PyRunner

ctx._runner = PyRunner(use_thread_pool=use_thread_pool)
return ctx


Expand All @@ -293,11 +272,12 @@ def set_runner_native() -> DaftContext:
"""
ctx = get_context()
with ctx._lock:
if not ctx._can_set_runner("native"):
if ctx._runner is not None and ctx._runner.name not in {"py", "native"}:
raise RuntimeError("Cannot set runner more than once")

ctx._runner_config = _NativeRunnerConfig()
ctx._runner = None
from daft.runners.native_runner import NativeRunner

ctx._runner = NativeRunner()
return ctx


Expand Down
3 changes: 3 additions & 0 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,9 @@ def __init__(
force_client_mode: bool = False,
) -> None:
super().__init__()

self.ray_address = address

if ray.is_initialized():
if address is not None:
logger.warning(
Expand Down
Loading

0 comments on commit e430942

Please sign in to comment.