diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h new file mode 100644 index 000000000000..058fee4da4a2 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +namespace mlir::torch::onnx_c { + +Value createConstantIntList(OpBinder binder, + ConversionPatternRewriter &rewriter, + SmallVector cstInput); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt index 807db64eac64..4a5015816609 100644 --- a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch Passes.cpp Patterns.cpp TorchOnnxToTorch.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 51a69b3091aa..bf599bf7b8df 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -8,26 +8,13 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; -static Value createConstantIntList(OpBinder binder, - ConversionPatternRewriter &rewriter, - SmallVector cstInput) { - SmallVector cstValue; - for (int64_t i : cstInput) { - cstValue.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - return rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstValue); -} - // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp new file mode 100644 index 000000000000..8f5a2e67c0cb --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -0,0 +1,28 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +Value mlir::torch::onnx_c::createConstantIntList( + OpBinder binder, ConversionPatternRewriter &rewriter, + SmallVector cstInput) { + SmallVector cstValue; + for (int64_t i : cstInput) { + cstValue.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstValue); +}