Skip to content

Commit

Permalink
[REFBACKEND] Add support for returning multiple different return types.
Browse files Browse the repository at this point in the history
Added the dynamic registration of return function to the execution
engine. This makes sure that  different/multiple return types are supported.
Also, updated the .style.yapf indentation to 4.
  • Loading branch information
Prashant Kumar committed Apr 21, 2022
1 parent b69db60 commit 33c9d25
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 160 deletions.
1 change: 1 addition & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,5 @@
"ElementwiseNeFloatTensorModule_basic",
"ConvolutionModule2DStatic_basic",
"ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic",
}
49 changes: 2 additions & 47 deletions lib/RefBackend/RefBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
}

static LogicalResult mungeFunction(
FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
FuncOp func,
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Only need to call mungeFunction for functions callable from outside of the
// module.
Expand Down Expand Up @@ -147,7 +147,6 @@ static LogicalResult mungeFunction(
}

SmallVector<Operation *> toErase;
bool isSupported = true;
func.walk([&](func::ReturnOp op) {
auto types = op.getOperandTypes();
b.setInsertionPoint(op);
Expand All @@ -169,72 +168,28 @@ static LogicalResult mungeFunction(
retVals.push_back(retVal);
}

auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError("Supported return types:"
"mri1, mri32, mri64, mrf32, mrf64, i1, i64, f32, f64,"
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
"(mrf32, mrf32, mrf32)");
isSupported = false;
}

auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
});
if (!isSupported)
return failure();
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
for (Operation *op : toErase)
op->erase();
return success();
}

static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
std::set<std::string> funcNames;
Type mri1 = UnrankedMemRefType::get(b.getI1Type(), 0);
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
Type i1 = b.getI1Type();
Type i64 = b.getI64Type();
Type f32 = b.getF32Type();
Type f64 = b.getF64Type();

SmallVector<TypeRange> supportedReturnTypes = {mri1,
mri32,
mri64,
mrf32,
mrf64,
i1,
i64,
f32,
f64,
{mrf32, mri64},
{mrf32, mrf32},
{mrf64, mrf64},
{mrf32, mrf32, mrf32}};

llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
});
return funcNames;
}

namespace {
class MungeCallingConventions
: public MungeCallingConventionsBase<MungeCallingConventions> {
void runOnOperation() override {
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
static std::set<std::string> supported =
getSupportedConsumeFuncReturnFuncs(b);
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
for (auto func : module.getOps<FuncOp>()) {
if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs)))
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
return signalPassFailure();
}

Expand Down
179 changes: 66 additions & 113 deletions python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,133 +22,85 @@
]


def checkArgTypeIsSupported(ty):
def assert_arg_type_is_supported(ty):
SUPPORTED = [np.float32, np.float64, np.int32, np.int64, np.bool_]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"


memref_type_to_np_dtype = {
"mrf32": np.float32,
"mrf64": np.float64,
"mri1": np.bool_,
"mri32": np.int32,
"mri64": np.int64
}
elemental_type_to_ctype = {
"i1": ctypes.c_bool,
"i64": ctypes.c_int,
"f32": ctypes.c_float,
"f64": ctypes.c_double
}

CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"


def get_return_funcs(module):
return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX)
return_funcs = []
with module.context:
for func in module.body:
# Returns strings of the form `"refbackend.."` so `"` is deleted.
func_name = str(func.attributes["sym_name"]).replace('"', '')
if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX:
return_funcs.append(func_name)

return return_funcs


def get_ctype_func(func_name):
return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX)
ret_types = func_name[return_prefix_len:].split("_")
ctypes_arg = [None]
for type in ret_types:
if type in elemental_type_to_ctype:
ctypes_arg.append(elemental_type_to_ctype[type])
elif type in memref_type_to_np_dtype:
ctypes_arg.append(ctypes.POINTER(UnrankedMemRefDescriptor))
else:
assert False, f"Not supported type: {type}"

return ctypes.CFUNCTYPE(*ctypes_arg), ret_types


class RefBackendInvoker:

def __init__(self, module):
self.ee = ExecutionEngine(module)
self.result = None

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mri1(a):
self.result = unranked_memref_to_numpy(a, np.bool_)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mri32(a):
self.result = unranked_memref_to_numpy(a, np.int32)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mri64(a):
self.result = unranked_memref_to_numpy(a, np.int64)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32(a):
self.result = unranked_memref_to_numpy(a, np.float32)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf64(a):
self.result = unranked_memref_to_numpy(a, np.float64)

@ctypes.CFUNCTYPE(None, ctypes.c_bool)
def consume_return_i1(a):
self.result = a

@ctypes.CFUNCTYPE(None, ctypes.c_int)
def consume_return_i64(a):
self.result = a

@ctypes.CFUNCTYPE(None, ctypes.c_float)
def consume_return_f32(a):
self.result = a

@ctypes.CFUNCTYPE(None, ctypes.c_double)
def consume_return_f64(a):
self.result = a

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mri64(arg0, arg1):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.int64)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mrf32(arg0, arg1):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.float32)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf64_mrf64(arg0, arg1):
self.result = unranked_memref_to_numpy(
arg0, np.float64), unranked_memref_to_numpy(
arg1,
np.float64)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.float32), unranked_memref_to_numpy(arg2, np.float32)

self.ee.register_runtime("refbackend_consume_func_return_mri1",
consume_return_mri1)

self.ee.register_runtime("refbackend_consume_func_return_mri32",
consume_return_mri32)

self.ee.register_runtime("refbackend_consume_func_return_mri64",
consume_return_mri64)

self.ee.register_runtime("refbackend_consume_func_return_mrf32",
consume_return_mrf32)

self.ee.register_runtime("refbackend_consume_func_return_mrf64",
consume_return_mrf64)

self.ee.register_runtime("refbackend_consume_func_return_i1",
consume_return_i1)

self.ee.register_runtime("refbackend_consume_func_return_i64",
consume_return_i64)

self.ee.register_runtime("refbackend_consume_func_return_f32",
consume_return_f32)

self.ee.register_runtime("refbackend_consume_func_return_f64",
consume_return_f64)

self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mri64",
consume_return_mrf32_mri64)

self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32",
consume_return_mrf32_mrf32)

self.ee.register_runtime(
"refbackend_consume_func_return_mrf64_mrf64",
consume_return_mrf64_mrf64)

self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
consume_return_mrf32_mrf32_mrf32)
return_funcs = get_return_funcs(module)

for ret_func in return_funcs:
ctype_wrapper, ret_types = get_ctype_func(ret_func)

def consume_return_funcs(*args):
self.result = tuple([
arg if type in elemental_type_to_ctype else
unranked_memref_to_numpy(arg, memref_type_to_np_dtype[type])
for arg, type in zip(args, ret_types)
])
if len(self.result) == 1:
self.result = self.result[0]

self.ee.register_runtime(ret_func,
ctype_wrapper(consume_return_funcs))

def __getattr__(self, function_name: str):

def invoke(*args):
ffi_args = []
for arg in args:
checkArgTypeIsSupported(arg.dtype)
assert_arg_type_is_supported(arg.dtype)
ffi_args.append(
ctypes.pointer(
ctypes.pointer(get_unranked_memref_descriptor(arg))))
Expand Down Expand Up @@ -202,6 +154,7 @@ def invoke(*args):

class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
"""Main entry-point for the reference backend."""

def __init__(self):
super().__init__()

Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ def register_all_tests():
from . import cast
from . import index_put
from . import pooling
from . import return_types
70 changes: 70 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/return_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import torch

from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export

# ==============================================================================


class TestMultipleTensorReturn(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float64, True),
([-1, -1], torch.int32, True),
([-1, -1], torch.int64, True),
([-1, -1], torch.bool, True),
])
def forward(self, a, b, c, d, e):
return a, b, c, d, e


@register_test_case(module_factory=lambda: TestMultipleTensorReturn())
def TestMultipleTensorReturn_basic(module, tu: TestUtils):
module.forward(
tu.rand(3, 4).to(torch.float32),
tu.rand(2, 3).to(torch.float64),
tu.rand(2, 3).to(torch.int32),
tu.rand(2, 3).to(torch.int64),
tu.rand(2, 3).to(torch.bool))


class TestMultipleTensorAndPrimitiveTypesReturn(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.float64, True),
([-1, -1], torch.bool, True),
])
def forward(self, a, b, c):
d = 1
e = 2.3
return a, b, c, d, e


@register_test_case(
module_factory=lambda: TestMultipleTensorAndPrimitiveTypesReturn())
def TestMultipleTensorAndPrimitiveTypesReturn_basic(module, tu: TestUtils):
module.forward(
tu.rand(3, 4).to(torch.int32),
tu.rand(2, 3).to(torch.float64),
tu.rand(2, 3).to(torch.bool))


# ==============================================================================

0 comments on commit 33c9d25

Please sign in to comment.