Skip to content

Commit

Permalink
[Tcp] Add e2e test for TCP
Browse files Browse the repository at this point in the history
  • Loading branch information
navahgar committed Dec 12, 2022
1 parent a15afc3 commit 21f233b
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 2 deletions.
10 changes: 8 additions & 2 deletions e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
NativeTorchTestConfig,
TcpBackendTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
EagerModeTestConfig,
Expand All @@ -26,16 +27,17 @@

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tcp_backends.linalg_on_tensors import LinalgOnTensorsTcpBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TCP_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tcp', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
Expand All @@ -44,6 +46,7 @@ def _get_argparse():
Meaning of options:
"refbackend": run through torch-mlir's RefBackend.
"mhlo": run through torch-mlir's default MHLO backend.
"tcp": run through torch-mlir's default TCP backend.
"tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -85,6 +88,9 @@ def main():
if args.config == 'mhlo':
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
elif args.config == 'tcp':
config = TcpBackendTestConfig(LinalgOnTensorsTcpBackend())
xfail_set = all_test_unique_names - TCP_PASS_SET
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
xfail_set = {}
Expand Down
23 changes: 23 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,29 @@
"AtenRoundIntModule_basic",
}

TCP_PASS_SET = {
"AtenRoundIntModule_basic",
"AtenToDeviceModule_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnMixedModule_basic",
"BoolTensorReturnTrueModule_basic",
"BroadcastToModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutEvalIntModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseToDtypeIdentityModule_basic",
"ElementwiseUnaryModule_basic",
"ExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"TModuleRank0_basic",
"TModuleRank1_basic",
"TestMultipleTensorReturn_basic",
"TypeAsSameModule_basic",
"UnsafeView1DFoldModule_basic",
"View1DFoldModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
Expand Down
17 changes: 17 additions & 0 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to MHLO.
MHLO = "mhlo"

# This output type consists of `tcp` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TCP.
TCP = "tcp"

# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = "raw"
Expand Down Expand Up @@ -243,6 +247,7 @@ def _get_for_tracing(
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.MHLO: [],
OutputType.TCP: [],
}


Expand Down Expand Up @@ -414,4 +419,16 @@ def compile(model: torch.nn.Module,
print("MHLO Backend IR")
print(mb.module)
return mb.module

elif output_type == OutputType.TCP:
run_pipeline_with_repro_report(
mb.module,
"builtin.module(torch-backend-to-tcp-backend-pipeline)",
"Lowering Torch Backend IR -> TCP Backend IR")
if verbose:
print("\n====================")
print("TCP Backend IR")
print(mb.module)
return mb.module

raise Exception(f"Unknown OutputType: {output_type}")
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig
from .mhlo_backend import MhloBackendTestConfig
from .tcp_backend import TcpBackendTestConfig
from .tosa_backend import TosaBackendTestConfig
from .eager_mode import EagerModeTestConfig
from .torchdynamo import TorchDynamoTestConfig
50 changes: 50 additions & 0 deletions python/torch_mlir_e2e_test/configs/tcp_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Any

import torch
import torch_mlir

from torch_mlir_e2e_test.tcp_backends.abc import TcpBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
)


class TcpBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level.
"""
def __init__(self, backend: TcpBackend):
super().__init__()
self.backend = backend

def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
program, example_args, output_type="tcp")

return self.backend.compile(module)



def run(self, artifact: Any, trace: Trace) -> Trace:
backend_module = self.backend.load(artifact)
result: Trace = []
for item in trace:
numpy_inputs = recursively_convert_to_numpy(item.inputs)
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
output = recursively_convert_from_numpy(outputs)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
return result
Empty file.
49 changes: 49 additions & 0 deletions python/torch_mlir_e2e_test/tcp_backends/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import abc
from typing import TypeVar

import torch

from torch_mlir.ir import Module

# A type shared between the result of `TcpBackend.compile` and the
# input to `TcpBackend.load`. Each backend will likely have a
# different definition of this type.
CompiledArtifact = TypeVar('CompiledArtifact')

# A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker')


class TcpBackend(abc.ABC):
"""The interface to an TCP backend.
Backends are recommended to raise meaningful exceptions in case of error,
ideally with easy reproduction instructions.
"""
@abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact.
The module adheres to the TCP backend contract
(see the VerifyTcpBackendContract pass).
The compiled artifact can be any type, but must be correctly
interpreted by the `load` method.
"""

@abc.abstractmethod
def load(self, artifact: CompiledArtifact) -> Invoker:
"""Load the compiled artifact into a uniformly invokable form.
The compiled artifact is the result of a previous call to `compile`.
See the description of `Invoker` for the requirements on the returned
type.
"""
46 changes: 46 additions & 0 deletions python/torch_mlir_e2e_test/tcp_backends/linalg_on_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from torch_mlir.ir import *
from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend

from .abc import TcpBackend

__all__ = [
"LinalgOnTensorsTcpBackend",
]

class LinalgOnTensorsTcpBackend(TcpBackend):
"""Main entry-point for the linalg-on-tensors based TCP backend.
This currently uses the linalg-on-tensors RefBackend for actual execution.
"""
def __init__(self):
super().__init__()
self.refbackend = RefBackendLinalgOnTensorsBackend()

def compile(self, imported_module: Module):
"""Compiles an imported module that satisfied the TCP backend contract.
Args:
imported_module: The MLIR module consisting of funcs in the TCP
dialect.
Returns:
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
run_pipeline_with_repro_report(
imported_module,
"builtin.module(func.func(convert-tcp-to-linalg))",
"Lowering TCP to Linalg-on-Tensors")

return self.refbackend.compile(imported_module)

def load(self, module):
"""Loads a compiled artifact into the runtime."""
return self.refbackend.load(module)

0 comments on commit 21f233b

Please sign in to comment.