Skip to content

Commit

Permalink
Allow getting all backend names (pytorch#8520)
Browse files Browse the repository at this point in the history
Summary:

Allow getting all backends name in both python and c++

Reviewed By: omerjerk

Differential Revision: D69691354
  • Loading branch information
cccclai authored and facebook-github-bot committed Feb 21, 2025
1 parent 735f16e commit c6eb626
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 1 deletion.
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_create_profile_block, # noqa: F401
_dump_profile_results, # noqa: F401
_get_operator_names, # noqa: F401
_get_registered_backend_names, # noqa: F401
_load_bundled_program_from_buffer, # noqa: F401
_load_for_executorch, # noqa: F401
_load_for_executorch_from_buffer, # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/extension/data_loader/mmap_data_loader.h>
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/executor/method.h>
Expand Down Expand Up @@ -91,6 +92,8 @@ using ::executorch::runtime::DataLoader;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
using ::executorch::runtime::EventTracerDebugLogLevel;
using ::executorch::runtime::get_backend_name;
using ::executorch::runtime::get_num_registered_backends;
using ::executorch::runtime::get_registered_kernels;
using ::executorch::runtime::HierarchicalAllocator;
using ::executorch::runtime::Kernel;
Expand Down Expand Up @@ -975,6 +978,18 @@ py::list get_operator_names() {
return res;
}

py::list get_registered_backend_names() {
size_t n_of_registered_backends = get_num_registered_backends();
py::list res;
for (size_t i = 0; i < n_of_registered_backends; i++) {
auto backend_name_res = get_backend_name(i);
THROW_IF_ERROR(backend_name_res.error(), "Failed to get backend name");
auto backend_name = backend_name_res.get();
res.append(backend_name);
}
return res;
}

} // namespace

PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
Expand Down Expand Up @@ -1028,6 +1043,10 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
prof_result.num_bytes);
},
call_guard);
m.def(
"_get_registered_backend_names",
&get_registered_backend_names,
call_guard);
m.def("_get_operator_names", &get_operator_names);
m.def("_create_profile_block", &create_profile_block, call_guard);
m.def(
Expand Down
8 changes: 8 additions & 0 deletions extension/pybindings/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@ runtime.python_test(
"//executorch/kernels/quantized:aot_lib",
],
)

runtime.python_test(
name = "test_backend_pybinding",
srcs = ["test_backend_pybinding.py"],
deps = [
"//executorch/runtime:runtime",
],
)
14 changes: 14 additions & 0 deletions extension/pybindings/test/test_backend_pybinding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest

from executorch.runtime import Runtime


class TestBackendsPybinding(unittest.TestCase):
def test_backend_name_list(
self,
) -> None:

runtime = Runtime.get()
registered_backend_names = runtime.backend_registry.registered_backend_names
self.assertGreaterEqual(len(registered_backend_names), 1)
self.assertIn("XnnpackBackend", registered_backend_names)
18 changes: 17 additions & 1 deletion runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import functools
from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Dict, Optional, Sequence, Set, Union
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Set, Union

try:
from executorch.extension.pybindings.portable_lib import (
Expand Down Expand Up @@ -125,6 +125,21 @@ def load_method(self, name: str) -> Optional[Method]:
return self._methods.get(name, None)


class BackendRegistry:
"""The registry of backends that are available to the runtime."""

def __init__(self, legacy_module: ModuleType) -> None:
# TODO: Expose the kernel callables to Python.
self._legacy_module = legacy_module

@property
def registered_backend_names(self) -> List[str]:
"""
Returns the names of all registered backends as a list of strings.
"""
return self._legacy_module._get_registered_backend_names()


class OperatorRegistry:
"""The registry of operators that are available to the runtime."""

Expand Down Expand Up @@ -157,6 +172,7 @@ def get() -> "Runtime":

def __init__(self, *, legacy_module: ModuleType) -> None:
# Public attributes.
self.backend_registry = BackendRegistry(legacy_module)
self.operator_registry = OperatorRegistry(legacy_module)
# Private attributes.
self._legacy_module = legacy_module
Expand Down
11 changes: 11 additions & 0 deletions runtime/backend/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,16 @@ Error register_backend(const Backend& backend) {
return Error::Ok;
}

size_t get_num_registered_backends() {
return num_registered_backends;
}

Result<const char*> get_backend_name(size_t index) {
if (index >= num_registered_backends) {
return Error::InvalidArgument;
}
return registered_backends[index].name;
}

} // namespace runtime
} // namespace executorch
10 changes: 10 additions & 0 deletions runtime/backend/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ struct Backend {
*/
ET_NODISCARD Error register_backend(const Backend& backend);

/**
* Returns the number of registered backends.
*/
size_t get_num_registered_backends();

/**
* Returns the backend name at the given index.
*/
Result<const char*> get_backend_name(size_t index);

} // namespace runtime
} // namespace executorch

Expand Down

0 comments on commit c6eb626

Please sign in to comment.