Skip to content

Commit

Permalink
bugfix for: 'aisle.GEMM' op operand onnx#1 must be tensor of 32-bit s…
Browse files Browse the repository at this point in the history
…ignless integer values, but got 'tensor<1x4xi64>
  • Loading branch information
darotsr committed Jun 20, 2024
1 parent 7e76263 commit 6574582
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/Conversion/ONNXToAISLE/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using namespace mlir;
namespace onnx_to_aisle {

inline void getVectorShape(Value value, llvm::SmallVector<int64_t>&Vec) {
inline void getVectorShape(Value value, llvm::SmallVector<int32_t>&Vec) {
TensorType t_type = value.getType().cast<TensorType>();
if (!t_type)
return;
Expand Down Expand Up @@ -49,16 +49,16 @@ inline void getAttrValues(::mlir::ArrayAttr array, SmallVector<int> &Vec) {
template <typename _Class, typename _Func, typename _Operation>
spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter,
_Operation &op, const char *name, _Func mem_func) {
auto elementType = rewriter.getI64Type();
auto elementType = rewriter.getI32Type();
auto paramShape = mlir::RankedTensorType::get(
llvm::ArrayRef<std::int64_t>{1, 4}, elementType);
llvm::ArrayRef<int64_t>{1, 4}, elementType);
_Class adaptor(op);
auto get_val = std::mem_fn(mem_func);
Value X = get_val(adaptor);
SmallVector<int64_t> Vec;
SmallVector<int32_t> Vec;
onnx_to_aisle::getVectorShape(X, Vec);
auto input_shape_dense =
DenseElementsAttr::get(paramShape, ArrayRef<int64_t>(Vec));
DenseElementsAttr::get(paramShape, ArrayRef<int32_t>(Vec));
auto iShapeParam = rewriter.create<spade::AISLEQConstantOp>(
op.getLoc(), name, input_shape_dense);
return iShapeParam;
Expand All @@ -67,10 +67,10 @@ spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter,
template <typename _Operation>
spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter,
_Operation &op, const char *name, const SmallVector<int64_t> &shape) {
auto elementType = rewriter.getI64Type();
auto elementType = rewriter.getI32Type();
auto paramShape = mlir::RankedTensorType::get(
llvm::ArrayRef<std::int64_t>{1, 4}, elementType);
SmallVector<int64_t> Vec;
llvm::ArrayRef<int64_t>{1, 4}, elementType);
SmallVector<int32_t> Vec;
//make it rank 4
auto rank = shape.size();
/*
Expand All @@ -86,7 +86,7 @@ spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter,
for (size_t i = 0; i < rank; ++i)
Vec.push_back(shape[i]);
auto input_shape_dense =
DenseElementsAttr::get(paramShape, ArrayRef<int64_t>(Vec));
DenseElementsAttr::get(paramShape, ArrayRef<int32_t>(Vec));
auto iShapeParam = rewriter.create<spade::AISLEQConstantOp>(
op.getLoc(), name, input_shape_dense);
return iShapeParam;
Expand Down

0 comments on commit 6574582

Please sign in to comment.