diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index e606e7309aef..5251319b7132 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -268,6 +268,7 @@ SOURCES net_actions.py onnx_actions.py target_machine.py + test_actions.py ) add_mlir_python_modules(IREECompilerBuildPythonModules diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py index 532f235ba6ca..8554e207976a 100644 --- a/compiler/bindings/python/iree/build/executor.py +++ b/compiler/bindings/python/iree/build/executor.py @@ -276,8 +276,17 @@ def __str__(self) -> str: return self.value -class BuildAction(BuildDependency, abc.ABC): - """An action that must be carried out.""" +class BuildAction(BuildDependency): + """An action that must be carried out. + + This class is designed to be subclassed by concrete actions. In-process + only actions should override `_invoke`, whereas those that can be executed + out-of-process must override `_remotable_thunk`. + + Note that even actions that are marked for `PROCESS` concurrency will + run on a dedicated thread within the host process. Only the `_remotable_thunk` + result will be scheduled out of process. + """ def __init__( self, @@ -289,7 +298,7 @@ def __init__( ): super().__init__(executor=executor, deps=deps) self.desc = desc - self.concurrnecy = concurrency + self.concurrency = concurrency def __str__(self): return self.desc @@ -297,12 +306,35 @@ def __str__(self): def __repr__(self): return f"Action[{type(self).__name__}]('{self.desc}')" - def invoke(self): - self._invoke() + def invoke(self, scheduler: "Scheduler"): + # Invoke is run within whatever in-process execution context was requested: + # - On the scheduler thread for NONE + # - On a worker thread for THREAD or PROCESS + # For PROCESS concurrency, we have to create a compatible invocation + # thunk, schedule that on the process pool and wait for it. + if self.concurrency == ActionConcurrency.PROCESS: + thunk = self._remotable_thunk() + fut = scheduler.process_pool_executor.submit(thunk) + fut.result() + else: + self._invoke() - @abc.abstractmethod def _invoke(self): - ... + self._remotable_thunk()() + + def _remotable_thunk(self) -> Callable[[], None]: + """Creates a remotable no-arg thunk that will execute this out of process. + + This must return a no arg/result callable that can be pickled. While there + are various ways to ensure this, here are a few guidelines: + + * Must be a type/function defined at a module level. + * Cannot be decorated. + * Must only contain attributes with the same constraints. + """ + raise NotImplementedError( + f"Action '{self}' does not implement remotable invocation" + ) class BuildContext(BuildDependency): @@ -513,19 +545,20 @@ def _schedule_action(self, dep: BuildDependency): if isinstance(dep, BuildAction): def invoke(): - dep.invoke() + dep.invoke(self) return dep print(f"Scheduling action: {dep}", file=self.stderr) - if dep.concurrnecy == ActionConcurrency.NONE: + if dep.concurrency == ActionConcurrency.NONE: invoke() - elif dep.concurrnecy == ActionConcurrency.THREAD: + elif ( + dep.concurrency == ActionConcurrency.THREAD + or dep.concurrency == ActionConcurrency.PROCESS + ): dep.start(self.thread_pool_executor.submit(invoke)) - elif dep.concurrnecy == ActionConcurrency.PROCESS: - dep.start(self.process_pool_executor.submit(invoke)) else: raise AssertionError( - f"Unhandled ActionConcurrency value: {dep.concurrnecy}" + f"Unhandled ActionConcurrency value: {dep.concurrency}" ) else: # Not schedulable. Just mark it as done. diff --git a/compiler/bindings/python/iree/build/test_actions.py b/compiler/bindings/python/iree/build/test_actions.py new file mode 100644 index 000000000000..e4b9c55e0eac --- /dev/null +++ b/compiler/bindings/python/iree/build/test_actions.py @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable +from iree.build.executor import ActionConcurrency, BuildAction + + +class _ThunkTrampoline: + def __init__(self, thunk, args): + self.thunk = thunk + self.args = args + + def __call__(self): + self.thunk(*self.args) + + +class ExecuteOutOfProcessThunkAction(BuildAction): + """Executes a callback thunk with arguments. + + Both the thunk and args must be pickleable. + """ + + def __init__(self, thunk, args, concurrency=ActionConcurrency.PROCESS, **kwargs): + super().__init__(concurrency=concurrency, **kwargs) + self.trampoline = _ThunkTrampoline(thunk, args) + + def _remotable_thunk(self) -> Callable[[], None]: + return self.trampoline diff --git a/compiler/bindings/python/test/build_api/CMakeLists.txt b/compiler/bindings/python/test/build_api/CMakeLists.txt index b8bd81759ddc..5c8f97123d63 100644 --- a/compiler/bindings/python/test/build_api/CMakeLists.txt +++ b/compiler/bindings/python/test/build_api/CMakeLists.txt @@ -13,3 +13,10 @@ if(IREE_INPUT_TORCH) "mnist_builder_test.py" ) endif() + +iree_py_test( + NAME + concurrency_test + SRCS + "concurrency_test.py" +) diff --git a/compiler/bindings/python/test/build_api/concurrency_test.py b/compiler/bindings/python/test/build_api/concurrency_test.py new file mode 100644 index 000000000000..498179b73188 --- /dev/null +++ b/compiler/bindings/python/test/build_api/concurrency_test.py @@ -0,0 +1,61 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +from pathlib import Path +import tempfile +import unittest + +from iree.build import * +from iree.build.executor import BuildContext +from iree.build.test_actions import ExecuteOutOfProcessThunkAction + + +@entrypoint +def write_out_of_process_pid(): + context = BuildContext.current() + output_file = context.allocate_file("pid.txt") + action = ExecuteOutOfProcessThunkAction( + _write_pid_file, + args=[output_file.get_fs_path()], + desc="Writing pid file", + executor=context.executor, + ) + output_file.deps.add(action) + return output_file + + +def _write_pid_file(output_path: Path): + pid = os.getpid() + print(f"Running action out of process: pid={pid}") + output_path.write_text(str(pid)) + + +class ConcurrencyTest(unittest.TestCase): + def setUp(self): + self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + self._temp_dir.__enter__() + self.output_path = Path(self._temp_dir.name) + + def tearDown(self) -> None: + self._temp_dir.__exit__(None, None, None) + + def testProcessConcurrency(self): + parent_pid = os.getpid() + print(f"Testing out of process concurrency: pid={parent_pid}") + iree_build_main( + args=["write_out_of_process_pid", "--output-dir", str(self.output_path)] + ) + pid_file = ( + self.output_path / "genfiles" / "write_out_of_process_pid" / "pid.txt" + ) + child_pid = int(pid_file.read_text()) + print(f"Got child pid={child_pid}") + self.assertNotEqual(parent_pid, child_pid) + + +if __name__ == "__main__": + unittest.main()