-
Notifications
You must be signed in to change notification settings - Fork 533
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
194 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
from typing import Any | ||
|
||
import torch | ||
import torch_mlir | ||
|
||
from torch_mlir_e2e_test.tcp_backends.abc import TcpBackend | ||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem | ||
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders | ||
from .utils import ( | ||
recursively_convert_to_numpy, | ||
recursively_convert_from_numpy, | ||
) | ||
|
||
|
||
class TcpBackendTestConfig(TestConfig): | ||
"""Base class for TestConfig's that are implemented with linalg-on-tensors. | ||
This class handles all the common lowering that torch-mlir does before | ||
reaching the linalg-on-tensors abstraction level. | ||
""" | ||
def __init__(self, backend: TcpBackend): | ||
super().__init__() | ||
self.backend = backend | ||
|
||
def compile(self, program: torch.nn.Module) -> Any: | ||
example_args = convert_annotations_to_placeholders(program.forward) | ||
module = torch_mlir.compile( | ||
program, example_args, output_type="tcp") | ||
|
||
return self.backend.compile(module) | ||
|
||
|
||
|
||
def run(self, artifact: Any, trace: Trace) -> Trace: | ||
backend_module = self.backend.load(artifact) | ||
result: Trace = [] | ||
for item in trace: | ||
numpy_inputs = recursively_convert_to_numpy(item.inputs) | ||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs) | ||
output = recursively_convert_from_numpy(outputs) | ||
result.append( | ||
TraceItem(symbol=item.symbol, | ||
inputs=item.inputs, | ||
output=output)) | ||
return result |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
import abc | ||
from typing import TypeVar | ||
|
||
import torch | ||
|
||
from torch_mlir.ir import Module | ||
|
||
# A type shared between the result of `TcpBackend.compile` and the | ||
# input to `TcpBackend.load`. Each backend will likely have a | ||
# different definition of this type. | ||
CompiledArtifact = TypeVar('CompiledArtifact') | ||
|
||
# A wrapper around a backend-specific loaded program representation | ||
# that uniformly translates the `x.method(...)` interface expected of | ||
# Torch modules into appropriate lower-level operations. | ||
Invoker = TypeVar('Invoker') | ||
|
||
|
||
class TcpBackend(abc.ABC): | ||
"""The interface to an TCP backend. | ||
Backends are recommended to raise meaningful exceptions in case of error, | ||
ideally with easy reproduction instructions. | ||
""" | ||
@abc.abstractmethod | ||
def compile(self, module: Module) -> CompiledArtifact: | ||
"""Compile the provided MLIR module into a compiled artifact. | ||
The module adheres to the TCP backend contract | ||
(see the VerifyTcpBackendContract pass). | ||
The compiled artifact can be any type, but must be correctly | ||
interpreted by the `load` method. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def load(self, artifact: CompiledArtifact) -> Invoker: | ||
"""Load the compiled artifact into a uniformly invokable form. | ||
The compiled artifact is the result of a previous call to `compile`. | ||
See the description of `Invoker` for the requirements on the returned | ||
type. | ||
""" |
46 changes: 46 additions & 0 deletions
46
python/torch_mlir_e2e_test/tcp_backends/linalg_on_tensors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
from torch_mlir.ir import * | ||
from torch_mlir.passmanager import * | ||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report | ||
|
||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend | ||
|
||
from .abc import TcpBackend | ||
|
||
__all__ = [ | ||
"LinalgOnTensorsTcpBackend", | ||
] | ||
|
||
class LinalgOnTensorsTcpBackend(TcpBackend): | ||
"""Main entry-point for the linalg-on-tensors based TCP backend. | ||
This currently uses the linalg-on-tensors RefBackend for actual execution. | ||
""" | ||
def __init__(self): | ||
super().__init__() | ||
self.refbackend = RefBackendLinalgOnTensorsBackend() | ||
|
||
def compile(self, imported_module: Module): | ||
"""Compiles an imported module that satisfied the TCP backend contract. | ||
Args: | ||
imported_module: The MLIR module consisting of funcs in the TCP | ||
dialect. | ||
Returns: | ||
An opaque, backend specific compiled artifact object that can be | ||
passed to `load`. | ||
""" | ||
run_pipeline_with_repro_report( | ||
imported_module, | ||
"builtin.module(func.func(convert-tcp-to-linalg))", | ||
"Lowering TCP to Linalg-on-Tensors") | ||
|
||
return self.refbackend.compile(imported_module) | ||
|
||
def load(self, module): | ||
"""Loads a compiled artifact into the runtime.""" | ||
return self.refbackend.load(module) |