Skip to content

Commit

Permalink
Allow getting all backend names (pytorch#8520)
Browse files Browse the repository at this point in the history
Summary:

Allow getting all backends name in both python and c++

Reviewed By: tarun292

Differential Revision: D69691354
  • Loading branch information
cccclai authored and facebook-github-bot committed Feb 18, 2025
1 parent cb2b174 commit 44978f1
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# Disable "imported but unused" (F401) checks.
_create_profile_block, # noqa: F401
_dump_profile_results, # noqa: F401
_get_backend_names, # noqa: F401
_get_operator_names, # noqa: F401
_load_bundled_program_from_buffer, # noqa: F401
_load_for_executorch, # noqa: F401
Expand Down
15 changes: 15 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/extension/data_loader/mmap_data_loader.h>
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/executor/method.h>
Expand Down Expand Up @@ -87,10 +88,12 @@ using ::executorch::extension::BufferDataLoader;
using ::executorch::extension::MallocMemoryAllocator;
using ::executorch::extension::MmapDataLoader;
using ::executorch::runtime::ArrayRef;
using ::executorch::runtime::Backend;
using ::executorch::runtime::DataLoader;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
using ::executorch::runtime::EventTracerDebugLogLevel;
using ::executorch::runtime::get_registered_backends;
using ::executorch::runtime::get_registered_kernels;
using ::executorch::runtime::HierarchicalAllocator;
using ::executorch::runtime::Kernel;
Expand Down Expand Up @@ -975,6 +978,17 @@ py::list get_operator_names() {
return res;
}

py::list get_backend_names() {
Span<const Backend> backends = get_registered_backends();
py::list res;
for (const Backend& backend : backends) {
if (backend.name != nullptr) {
res.append(py::cast(backend.name));
}
}
return res;
}

} // namespace

PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
Expand Down Expand Up @@ -1029,6 +1043,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
},
call_guard);
m.def("_get_operator_names", &get_operator_names);
m.def("_get_backend_names", &get_backend_names);
m.def("_create_profile_block", &create_profile_block, call_guard);
m.def(
"_reset_profile_results",
Expand Down
9 changes: 9 additions & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ def _load_bundled_program_from_buffer(
"""
...

@experimental("This API is experimental and subject to change without notice.")
def _get_backend_names() -> List[str]:
"""
.. warning::
This API is experimental and subject to change without notice.
"""
...

@experimental("This API is experimental and subject to change without notice.")
def _get_operator_names() -> List[str]:
"""
Expand Down
8 changes: 8 additions & 0 deletions extension/pybindings/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@ runtime.python_test(
"//executorch/kernels/quantized:aot_lib",
],
)

runtime.python_test(
name = "test_backend_pybinding",
srcs = ["test_backend_pybinding.py"],
deps = [
"//executorch/extension/pybindings:portable_lib",
],
)
12 changes: 12 additions & 0 deletions extension/pybindings/test/test_backend_pybinding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import unittest

from executorch.extension.pybindings.portable_lib import _get_backend_names


class TestBackendsPybinding(unittest.TestCase):
def test_backend_name_list(
self,
) -> None:
all_backend_name = _get_backend_names()
self.assertGreater(len(all_backend_name), 1)
self.assertIn("XnnpackBackend", all_backend_name)
4 changes: 4 additions & 0 deletions runtime/backend/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,9 @@ Error register_backend(const Backend& backend) {
return Error::Ok;
}

Span<const Backend> get_registered_backends() {
return {registered_backends, num_registered_backends};
}

} // namespace runtime
} // namespace executorch
6 changes: 6 additions & 0 deletions runtime/backend/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <executorch/runtime/core/freeable_buffer.h>
#include <executorch/runtime/core/memory_allocator.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/platform/compiler.h>

namespace executorch {
Expand Down Expand Up @@ -139,6 +140,11 @@ struct Backend {
*/
ET_NODISCARD Error register_backend(const Backend& backend);

/**
* Returns all registered backends.
*/
Span<const Backend> get_registered_backends();

} // namespace runtime
} // namespace executorch

Expand Down

0 comments on commit 44978f1

Please sign in to comment.