Skip to content

Commit

Permalink
[ONNX] Support onnx.LSTM (llvm#2969)
Browse files Browse the repository at this point in the history
This PR only performs a lit test. In lieu of an e2e test, nod-ai/SHARK-TestSuite#142 makede sure that the lowering works & the numbers check out.

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
  • Loading branch information
renxida and Xida Ren authored Apr 8, 2024
1 parent 1d6e4c3 commit dd967eb
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 2 deletions.
19 changes: 19 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ struct OpBinder {
return failure();
}

ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr)
return success();
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
StringAttr stringAttr = element.dyn_cast<StringAttr>();
if (!stringAttr)
return failure();
values.push_back(stringAttr.getValue().str());
}
return success();
}
return failure();
}

ParseResult denseElementsAttr(ElementsAttr elementsattr,
StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
Expand Down
6 changes: 6 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

namespace mlir::torch::onnx_c {

Expand All @@ -20,6 +23,9 @@ Value createConstantIntList(OpBinder binder,

Type getQTorchTypeFromTorchIntType(Type ty);

LogicalResult OnnxLstmExpander(OpBinder binder,
ConversionPatternRewriter &rewriter);

bool areAllElementsDistinct(SmallVector<int64_t> array);

} // namespace mlir::torch::onnx_c
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
DefaultDomainAtoF.cpp
DefaultDomainGtoP.cpp
DefaultDomainQtoZ.cpp
OnnxLstmExpander.cpp
Passes.cpp
Patterns.cpp
TorchOnnxToTorch.cpp
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand);
return success();
});
patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander);
patterns.onOp(
"LogSoftmax", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});
patterns.onOp(
"Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
Value axes;
Expand Down
Loading

0 comments on commit dd967eb

Please sign in to comment.