Skip to content

Commit

Permalink
I64TensorAttr -> DenseI64ArrayAttr
Browse files Browse the repository at this point in the history
  • Loading branch information
sjain-stanford committed Jan 31, 2024
1 parent 1db8e32 commit a0ad9a8
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 93 deletions.
18 changes: 9 additions & 9 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
}

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
Expand Down Expand Up @@ -424,7 +424,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down Expand Up @@ -542,7 +542,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
} else {
return op.emitError("operator haven't been supported");
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
compareTypeAttr);
Expand Down Expand Up @@ -570,7 +570,7 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
Value rhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
Expand Down Expand Up @@ -757,7 +757,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
rewriter.getDenseI64ArrayAttr(dimensionNumbers));
}
return success();
}
Expand Down Expand Up @@ -887,7 +887,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Expand Down Expand Up @@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(

Value window =
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions;
DenseI64ArrayAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions);
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, mulOut, start,
Expand Down Expand Up @@ -1721,7 +1721,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(), outType, scalarTensor, shapeTensor,
rewriter.getI64TensorAttr({}));
rewriter.getDenseI64ArrayAttr({}));
rewriter.replaceOp(op, bcastScalar);
return success();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}

Expand Down Expand Up @@ -835,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, resultType, input, finalIndexTensor, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
rewriter.getDenseI64ArrayAttr(sliceSizes));
return success();
}

Expand Down
32 changes: 10 additions & 22 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType());

RankedTensorType attrTy =
RankedTensorType::get({static_cast<int64_t>(broadcastDims.size())},
rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims);

auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, stablehloShape, broadcastAttr);
Expand Down Expand Up @@ -549,8 +546,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {

// Prepare for transposed convolution
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr stablehloStride =
rewriter.getI64TensorAttr(stablehloStrideVec);
auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
Expand All @@ -563,15 +559,13 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
stablehloPaddingVec);
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
DenseIntElementsAttr stablehloLhsDilation =
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
auto stablehloLhsDilation = rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec);
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(),
stablehloRhsDilationVec.begin());
DenseIntElementsAttr stablehloRhsDilation =
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec);

DenseElementsAttr windowReversal;
DenseBoolArrayAttr windowReversal;
ArrayAttr precisionConfig;

SmallVector<int64_t> spatialDims;
Expand Down Expand Up @@ -614,10 +608,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
int64_t nDims = outType.getRank();

// Get stablehlo::ConvolutionOp attributes
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()),
stride);
auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride);
std::vector<int64_t> stablehloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) {
stablehloPaddingVec.emplace_back(padding[i]);
Expand All @@ -628,10 +619,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()),
stablehloPaddingVec);
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()),
dilation);
auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation);
SmallVector<int64_t> spatialDimensions;
for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i);
Expand All @@ -648,8 +636,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
/*outputSpatialDimensions=*/spatialDimensions);

// stablehlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr stablehloLhsDilation;
DenseElementsAttr windowReversal;
DenseI64ArrayAttr stablehloLhsDilation;
DenseBoolArrayAttr windowReversal;
ArrayAttr precisionConfig;

auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
Expand Down Expand Up @@ -781,7 +769,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);

DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, outTy, stablehloConvResult, bias, bcastDimensions);
return success();
Expand Down
73 changes: 18 additions & 55 deletions lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -242,19 +233,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -453,20 +435,10 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down Expand Up @@ -508,7 +480,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
.value();
}
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseIntElementsAttr bcastDimensions;
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success();
Expand All @@ -528,7 +500,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({}));

Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
Expand Down Expand Up @@ -599,19 +571,10 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
stablehloPadding[dim * 2] = inputShape[dim] - 1;

DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
Expand Down
7 changes: 2 additions & 5 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
if (!do_bcast) {
return input;
}
DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()),
bcastDims);
auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims);
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
op->getLoc(), outType, input, bcast_attr);
return bcast_op.getResult();
Expand Down Expand Up @@ -360,7 +357,7 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
return rewriter
.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({}))
.getResult();
}
} // namespace hlo
Expand Down

0 comments on commit a0ad9a8

Please sign in to comment.