-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Add test suite for torch.compile backend (#1849)
- Loading branch information
Showing
7 changed files
with
292 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs | ||
from utils import same_output_format | ||
import torch_tensorrt | ||
import unittest | ||
import torch | ||
|
||
|
||
class TestPrepareDevice(unittest.TestCase): | ||
def test_prepare_cuda_device(self): | ||
gpu_id = 0 | ||
device = torch.device(f"cuda:{gpu_id}") | ||
prepared_device = prepare_device(device) | ||
self.assertTrue(isinstance(prepared_device, torch.device)) | ||
self.assertTrue(prepared_device.index == gpu_id) | ||
|
||
def test_prepare_trt_device(self): | ||
gpu_id = 4 | ||
device = torch_tensorrt.Device(gpu_id=gpu_id) | ||
prepared_device = prepare_device(device) | ||
self.assertTrue(isinstance(prepared_device, torch.device)) | ||
self.assertTrue(prepared_device.index == gpu_id) | ||
|
||
|
||
class TestPrepareInputs(unittest.TestCase): | ||
def test_prepare_single_tensor_input(self): | ||
inputs = [torch.ones((4, 4))] | ||
prepared_inputs = prepare_inputs(inputs) | ||
self.assertTrue( | ||
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) | ||
) | ||
|
||
def test_prepare_trt_input(self): | ||
inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)] | ||
prepared_inputs = prepare_inputs(inputs) | ||
self.assertTrue( | ||
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) | ||
) | ||
|
||
def test_prepare_mixed_type_compound_tensor_input(self): | ||
inputs = { | ||
"first": [ | ||
torch.ones((4, 4)), | ||
torch_tensorrt.Input(shape=(4, 3), dtype=torch.float), | ||
], | ||
"second": ( | ||
torch.rand((5, 1)), | ||
(torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))), | ||
), | ||
} | ||
prepared_inputs = prepare_inputs(inputs) | ||
self.assertTrue( | ||
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
54 changes: 54 additions & 0 deletions
54
py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from functools import partial | ||
from utils import fx_dynamo_testing_backend | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
import torch | ||
|
||
|
||
class TestTRTModule(TestCase): | ||
def test_lowering_inplace_op(self): | ||
class FullySupported(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
||
def forward(self, x, y): | ||
x = torch.ops.aten.add_.Tensor(x, y) | ||
x = torch.ops.aten.relu_.default(x) | ||
return x | ||
|
||
# Operations expected to be included in the traced graph after decompositions | ||
expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default} | ||
|
||
# Trace module and set up custom backend to track intermediate graphs | ||
fx_graph = torch.fx.symbolic_trace(FullySupported()) | ||
partitioned_graphs = [] | ||
custom_backend = partial( | ||
fx_dynamo_testing_backend, | ||
store_intermediate_graphs=partitioned_graphs, | ||
) | ||
|
||
# Invoke compilation | ||
compiled_graph = torch.compile(fx_graph, backend=custom_backend) | ||
compiled_graph( | ||
torch.rand( | ||
5, | ||
).cuda(), | ||
torch.rand( | ||
5, | ||
).cuda(), | ||
) | ||
|
||
# Iterate over intermediate graphs, attempt to match nodes | ||
for fx_module in partitioned_graphs: | ||
for _, submodule in fx_module.named_children(): | ||
for node in submodule.graph.nodes: | ||
|
||
if node.op == "call_function" and node.target in expected_ops: | ||
expected_ops.remove(node.target) | ||
|
||
self.assertEqual( | ||
len(expected_ops), 0, "All operators should have been decomposed" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
68 changes: 68 additions & 0 deletions
68
py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from torch_tensorrt.dynamo.torch_compile.lowering import partition | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
import torch | ||
from copy import deepcopy | ||
import numpy as np | ||
|
||
|
||
class TestPartitioning(TestCase): | ||
def test_partition_fully_supported_one_op(self): | ||
class FullySupportedOneOp(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
||
def forward(self, x, y): | ||
return torch.ops.aten.add.Tensor(x, y) | ||
|
||
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) | ||
partitioned_graph = partition(deepcopy(fx_graph)) | ||
self.assertEqual( | ||
len(list(partitioned_graph.named_children())), | ||
0, | ||
"Single operators should not be segmented", | ||
) | ||
|
||
def test_partition_fully_supported_multi_op(self): | ||
class FullySupportedMultiOp(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
||
def forward(self, x, y): | ||
sum_ = torch.ops.aten.sub.Tensor(x, y) | ||
concat_ = torch.ops.aten.cat.default(x, sum_) | ||
relu_ = torch.ops.aten.relu.default(concat_) | ||
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) | ||
return pow_ | ||
|
||
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) | ||
partitioned_graph = partition(deepcopy(fx_graph)) | ||
self.assertEqual( | ||
len(list(partitioned_graph.named_children())), | ||
1, | ||
"All operators are supported, there should be one segment", | ||
) | ||
|
||
def test_partition_partially_supported_multi_op(self): | ||
class PartiallySupportedMultiOp(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
||
def forward(self, x, y): | ||
sum_1 = torch.ops.aten.add.Tensor(x, y) | ||
sum_2 = torch.ops.aten.add.Tensor(x, sum_1) | ||
sum_ = np.sum(sum_1) + np.sum(sum_2) | ||
relu_ = torch.ops.aten.relu.default(sum_) | ||
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) | ||
return pow_ | ||
|
||
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) | ||
partitioned_graph = partition(deepcopy(fx_graph)) | ||
self.assertEqual( | ||
len(list(partitioned_graph.named_children())), | ||
2, | ||
"Unsupported operators interleave supported ones, expected 2 segments", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from copy import deepcopy | ||
from functools import partial | ||
from typing import List, Sequence | ||
import torch | ||
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( | ||
get_decompositions, | ||
) | ||
from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( | ||
partition, | ||
) | ||
|
||
from torch._dynamo.backends.common import fake_tensor_unsupported | ||
|
||
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler | ||
|
||
|
||
@fake_tensor_unsupported | ||
def fx_dynamo_testing_backend( | ||
gm: torch.fx.GraphModule, | ||
sample_inputs: Sequence[torch.Tensor], | ||
*, | ||
store_intermediate_graphs: List, | ||
): | ||
"""Helper Dynamo backend exclusively for testing""" | ||
custom_backend = partial( | ||
compile_module_testing, | ||
store_intermediate_graphs=store_intermediate_graphs, | ||
) | ||
|
||
# Invoke AOTAutograd to translate operators to aten | ||
return aot_module_simplified( | ||
gm, | ||
sample_inputs, | ||
fw_compiler=make_boxed_compiler(custom_backend), | ||
decompositions=get_decompositions(), | ||
) | ||
|
||
|
||
def compile_module_testing( | ||
gm: torch.fx.GraphModule, | ||
example_inputs: Sequence[torch.Tensor], | ||
*, | ||
store_intermediate_graphs: List, | ||
) -> torch.fx.GraphModule: | ||
"""Helper compiler exclusively for testing""" | ||
partitioned_module = partition(gm) | ||
|
||
# Store intermediate graph from partitioned module | ||
store_intermediate_graphs.append(deepcopy(partitioned_module)) | ||
|
||
return partitioned_module | ||
|
||
|
||
def same_output_format(trt_output, torch_output, enforce_tensor_type=True): | ||
# For each encountered collection type, ensure the torch and trt outputs agree | ||
# on type and size, checking recursively through all member elements. | ||
if isinstance(trt_output, tuple): | ||
return ( | ||
isinstance(torch_output, tuple) | ||
and (len(trt_output) == len(torch_output)) | ||
and all( | ||
same_output_format(trt_entry, torch_entry, enforce_tensor_type) | ||
for trt_entry, torch_entry in zip(trt_output, torch_output) | ||
) | ||
) | ||
elif isinstance(trt_output, list): | ||
return ( | ||
isinstance(torch_output, list) | ||
and (len(trt_output) == len(torch_output)) | ||
and all( | ||
same_output_format(trt_entry, torch_entry, enforce_tensor_type) | ||
for trt_entry, torch_entry in zip(trt_output, torch_output) | ||
) | ||
) | ||
elif isinstance(trt_output, dict): | ||
return ( | ||
isinstance(torch_output, dict) | ||
and (len(trt_output) == len(torch_output)) | ||
and (trt_output.keys() == torch_output.keys()) | ||
and all( | ||
same_output_format( | ||
trt_output[key], torch_output[key], enforce_tensor_type | ||
) | ||
for key in trt_output.keys() | ||
) | ||
) | ||
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): | ||
raise AssertionError( | ||
"Unsupported output type 'set' encountered in output format check." | ||
) | ||
elif enforce_tensor_type: | ||
return type(trt_output) is type(torch_output) | ||
else: | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters