diff --git a/.gitignore b/.gitignore index b61afb451912..58f783104044 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,8 @@ build/ build-*/ Testing/ +# Include iree.build package +!compiler/bindings/python/iree/compiler/build/ # Bazel artifacts **/bazel-* diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index 0a0b36474885..bc8119f07b61 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -234,6 +234,37 @@ add_mlir_python_modules(IREECompilerPythonModules ) +################################################################################ +# iree.build package +# This is a pure Python part of the namespace, not rooted under iree.compiler +# like the above. It is only using the same build support for compatibility +# with the existing development flow. +# If the build system for Python code is ever redone, this can just be +# source namespace in the project definition. +################################################################################ + +# The iree.build package. +declare_mlir_python_sources(IREECompilerBuildPythonPackage +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/build" +SOURCES + __init__.py + __main__.py + executor.py + lang.py + main.py + net_actions.py + onnx_actions.py +) + +add_mlir_python_modules(IREECompilerBuildPythonModules + ROOT_PREFIX "${_PYTHON_BUILD_PREFIX}/iree/build" + INSTALL_PREFIX "${_PYTHON_INSTALL_PREFIX}/iree/build" + DECLARED_SOURCES + IREECompilerBuildPythonPackage +) + +add_dependencies(IREECompilerPythonModules IREECompilerBuildPythonModules) + ################################################################################ # Tools linked against the shared CAPI library ################################################################################ diff --git a/compiler/bindings/python/iree/build/__init__.py b/compiler/bindings/python/iree/build/__init__.py new file mode 100644 index 000000000000..95ee3f74b26b --- /dev/null +++ b/compiler/bindings/python/iree/build/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse + +from iree.build.lang import * +from iree.build.main import * +from iree.build.net_actions import * +from iree.build.onnx_actions import * diff --git a/compiler/bindings/python/iree/build/__main__.py b/compiler/bindings/python/iree/build/__main__.py new file mode 100644 index 000000000000..bbaa9e1efb89 --- /dev/null +++ b/compiler/bindings/python/iree/build/__main__.py @@ -0,0 +1,11 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .main import CliMain + + +if __name__ == "__main__": + CliMain().run() diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py new file mode 100644 index 000000000000..dcd02058b10e --- /dev/null +++ b/compiler/bindings/python/iree/build/executor.py @@ -0,0 +1,516 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable, Collection, Generator, IO + +import abc +import argparse +import concurrent.futures +import enum +import inspect +import multiprocessing +import sys +import time +import traceback +from pathlib import Path +import threading + +_locals = threading.local() + + +class FileNamespace(enum.StrEnum): + # Transient generated files go into the GEN namespace. These are typically + # not packaged for distribution. + GEN = enum.auto() + + # Distributable parameter files. + PARAMS = enum.auto() + + # Distributable, platform-neutral binaries. + BIN = enum.auto() + + # Distributable, platform specific binaries. + PLATFORM_BIN = enum.auto() + + +FileNamespaceToPath = { + FileNamespace.GEN: lambda executor: executor.output_dir / "genfiles", + FileNamespace.PARAMS: lambda executor: executor.output_dir / "params", + FileNamespace.BIN: lambda executor: executor.output_dir / "bin", + # TODO: This isn't right. Need to resolve platform dynamically. + FileNamespace.PLATFORM_BIN: lambda executor: executor.output_dir / "platform", +} + + +def join_namespace(prefix: str, suffix: str) -> str: + """Joins two namespace components, taking care of the root namespace (empty).""" + if not prefix: + return suffix + return f"{prefix}/{suffix}" + + +class ClArg: + def __init__(self, name, dest: str, **add_argument_kw): + self.name = name + self.dest = dest + self.add_argument_kw = add_argument_kw + + def define_arg(self, parser: argparse.ArgumentParser): + parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw) + + def resolve(self, arg_namespace: argparse.Namespace): + try: + return getattr(arg_namespace, self.dest) + except AttributeError as e: + raise RuntimeError( + f"Unable to resolve command line argument '{self.dest}' in namespace" + ) from e + + +class Entrypoint: + def __init__( + self, + name: str, + wrapped: Callable, + description: str | None = None, + ): + self.name = name + self.description = description + self._wrapped = wrapped + + def cl_args(self) -> Generator[ClArg, None, None]: + sig = inspect.signature(self._wrapped) + for p in sig.parameters.values(): + def_value = p.default + if isinstance(def_value, ClArg): + yield def_value + + def __call__(self, *args, **kwargs): + parent_context = BuildContext.current() + bep = BuildEntrypoint( + join_namespace(parent_context.path, self.name), + parent_context.executor, + self, + ) + parent_context.executor.entrypoints.append(bep) + with bep: + sig = inspect.signature(self._wrapped) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + def filter(arg): + if isinstance(arg, ClArg): + return arg.resolve(parent_context.executor.args_namespace) + return arg + + new_args = [filter(arg) for arg in bound.args] + new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()} + results = self._wrapped(*new_args, **new_kwargs) + if results is not None: + files = bep.files(results) + bep.deps.update(files) + bep.outputs.extend(files) + return files + + +class Executor: + """Executor that all build contexts share.""" + + def __init__( + self, output_dir: Path, args_namespace: argparse.Namespace, stderr: IO + ): + self.output_dir = output_dir + self.verbose_level = 0 + # Keyed by path + self.all: dict[str, "BuildContext" | "BuildFile"] = {} + self.entrypoints: list["BuildEntrypoint"] = [] + self.args_namespace = args_namespace + self.stderr = stderr + BuildContext("", self) + + def check_path_not_exists(self, path: str, for_entity): + existing = self.all.get(path) + if existing is not None: + formatted_stack = "".join(traceback.format_list(existing.def_stack)) + raise RuntimeError( + f"Cannot add {for_entity} because an entity with that name was " + f"already defined at:\n{formatted_stack}" + ) + + def get_context(self, path: str) -> "BuildContext": + existing = self.all.get(path) + if existing is None: + raise RuntimeError(f"Context at path {path} not found") + if not isinstance(existing, BuildContext): + raise RuntimeError( + f"Entity at path {path} is not a context. It is: {existing}" + ) + return existing + + def get_file(self, path: str) -> "BuildFile": + existing = self.all.get(path) + if existing is None: + raise RuntimeError(f"File at path {path} not found") + if not isinstance(existing, BuildFile): + raise RuntimeError( + f"Entity at path {path} is not a file. It is: {existing}" + ) + return existing + + def write_status(self, message: str): + print(message, file=self.stderr) + + def get_root(self, namespace: FileNamespace) -> Path: + return FileNamespaceToPath[namespace](self) + + def analyze(self, *entrypoints: Entrypoint): + """Analyzes all entrypoints building the graph.""" + for entrypoint in entrypoints: + if self.verbose_level > 1: + self.write_status(f"Analyzing entrypoint {entrypoint.name}") + with self.get_context("") as context: + entrypoint() + + def build(self, *initial_deps: "BuildDependency"): + """Transitively builds the given deps.""" + scheduler = Scheduler(stderr=self.stderr) + success = False + try: + for d in initial_deps: + scheduler.add_initial_dep(d) + scheduler.build() + success = True + finally: + if not success: + print("Waiting for background tasks to complete...", file=self.stderr) + scheduler.shutdown() + + +class BuildDependency: + """Base class of entities that can act as a build dependency.""" + + def __init__( + self, *, executor: Executor, deps: set["BuildDependency"] | None = None + ): + self.executor = executor + self.deps: set[BuildDependency] = set() + if deps: + self.deps.update(deps) + + # Scheduling state. + self.future: concurrent.futures.Future | None = None + self.start_time: float | None = None + self.finish_time: float | None = None + + @property + def is_scheduled(self) -> bool: + return self.future is not None + + @property + def execution_time(self) -> float: + if self.start_time is None: + return 0.0 + if self.finish_time is None: + return time.time() - self.start_time + return self.finish_time - self.start_time + + def start(self, future: concurrent.futures.Future): + assert not self.is_scheduled, f"Cannot start an already scheduled dep: {self}" + self.future = future + self.start_time = time.time() + + def finish(self): + assert self.is_scheduled, "Cannot finish an unstarted dep" + self.finish_time = time.time() + self.future.set_result(self) + + +class BuildFile(BuildDependency): + """Generated file in the build tree.""" + + def __init__( + self, + *, + executor: Executor, + path: str, + namespace: FileNamespace = FileNamespace.GEN, + deps: set[BuildDependency] | None = None, + ): + super().__init__(executor=executor, deps=deps) + self.def_stack = traceback.extract_stack()[0:-2] + self.executor = executor + self.path = path + self.namespace = namespace + # Set of build files that must be made available to any transitive user + # of this build file at runtime. + self.runfiles: set["BuildFile"] = set() + + executor.check_path_not_exists(path, self) + executor.all[path] = self + + def get_fs_path(self) -> Path: + path = self.executor.get_root(self.namespace) / self.path + path.parent.mkdir(parents=True, exist_ok=True) + return path + + def __repr__(self): + return f"BuildFile[{self.namespace}]({self.path})" + + +class ActionConcurrency(enum.StrEnum): + THREAD = enum.auto() + PROCESS = enum.auto() + NONE = enum.auto() + + +class BuildAction(BuildDependency, abc.ABC): + """An action that must be carried out.""" + + def __init__( + self, + *, + desc: str, + executor: Executor, + concurrency: ActionConcurrency = ActionConcurrency.THREAD, + deps: set[BuildDependency] | None = None, + ): + super().__init__(executor=executor, deps=deps) + self.desc = desc + self.concurrnecy = concurrency + + def __str__(self): + return self.desc + + def __repr__(self): + return f"Action[{type(self).__name__}]('{self.desc}')" + + @abc.abstractmethod + def invoke(self): + ... + + +class BuildContext(BuildDependency): + """Manages a build graph under construction.""" + + def __init__(self, path: str, executor: Executor): + super().__init__(executor=executor) + self.def_stack = traceback.extract_stack()[0:-2] + self.executor = executor + self.path = path + executor.check_path_not_exists(path, self) + executor.all[path] = self + self.analyzed = False + + def __repr__(self): + return f"{type(self).__name__}(path='{self.path}')" + + def allocate_file( + self, path: str, namespace: FileNamespace = FileNamespace.GEN + ) -> BuildFile: + """Allocates a file in the build tree with local path |path|. + + If |path| is absoluate (starts with '/'), then it is used as-is. Otherwise, + it is joined with the path of this context. + """ + if not path.startswith("/"): + path = join_namespace(self.path, path) + return BuildFile(executor=self.executor, path=path, namespace=namespace) + + def file(self, file: str | BuildFile) -> BuildFile: + """Accesses a BuildFile by either string (path) or BuildFile. + + It must already exist. + """ + if isinstance(file, BuildFile): + return file + path = file + if not path.startswith("/"): + path = join_namespace(self.path, path) + existing = self.executor.all.get(path) + if not isinstance(existing, BuildFile): + all_files = [ + f.path for f in self.executor.all.values() if isinstance(f, BuildFile) + ] + raise RuntimeError( + f"File with path '{path}' is not known in the build graph. Available:\n" + f" {'\n '.join(all_files)}" + ) + return existing + + def files( + self, files: str | BuildFile | Collection[str | BuildFile] + ) -> list[BuildFile]: + """Accesses a collection of files (or single) as a list of BuildFiles.""" + if isinstance(files, (str, BuildFile)): + return [self.file(files)] + return [self.file(f) for f in files] + + @staticmethod + def current() -> "BuildContext": + try: + return _locals.context_stack[-1] + except (AttributeError, IndexError): + raise RuntimeError( + "The current code can only be evaluated within an active BuildContext" + ) + + def __enter__(self) -> "BuildContext": + try: + stack = _locals.context_stack + except AttributeError: + stack = _locals.context_stack = [] + stack.append(self) + return self + + def __exit__(self, *args): + try: + stack = _locals.context_stack + except AttributeError: + raise AssertionError("BuildContext exit without enter") + existing = stack.pop() + assert existing is self, "Unbalanced BuildContext enter/exit" + + def populate_arg_parser(self, parser: argparse.ArgumentParser): + ... + + +class BuildEntrypoint(BuildContext): + def __init__(self, path: str, executor: Executor, entrypoint: Entrypoint): + super().__init__(path, executor) + self.entrypoint = entrypoint + self.outputs: list[BuildFile] = [] + + +class Scheduler: + """Holds resources related to scheduling.""" + + def __init__(self, stderr: IO): + self.stderr = stderr + + # Inverted producer-consumer graph nodes mapping a producer dep to + # all deps which directly depend on it and will be unblocked by it + # beins satisfied. + self.producer_graph: dict[BuildDependency, list[BuildDependency]] = {} + + # Set of build dependencies that have been scheduled. These will all + # have a future set on them prior to adding to the set. + self.in_flight_deps: set[BuildDependency] = set() + + self.thread_pool_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=10, thread_name_prefix="iree.build" + ) + self.process_pool_executor = concurrent.futures.ProcessPoolExecutor( + max_workers=10, mp_context=multiprocessing.get_context("spawn") + ) + + def shutdown(self): + self.thread_pool_executor.shutdown(cancel_futures=True) + self.process_pool_executor.shutdown(cancel_futures=True) + + def add_initial_dep(self, initial_dep: BuildDependency): + assert isinstance(initial_dep, BuildDependency) + if initial_dep in self.producer_graph: + # Already in the graph. + return + + # At this point nothing depends on this initial dep, so just note it + # as producing nothing. + self.producer_graph[initial_dep] = [] + + # Adds a dep requested by some top-level caller. + stack: set[BuildDependency] = set() + stack.add(initial_dep) + for producer_dep in initial_dep.deps: + self._add_dep(producer_dep, initial_dep, stack) + + def _add_dep( + self, + producer_dep: BuildDependency, + consumer_dep: BuildDependency, + stack: set[BuildDependency], + ): + if producer_dep in stack: + raise RuntimeError( + f"Circular dependency: '{producer_dep}' depends on itself: {stack}" + ) + plist = self.producer_graph.get(producer_dep) + if plist is None: + plist = [] + self.producer_graph[producer_dep] = plist + plist.append(consumer_dep) + next_stack = set(stack) + next_stack.add(producer_dep) + if producer_dep.deps: + # Intermediate dep. + for next_dep in producer_dep.deps: + self._add_dep(next_dep, producer_dep, next_stack) + + def build(self): + # Build all deps until the graph is satisfied. + # Schedule any deps that have no dependencies to start things off. + for eligible_dep in self.producer_graph.keys(): + if len(eligible_dep.deps) == 0: + self._schedule_action(eligible_dep) + self.in_flight_deps.add(eligible_dep) + + while self.producer_graph: + print( + f"Servicing {len(self.producer_graph)} outstanding tasks", + file=self.stderr, + ) + self._service_graph() + + def _service_graph(self): + completed_deps: set[BuildDependency] = set() + try: + for completed_fut in concurrent.futures.as_completed( + (d.future for d in self.in_flight_deps), 0 + ): + completed_dep = completed_fut.result() + assert isinstance(completed_dep, BuildDependency) + print(f"Completed {completed_dep}", file=self.stderr) + completed_deps.add(completed_dep) + except TimeoutError: + pass + + # Purge done from in-flight list. + self.in_flight_deps.difference_update(completed_deps) + + # Schedule any available. + for completed_dep in completed_deps: + ready_list = self.producer_graph.get(completed_dep) + if ready_list is None: + continue + del self.producer_graph[completed_dep] + for ready_dep in ready_list: + self._schedule_action(ready_dep) + self.in_flight_deps.add(ready_dep) + + # Do a blocking wait for at least one ready. + concurrent.futures.wait( + (d.future for d in self.in_flight_deps), + return_when=concurrent.futures.FIRST_COMPLETED, + ) + + def _schedule_action(self, dep: BuildDependency): + if dep.is_scheduled: + return + if isinstance(dep, BuildAction): + + def invoke(): + dep.invoke() + return dep + + print(f"Scheduling action: {dep}", file=self.stderr) + dep.start(self.thread_pool_executor.submit(invoke)) + else: + # Not schedulable. Just mark it as done. + dep.start(concurrent.futures.Future()) + dep.finish() + + +# Type aliases. +BuildFileLike = BuildFile | str diff --git a/compiler/bindings/python/iree/build/lang.py b/compiler/bindings/python/iree/build/lang.py new file mode 100644 index 000000000000..5cb87790aec3 --- /dev/null +++ b/compiler/bindings/python/iree/build/lang.py @@ -0,0 +1,52 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable + +import argparse +import functools + +from iree.build.executor import ClArg, Entrypoint + +__all__ = [ + "cl_arg", + "entrypoint", +] + + +def entrypoint( + f=None, + *, + description: str | None = None, +): + """Function decorator to turn it into a build entrypoint.""" + if f is None: + return functools.partial(entrypoint, description=description) + target = Entrypoint(f.__name__, f, description=description) + functools.wraps(target, f) + return target + + +def cl_arg(name: str, *, action=None, default=None, type=None, help=None): + """Used to define or reference a command-line argument from within actions + and entry-points. + + Keywords have the same interpretation as `ArgumentParser.add_argument()`. + + Any ClArg set as a default value for an argument to an `entrypoint` will be + added to the global argument parser. Any particular argument name can only be + registered once and must not conflict with a built-in command line option. + The implication of this is that for single-use arguments, the `=cl_arg(...)` + can just be added as a default argument. Otherwise, for shared arguments, + it should be created at the module level and referenced. + + When called, any entrypoint arguments that do not have an explicit keyword + set will get their value from the command line environment. + """ + if name.startswith("-"): + raise ValueError("cl_arg name must not be prefixed with dashes") + dest = name.replace("-", "_") + return ClArg(name, action=action, default=default, type=type, dest=dest, help=help) diff --git a/compiler/bindings/python/iree/build/net_actions.py b/compiler/bindings/python/iree/build/net_actions.py new file mode 100644 index 000000000000..1d2e1587e63a --- /dev/null +++ b/compiler/bindings/python/iree/build/net_actions.py @@ -0,0 +1,39 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import urllib.error +import urllib.request + +from iree.build.executor import BuildAction, BuildContext, BuildFile + +__all__ = [ + "fetch_http", +] + + +def fetch_http(*, name: str, url: str) -> BuildFile: + context = BuildContext.current() + output_file = context.allocate_file(name) + action = FetchHttpAction( + url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor + ) + output_file.deps.add(action) + return output_file + + +class FetchHttpAction(BuildAction): + def __init__(self, url: str, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.url = url + self.output_file = output_file + + def invoke(self): + path = self.output_file.get_fs_path() + self.executor.write_status(f"Fetching URL: {self.url} -> {path}") + try: + urllib.request.urlretrieve(self.url, str(path)) + except urllib.error.HTTPError as e: + raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None diff --git a/compiler/bindings/python/iree/build/onnx_actions.py b/compiler/bindings/python/iree/build/onnx_actions.py new file mode 100644 index 000000000000..522089006116 --- /dev/null +++ b/compiler/bindings/python/iree/build/onnx_actions.py @@ -0,0 +1,90 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileLike + +__all__ = [ + "onnx_import", +] + + +def onnx_import( + *, + # Name of the rule and output of the final artifact. + name: str, + # Source onnx file. + source: BuildFileLike, + upgrade: bool = True, +) -> BuildFile: + context = BuildContext.current() + input_file = context.file(source) + output_file = context.allocate_file(name) + + # Chain through an upgrade if requested. + if upgrade: + processed_file = context.allocate_file(f"{name}__upgrade.onnx") + UpgradeOnnxAction( + input_file=input_file, + output_file=processed_file, + executor=context.executor, + desc=f"Upgrading ONNX {name}", + deps=[ + input_file, + ], + ) + input_file = processed_file + + # Import. + ImportOnnxAction( + input_file=input_file, + output_file=output_file, + desc=f"Importing ONNX {name}", + executor=context.executor, + deps=[ + input_file, + ], + ) + + output_file.deps.add(processed_file) + return output_file + + +class UpgradeOnnxAction(BuildAction): + def __init__(self, input_file: BuildFile, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.input_file = input_file + self.output_file = output_file + output_file.deps.add(self) + + def invoke(self): + import onnx + + input_path = self.input_file.get_fs_path() + output_path = self.output_file.get_fs_path() + + original_model = onnx.load_model(str(input_path)) + converted_model = onnx.version_converter.convert_version(original_model, 17) + onnx.save(converted_model, str(output_path)) + + +class ImportOnnxAction(BuildAction): + def __init__(self, input_file: BuildFile, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.input_file = input_file + self.output_file = output_file + output_file.deps.add(self) + + def invoke(self): + import iree.compiler.tools.import_onnx.__main__ as m + + args = m.parse_arguments( + [ + str(self.input_file.get_fs_path()), + "-o", + str(self.output_file.get_fs_path()), + ] + ) + m.main(args) diff --git a/compiler/bindings/python/test/CMakeLists.txt b/compiler/bindings/python/test/CMakeLists.txt index 6f6cdb913de9..ca809a8016c4 100644 --- a/compiler/bindings/python/test/CMakeLists.txt +++ b/compiler/bindings/python/test/CMakeLists.txt @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(api) +add_subdirectory(build_api) add_subdirectory(extras) add_subdirectory(ir) add_subdirectory(tools) diff --git a/compiler/bindings/python/test/build_api/CMakeLists.txt b/compiler/bindings/python/test/build_api/CMakeLists.txt new file mode 100644 index 000000000000..b8bd81759ddc --- /dev/null +++ b/compiler/bindings/python/test/build_api/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# The mnist builder depends on onnx, which needs the torch input support. +if(IREE_INPUT_TORCH) + iree_py_test( + NAME + mnist_builder_test + SRCS + "mnist_builder_test.py" + ) +endif() diff --git a/compiler/bindings/python/test/build_api/mnist_builder.py b/compiler/bindings/python/test/build_api/mnist_builder.py new file mode 100644 index 000000000000..54d6e3064693 --- /dev/null +++ b/compiler/bindings/python/test/build_api/mnist_builder.py @@ -0,0 +1,30 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * + + +@entrypoint(description="Compiles an mnist model") +def mnist( + url=cl_arg( + "mnist-onnx-url", + default="https://github.com/onnx/models/raw/main/validated/vision/classification/mnist/model/mnist-12.onnx", + help="URL from which to download mnist", + ), +): + fetch_http( + name="mnist.onnx", + url=url, + ) + onnx_import( + name="mnist.mlir", + source="mnist.onnx", + ) + return "mnist.mlir" + + +if __name__ == "__main__": + iree_build_main() diff --git a/compiler/bindings/python/test/build_api/mnist_builder_test.py b/compiler/bindings/python/test/build_api/mnist_builder_test.py new file mode 100644 index 000000000000..02a914025882 --- /dev/null +++ b/compiler/bindings/python/test/build_api/mnist_builder_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import io +from pathlib import Path +import re +import subprocess +import unittest +import tempfile +import sys + +from iree.build import * + +THIS_DIR = Path(__file__).resolve().parent + + +class MnistBuilderTest(unittest.TestCase): + def setUp(self): + self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + self._temp_dir.__enter__() + self.output_path = Path(self._temp_dir.name) + + def tearDown(self) -> None: + self._temp_dir.__exit__(None, None, None) + + # Tests that invoking via the tool works: + # python -m iree.build {path to py file} + # We execute this out of process in order to verify the full flow. + def testBuildEntrypoint(self): + output = subprocess.check_output( + [ + sys.executable, + "-m", + "iree.build", + str(THIS_DIR / "mnist_builder.py"), + "--output-dir", + str(self.output_path), + ] + ).decode() + print("OUTPUT:", output) + output_paths = output.splitlines() + self.assertEqual(len(output_paths), 1) + output_path = Path(output_paths[0]) + self.assertTrue(output_path.is_relative_to(self.output_path)) + contents = output_path.read_text() + self.assertIn("module", contents) + + # Tests that invoking via the build module itself works + # python {path to py file} + # We execute this out of process in order to verify the full flow. + def testTargetModuleEntrypoint(self): + output = subprocess.check_output( + [ + sys.executable, + str(THIS_DIR / "mnist_builder.py"), + "--output-dir", + str(self.output_path), + ] + ).decode() + print("OUTPUT:", output) + output_paths = output.splitlines() + self.assertEqual(len(output_paths), 1) + + def testListCommand(self): + mod = load_build_module(THIS_DIR / "mnist_builder.py") + out_file = io.StringIO() + iree_build_main(mod, args=["--list"], stdout=out_file) + output = out_file.getvalue().strip() + self.assertEqual(output, "mnist") + + def testListAllCommand(self): + mod = load_build_module(THIS_DIR / "mnist_builder.py") + out_file = io.StringIO() + iree_build_main(mod, args=["--list-all"], stdout=out_file) + output = out_file.getvalue().splitlines() + self.assertIn("mnist", output) + self.assertIn("mnist/mnist.onnx", output) + + def testActionCLArg(self): + mod = load_build_module(THIS_DIR / "mnist_builder.py") + out_file = io.StringIO() + err_file = io.StringIO() + with self.assertRaisesRegex( + IOError, + re.escape("Failed to fetch URL 'https://github.com/iree-org/doesnotexist'"), + ): + iree_build_main( + mod, + args=[ + "--mnist-onnx-url", + "https://github.com/iree-org/doesnotexist", + ], + stdout=out_file, + stderr=err_file, + ) + + +if __name__ == "__main__": + try: + import onnx + except ModuleNotFoundError: + print(f"Skipping test {__file__} because Python dependency `onnx` is not found") + sys.exit(0) + + unittest.main() diff --git a/compiler/setup.py b/compiler/setup.py index 8204e820e251..bf5d734a2a88 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -454,6 +454,7 @@ def find_git_submodule_revision(submodule_path): packages=packages, entry_points={ "console_scripts": [ + "iree-build = iree.build.__main__:main", "iree-compile = iree.compiler.tools.scripts.iree_compile.__main__:main", "iree-import-onnx = iree.compiler.tools.import_onnx.__main__:_cli_main", "iree-ir-tool = iree.compiler.tools.ir_tool.__main__:_cli_main",