Skip to content

Commit

Permalink
fix: Add test suite for torch.compile backend (#1849)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Apr 26, 2023
1 parent 2addf5e commit 7b67780
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 39 deletions.
17 changes: 17 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,22 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile-core:
description: "Test the Dynamo torch_compile path"
steps:
- run:
name: Run Dynamo torch_compile core tests
command: |
cd py/torch_tensorrt/dynamo/torch_compile
pushd test/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
popd
- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile:
description: "Test the Dynamo torch_compile path"
steps:
Expand Down Expand Up @@ -953,6 +969,7 @@ jobs:
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-torch_compile
- test-dynamo-torch_compile-core
- test-dynamo-fx_ts

package-x86_64-linux:
Expand Down
39 changes: 0 additions & 39 deletions py/torch_tensorrt/dynamo/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,3 @@ def cosine_similarity(gt_tensor, pred_tensor):
res = res.cpu().detach().item()

return res


def same_output_format(trt_output, torch_output):
# For each encountered collection type, ensure the torch and trt outputs agree
# on type and size, checking recursively through all member elements.
if isinstance(trt_output, tuple):
return (
isinstance(torch_output, tuple)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, list):
return (
isinstance(torch_output, list)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, dict):
return (
isinstance(torch_output, dict)
and (len(trt_output) == len(torch_output))
and (trt_output.keys() == torch_output.keys())
and all(
same_output_format(trt_output[key], torch_output[key])
for key in trt_output.keys()
)
)
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
raise AssertionError(
"Unsupported output type 'set' encountered in output format check."
)
else:
return type(trt_output) is type(torch_output)
57 changes: 57 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
from utils import same_output_format
import torch_tensorrt
import unittest
import torch


class TestPrepareDevice(unittest.TestCase):
def test_prepare_cuda_device(self):
gpu_id = 0
device = torch.device(f"cuda:{gpu_id}")
prepared_device = prepare_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)

def test_prepare_trt_device(self):
gpu_id = 4
device = torch_tensorrt.Device(gpu_id=gpu_id)
prepared_device = prepare_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)


class TestPrepareInputs(unittest.TestCase):
def test_prepare_single_tensor_input(self):
inputs = [torch.ones((4, 4))]
prepared_inputs = prepare_inputs(inputs)
self.assertTrue(
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
)

def test_prepare_trt_input(self):
inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)]
prepared_inputs = prepare_inputs(inputs)
self.assertTrue(
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
)

def test_prepare_mixed_type_compound_tensor_input(self):
inputs = {
"first": [
torch.ones((4, 4)),
torch_tensorrt.Input(shape=(4, 3), dtype=torch.float),
],
"second": (
torch.rand((5, 1)),
(torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))),
),
}
prepared_inputs = prepare_inputs(inputs)
self.assertTrue(
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
)


if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from functools import partial
from utils import fx_dynamo_testing_backend
from torch.testing._internal.common_utils import run_tests, TestCase
import torch


class TestTRTModule(TestCase):
def test_lowering_inplace_op(self):
class FullySupported(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
x = torch.ops.aten.add_.Tensor(x, y)
x = torch.ops.aten.relu_.default(x)
return x

# Operations expected to be included in the traced graph after decompositions
expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default}

# Trace module and set up custom backend to track intermediate graphs
fx_graph = torch.fx.symbolic_trace(FullySupported())
partitioned_graphs = []
custom_backend = partial(
fx_dynamo_testing_backend,
store_intermediate_graphs=partitioned_graphs,
)

# Invoke compilation
compiled_graph = torch.compile(fx_graph, backend=custom_backend)
compiled_graph(
torch.rand(
5,
).cuda(),
torch.rand(
5,
).cuda(),
)

# Iterate over intermediate graphs, attempt to match nodes
for fx_module in partitioned_graphs:
for _, submodule in fx_module.named_children():
for node in submodule.graph.nodes:

if node.op == "call_function" and node.target in expected_ops:
expected_ops.remove(node.target)

self.assertEqual(
len(expected_ops), 0, "All operators should have been decomposed"
)


if __name__ == "__main__":
run_tests()
68 changes: 68 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from torch_tensorrt.dynamo.torch_compile.lowering import partition
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from copy import deepcopy
import numpy as np


class TestPartitioning(TestCase):
def test_partition_fully_supported_one_op(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
partitioned_graph = partition(deepcopy(fx_graph))
self.assertEqual(
len(list(partitioned_graph.named_children())),
0,
"Single operators should not be segmented",
)

def test_partition_fully_supported_multi_op(self):
class FullySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
sum_ = torch.ops.aten.sub.Tensor(x, y)
concat_ = torch.ops.aten.cat.default(x, sum_)
relu_ = torch.ops.aten.relu.default(concat_)
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
return pow_

fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph))
self.assertEqual(
len(list(partitioned_graph.named_children())),
1,
"All operators are supported, there should be one segment",
)

def test_partition_partially_supported_multi_op(self):
class PartiallySupportedMultiOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, y):
sum_1 = torch.ops.aten.add.Tensor(x, y)
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
sum_ = np.sum(sum_1) + np.sum(sum_2)
relu_ = torch.ops.aten.relu.default(sum_)
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
return pow_

fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph))
self.assertEqual(
len(list(partitioned_graph.named_children())),
2,
"Unsupported operators interleave supported ones, expected 2 segments",
)


if __name__ == "__main__":
run_tests()
94 changes: 94 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from copy import deepcopy
from functools import partial
from typing import List, Sequence
import torch
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
partition,
)

from torch._dynamo.backends.common import fake_tensor_unsupported

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


@fake_tensor_unsupported
def fx_dynamo_testing_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
*,
store_intermediate_graphs: List,
):
"""Helper Dynamo backend exclusively for testing"""
custom_backend = partial(
compile_module_testing,
store_intermediate_graphs=store_intermediate_graphs,
)

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
sample_inputs,
fw_compiler=make_boxed_compiler(custom_backend),
decompositions=get_decompositions(),
)


def compile_module_testing(
gm: torch.fx.GraphModule,
example_inputs: Sequence[torch.Tensor],
*,
store_intermediate_graphs: List,
) -> torch.fx.GraphModule:
"""Helper compiler exclusively for testing"""
partitioned_module = partition(gm)

# Store intermediate graph from partitioned module
store_intermediate_graphs.append(deepcopy(partitioned_module))

return partitioned_module


def same_output_format(trt_output, torch_output, enforce_tensor_type=True):
# For each encountered collection type, ensure the torch and trt outputs agree
# on type and size, checking recursively through all member elements.
if isinstance(trt_output, tuple):
return (
isinstance(torch_output, tuple)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry, enforce_tensor_type)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, list):
return (
isinstance(torch_output, list)
and (len(trt_output) == len(torch_output))
and all(
same_output_format(trt_entry, torch_entry, enforce_tensor_type)
for trt_entry, torch_entry in zip(trt_output, torch_output)
)
)
elif isinstance(trt_output, dict):
return (
isinstance(torch_output, dict)
and (len(trt_output) == len(torch_output))
and (trt_output.keys() == torch_output.keys())
and all(
same_output_format(
trt_output[key], torch_output[key], enforce_tensor_type
)
for key in trt_output.keys()
)
)
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
raise AssertionError(
"Unsupported output type 'set' encountered in output format check."
)
elif enforce_tensor_type:
return type(trt_output) is type(torch_output)
else:
return True
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/torch_compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
raise ValueError(
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
)

return device

0 comments on commit 7b67780

Please sign in to comment.