diff --git a/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp b/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp index 14b19b13a..69bf14300 100644 --- a/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp +++ b/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp @@ -143,100 +143,180 @@ class VosVosErkinShiftNetworks { DenseMap> rotationGroups; }; -struct RewriteLayoutConversion : public OpRewritePattern { - RewriteLayoutConversion(mlir::MLIRContext *context, - VosVosErkinShiftNetworks shiftNetworks) - : mlir::OpRewritePattern(context), - shiftNetworks(std::move(shiftNetworks)) {} - - LogicalResult matchAndRewrite(ConvertLayoutOp op, - PatternRewriter &rewriter) const override { - // Convert the input and output layouts to an explicit permutation. - AffineMap inputLayout = op.getFromLayout().getValue(); - AffineMap outputLayout = op.getToLayout().getValue(); - - // Only support a 1-D tensor - if (op.getTensor().getType().getRank() != 1) { - return op.emitError("requires a one-dimensional tensor"); +// Create a tensor with zeros everywhere except for the indices specified in +// the input `indices` vector. +Value createMask(TypedValue tensor, + const SmallVector &indices, IRRewriter &rewriter) { + auto elementType = tensor.getType().getElementType(); + SmallVector maskAttrs; + + for (int64_t i = 0; i < tensor.getType().getDimSize(0); i++) { + maskAttrs.push_back(rewriter.getIntegerAttr(elementType, 0)); + } + for (int64_t index : indices) { + maskAttrs[index] = rewriter.getIntegerAttr(elementType, 1); + } + + auto denseAttr = DenseElementsAttr::get(tensor.getType(), maskAttrs); + auto constant = + rewriter.create(tensor.getLoc(), denseAttr); + return constant.getResult(); +} + +Value rotateGroup(TypedValue tensor, + const RotationGroup &group, int64_t ciphertextSize, + IRRewriter &rewriter) { + std::optional result = std::nullopt; + // As we rotate indices by partial shifts, we need to keep track of where + // each index currently is in the tensor. + DenseMap inputIndexToCurrentPosition; + inputIndexToCurrentPosition.reserve(group.size()); + for (int64_t index : group) { + inputIndexToCurrentPosition[index] = index; + } + + for (int64_t rotationAmount = 1; rotationAmount <= ciphertextSize; + rotationAmount <<= 1) { + SmallVector indicesToRotate; + for (int64_t index : group) { + if (index & rotationAmount) { + indicesToRotate.push_back(inputIndexToCurrentPosition[index]); + } + } + if (indicesToRotate.empty()) { + continue; } - // For now assume the layout maps have one result (single ciphertext) - if (inputLayout.getNumResults() != 1 || outputLayout.getNumResults() != 1) { - return op.emitError() - << "Shift network lowering only supports layout affine_maps with " - "a single result (i.e., one ciphertext)."; + Value mask = createMask(tensor, indicesToRotate, rewriter); + arith::MulIOp maskOp = + rewriter.create(tensor.getLoc(), tensor, mask); + // rotating right, so negate the shift amount + Value rotated = rewriter.create( + tensor.getLoc(), maskOp.getResult(), + rewriter.create(tensor.getLoc(), -rotationAmount, + rewriter.getI32Type())); + + if (result.has_value()) { + result = rewriter.create(tensor.getLoc(), result.value(), + rotated); + } else { + result = rotated; } - // FIXME: Should I simplify these to better facilitate the equality check? - if (inputLayout == outputLayout) { - // Just forward the operand - rewriter.replaceOp(op, op.getOperand()); - return success(); + for (auto index : indicesToRotate) { + inputIndexToCurrentPosition[index] = + (inputIndexToCurrentPosition[index] + rotationAmount) % + ciphertextSize; } + } - // The concrete permutation is the result of iterating over the index space - // of the tensors, and mapping fromLayout.eval(index) to - // toLayout.eval(index). - ArrayRef dims = op.getTensor().getType().getShape(); - int64_t ciphertextSize = shiftNetworks.getCiphertextSize(); - // Initial permutation starts as the identity permutation. - SmallVector permutation = identity(ciphertextSize); - - // Looking for something like llvm::product_iterator, but found nothing. - // Iterating manually and using mod arithmetic to get the per-axis indices. - SmallVector indices; - indices.resize(dims.size()); - for (size_t index = 0; index < op.getTensor().getType().getNumElements(); - ++index) { - // Unflatten the index into dimensional components - int dimIndex = dims.size() - 1; - for (int64_t dim : llvm::reverse(dims)) { - indices[dimIndex] = index % dim; - index /= dim; - --dimIndex; - } + return result.has_value() ? result.value() : tensor; +} - SmallVector inputLayoutResults; - SmallVector outputLayoutResults; - SmallVector operandConstants; - for (int64_t i = 0; i < dims.size(); i++) { - operandConstants.push_back(rewriter.getI64IntegerAttr(indices[i])); - } - if (failed( - inputLayout.constantFold(operandConstants, inputLayoutResults)) || - failed(outputLayout.constantFold(operandConstants, - outputLayoutResults))) { - return op.emitError( - "unable to statically evaluate one of the two affine maps."); - } +LogicalResult convertLayoutOp(ConvertLayoutOp op, + VosVosErkinShiftNetworks &shiftNetworks, + int64_t ciphertextSize) { + IRRewriter rewriter(op.getContext()); - int64_t inputLayoutResultIndex = - cast(inputLayoutResults[0]).getInt(); - int64_t outputLayoutResultIndex = - cast(outputLayoutResults[0]).getInt(); - permutation[inputLayoutResultIndex] = outputLayoutResultIndex; - } + // Convert the input and output layouts to an explicit permutation. + AffineMap inputLayout = op.getFromLayout().getValue(); + AffineMap outputLayout = op.getToLayout().getValue(); - LLVM_DEBUG({ - llvm::dbgs() << "ConvertLayoutOp produces underlying permutation: "; - for (int i = 0; i < permutation.size(); i++) { - llvm::dbgs() << i << " -> " << permutation[i] << ", "; - if (i % 10 == 9) { - llvm::dbgs() << "\n"; - } - } - llvm::dbgs() << "\n"; - }); + // Only support a 1-D tensor + if (op.getTensor().getType().getRank() != 1) { + return op.emitError("requires a one-dimensional tensor"); + } - ArrayRef rotationGroup = - shiftNetworks.computeShiftNetwork(permutation); + // For now assume the layout maps have one result (single ciphertext) + if (inputLayout.getNumResults() != 1 || outputLayout.getNumResults() != 1) { + return op.emitError() + << "Shift network lowering only supports layout affine_maps with " + "a single result (i.e., one ciphertext)."; + } + // FIXME: Should I simplify these to better facilitate the equality check? + if (inputLayout == outputLayout) { + // Just forward the operand + rewriter.replaceOp(op, op.getOperand()); return success(); } - private: - VosVosErkinShiftNetworks shiftNetworks; -}; + // The concrete permutation is the result of iterating over the index space + // of the tensors, and mapping fromLayout.eval(index) to + // toLayout.eval(index). + ArrayRef dims = op.getTensor().getType().getShape(); + // Initial permutation starts as the identity permutation. + SmallVector permutation = identity(ciphertextSize); + + // Looking for something like llvm::product_iterator, but found nothing. + // Iterating manually and using mod arithmetic to get the per-axis indices. + SmallVector indices; + indices.resize(dims.size()); + for (size_t index = 0; index < op.getTensor().getType().getNumElements(); + ++index) { + // Unflatten the index into dimensional components + int dimIndex = dims.size() - 1; + for (int64_t dim : llvm::reverse(dims)) { + indices[dimIndex] = index % dim; + index /= dim; + --dimIndex; + } + + SmallVector inputLayoutResults; + SmallVector outputLayoutResults; + SmallVector operandConstants; + for (int64_t i = 0; i < dims.size(); i++) { + operandConstants.push_back(rewriter.getI64IntegerAttr(indices[i])); + } + if (failed( + inputLayout.constantFold(operandConstants, inputLayoutResults)) || + failed( + outputLayout.constantFold(operandConstants, outputLayoutResults))) { + return op.emitError( + "unable to statically evaluate one of the two affine maps."); + } + + int64_t inputLayoutResultIndex = + cast(inputLayoutResults[0]).getInt(); + int64_t outputLayoutResultIndex = + cast(outputLayoutResults[0]).getInt(); + permutation[inputLayoutResultIndex] = outputLayoutResultIndex; + } + + LLVM_DEBUG({ + llvm::dbgs() << "ConvertLayoutOp produces underlying permutation: "; + for (int i = 0; i < permutation.size(); i++) { + llvm::dbgs() << i << " -> " << permutation[i] << ", "; + if (i % 10 == 9) { + llvm::dbgs() << "\n"; + } + } + llvm::dbgs() << "\n"; + }); + + FrozenVector permKey = FrozenVector(std::move(permutation)); + ArrayRef rotationGroup = + shiftNetworks.computeShiftNetwork(permKey); + assert(!rotationGroup.empty() && + "Shift network must have at least one group"); + + // Process each rotation group separately with a full set of power-of-two + // shifts. Then sum the results together. + rewriter.setInsertionPointAfter(op); + std::optional result = std::nullopt; + for (const RotationGroup &group : rotationGroup) { + Value perGroupResult = + rotateGroup(op.getTensor(), group, ciphertextSize, rewriter); + if (result.has_value()) + result = + rewriter.create(op.getLoc(), *result, perGroupResult); + else + result = perGroupResult; + } + + rewriter.replaceOp(op, result.value()); + return success(); +} struct LayoutConversionToShiftNetwork : impl::LayoutConversionToShiftNetworkBase { @@ -247,9 +327,12 @@ struct LayoutConversionToShiftNetwork RewritePatternSet patterns(context); VosVosErkinShiftNetworks shiftNetworks{ciphertextSize}; - patterns.add(context, shiftNetworks); - (void)walkAndApplyPatterns(getOperation(), std::move(patterns)); + getOperation()->walk([&](ConvertLayoutOp op) { + if (failed(convertLayoutOp(op, shiftNetworks, ciphertextSize))) { + signalPassFailure(); + } + }); } };