diff --git a/src/Conversion/ONNXToAISLE/helper.hpp b/src/Conversion/ONNXToAISLE/helper.hpp index a2000991d2..7ed424a14e 100644 --- a/src/Conversion/ONNXToAISLE/helper.hpp +++ b/src/Conversion/ONNXToAISLE/helper.hpp @@ -15,7 +15,7 @@ using namespace mlir; namespace onnx_to_aisle { -inline void getVectorShape(Value value, llvm::SmallVector&Vec) { +inline void getVectorShape(Value value, llvm::SmallVector&Vec) { TensorType t_type = value.getType().cast(); if (!t_type) return; @@ -49,16 +49,16 @@ inline void getAttrValues(::mlir::ArrayAttr array, SmallVector &Vec) { template spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter, _Operation &op, const char *name, _Func mem_func) { - auto elementType = rewriter.getI64Type(); + auto elementType = rewriter.getI32Type(); auto paramShape = mlir::RankedTensorType::get( - llvm::ArrayRef{1, 4}, elementType); + llvm::ArrayRef{1, 4}, elementType); _Class adaptor(op); auto get_val = std::mem_fn(mem_func); Value X = get_val(adaptor); - SmallVector Vec; + SmallVector Vec; onnx_to_aisle::getVectorShape(X, Vec); auto input_shape_dense = - DenseElementsAttr::get(paramShape, ArrayRef(Vec)); + DenseElementsAttr::get(paramShape, ArrayRef(Vec)); auto iShapeParam = rewriter.create( op.getLoc(), name, input_shape_dense); return iShapeParam; @@ -67,10 +67,10 @@ spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter, template spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter, _Operation &op, const char *name, const SmallVector &shape) { - auto elementType = rewriter.getI64Type(); + auto elementType = rewriter.getI32Type(); auto paramShape = mlir::RankedTensorType::get( - llvm::ArrayRef{1, 4}, elementType); - SmallVector Vec; + llvm::ArrayRef{1, 4}, elementType); + SmallVector Vec; //make it rank 4 auto rank = shape.size(); /* @@ -86,7 +86,7 @@ spade::AISLEQConstantOp create(ConversionPatternRewriter &rewriter, for (size_t i = 0; i < rank; ++i) Vec.push_back(shape[i]); auto input_shape_dense = - DenseElementsAttr::get(paramShape, ArrayRef(Vec)); + DenseElementsAttr::get(paramShape, ArrayRef(Vec)); auto iShapeParam = rewriter.create( op.getLoc(), name, input_shape_dense); return iShapeParam;