Skip to content

Commit

Permalink
[MLIR][ONNX] Add Utils files with createConstantIntList
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Jan 11, 2024
1 parent b4d7cab commit e33d6a0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 14 deletions.
23 changes: 23 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"

namespace mlir::torch::onnx_c {

Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
Passes.cpp
Patterns.cpp
TorchOnnxToTorch.cpp
Utils.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch
Expand Down
15 changes: 1 addition & 14 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,13 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

static Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput) {
SmallVector<Value> cstValue;
for (int64_t i : cstInput) {
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
return rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstValue);
}

// Simple rewrites for the default domain.
// See: https://onnx.ai/onnx/operators/
// For operators that are effectively version invariant, we register with
Expand Down
28 changes: 28 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

Value mlir::torch::onnx_c::createConstantIntList(
OpBinder binder, ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput) {
SmallVector<Value> cstValue;
for (int64_t i : cstInput) {
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
return rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstValue);
}

0 comments on commit e33d6a0

Please sign in to comment.