Skip to content

Commit

Permalink
finish first draft of layout conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Feb 1, 2025
1 parent 899d1a8 commit de67600
Showing 1 changed file with 165 additions and 82 deletions.
247 changes: 165 additions & 82 deletions lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,100 +143,180 @@ class VosVosErkinShiftNetworks {
DenseMap<Permutation, llvm::SmallVector<RotationGroup>> rotationGroups;
};

struct RewriteLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
RewriteLayoutConversion(mlir::MLIRContext *context,
VosVosErkinShiftNetworks shiftNetworks)
: mlir::OpRewritePattern<ConvertLayoutOp>(context),
shiftNetworks(std::move(shiftNetworks)) {}

LogicalResult matchAndRewrite(ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
// Convert the input and output layouts to an explicit permutation.
AffineMap inputLayout = op.getFromLayout().getValue();
AffineMap outputLayout = op.getToLayout().getValue();

// Only support a 1-D tensor
if (op.getTensor().getType().getRank() != 1) {
return op.emitError("requires a one-dimensional tensor");
// Create a tensor with zeros everywhere except for the indices specified in
// the input `indices` vector.
Value createMask(TypedValue<RankedTensorType> tensor,
const SmallVector<int64_t> &indices, IRRewriter &rewriter) {
auto elementType = tensor.getType().getElementType();
SmallVector<Attribute> maskAttrs;

for (int64_t i = 0; i < tensor.getType().getDimSize(0); i++) {
maskAttrs.push_back(rewriter.getIntegerAttr(elementType, 0));
}
for (int64_t index : indices) {
maskAttrs[index] = rewriter.getIntegerAttr(elementType, 1);
}

auto denseAttr = DenseElementsAttr::get(tensor.getType(), maskAttrs);
auto constant =
rewriter.create<arith::ConstantOp>(tensor.getLoc(), denseAttr);
return constant.getResult();
}

Value rotateGroup(TypedValue<RankedTensorType> tensor,
const RotationGroup &group, int64_t ciphertextSize,
IRRewriter &rewriter) {
std::optional<Value> result = std::nullopt;
// As we rotate indices by partial shifts, we need to keep track of where
// each index currently is in the tensor.
DenseMap<int64_t, int64_t> inputIndexToCurrentPosition;
inputIndexToCurrentPosition.reserve(group.size());
for (int64_t index : group) {
inputIndexToCurrentPosition[index] = index;
}

for (int64_t rotationAmount = 1; rotationAmount <= ciphertextSize;
rotationAmount <<= 1) {
SmallVector<int64_t> indicesToRotate;
for (int64_t index : group) {
if (index & rotationAmount) {
indicesToRotate.push_back(inputIndexToCurrentPosition[index]);
}
}
if (indicesToRotate.empty()) {
continue;
}

// For now assume the layout maps have one result (single ciphertext)
if (inputLayout.getNumResults() != 1 || outputLayout.getNumResults() != 1) {
return op.emitError()
<< "Shift network lowering only supports layout affine_maps with "
"a single result (i.e., one ciphertext).";
Value mask = createMask(tensor, indicesToRotate, rewriter);
arith::MulIOp maskOp =
rewriter.create<arith::MulIOp>(tensor.getLoc(), tensor, mask);
// rotating right, so negate the shift amount
Value rotated = rewriter.create<tensor_ext::RotateOp>(
tensor.getLoc(), maskOp.getResult(),
rewriter.create<arith::ConstantIntOp>(tensor.getLoc(), -rotationAmount,
rewriter.getI32Type()));

if (result.has_value()) {
result = rewriter.create<arith::AddIOp>(tensor.getLoc(), result.value(),
rotated);
} else {
result = rotated;
}

// FIXME: Should I simplify these to better facilitate the equality check?
if (inputLayout == outputLayout) {
// Just forward the operand
rewriter.replaceOp(op, op.getOperand());
return success();
for (auto index : indicesToRotate) {
inputIndexToCurrentPosition[index] =
(inputIndexToCurrentPosition[index] + rotationAmount) %
ciphertextSize;
}
}

// The concrete permutation is the result of iterating over the index space
// of the tensors, and mapping fromLayout.eval(index) to
// toLayout.eval(index).
ArrayRef<int64_t> dims = op.getTensor().getType().getShape();
int64_t ciphertextSize = shiftNetworks.getCiphertextSize();
// Initial permutation starts as the identity permutation.
SmallVector<int64_t> permutation = identity(ciphertextSize);

// Looking for something like llvm::product_iterator, but found nothing.
// Iterating manually and using mod arithmetic to get the per-axis indices.
SmallVector<int64_t, 4> indices;
indices.resize(dims.size());
for (size_t index = 0; index < op.getTensor().getType().getNumElements();
++index) {
// Unflatten the index into dimensional components
int dimIndex = dims.size() - 1;
for (int64_t dim : llvm::reverse(dims)) {
indices[dimIndex] = index % dim;
index /= dim;
--dimIndex;
}
return result.has_value() ? result.value() : tensor;
}

SmallVector<Attribute> inputLayoutResults;
SmallVector<Attribute> outputLayoutResults;
SmallVector<Attribute> operandConstants;
for (int64_t i = 0; i < dims.size(); i++) {
operandConstants.push_back(rewriter.getI64IntegerAttr(indices[i]));
}
if (failed(
inputLayout.constantFold(operandConstants, inputLayoutResults)) ||
failed(outputLayout.constantFold(operandConstants,
outputLayoutResults))) {
return op.emitError(
"unable to statically evaluate one of the two affine maps.");
}
LogicalResult convertLayoutOp(ConvertLayoutOp op,
VosVosErkinShiftNetworks &shiftNetworks,
int64_t ciphertextSize) {
IRRewriter rewriter(op.getContext());

int64_t inputLayoutResultIndex =
cast<IntegerAttr>(inputLayoutResults[0]).getInt();
int64_t outputLayoutResultIndex =
cast<IntegerAttr>(outputLayoutResults[0]).getInt();
permutation[inputLayoutResultIndex] = outputLayoutResultIndex;
}
// Convert the input and output layouts to an explicit permutation.
AffineMap inputLayout = op.getFromLayout().getValue();
AffineMap outputLayout = op.getToLayout().getValue();

LLVM_DEBUG({
llvm::dbgs() << "ConvertLayoutOp produces underlying permutation: ";
for (int i = 0; i < permutation.size(); i++) {
llvm::dbgs() << i << " -> " << permutation[i] << ", ";
if (i % 10 == 9) {
llvm::dbgs() << "\n";
}
}
llvm::dbgs() << "\n";
});
// Only support a 1-D tensor
if (op.getTensor().getType().getRank() != 1) {
return op.emitError("requires a one-dimensional tensor");
}

ArrayRef<RotationGroup> rotationGroup =
shiftNetworks.computeShiftNetwork(permutation);
// For now assume the layout maps have one result (single ciphertext)
if (inputLayout.getNumResults() != 1 || outputLayout.getNumResults() != 1) {
return op.emitError()
<< "Shift network lowering only supports layout affine_maps with "
"a single result (i.e., one ciphertext).";
}

// FIXME: Should I simplify these to better facilitate the equality check?
if (inputLayout == outputLayout) {
// Just forward the operand
rewriter.replaceOp(op, op.getOperand());
return success();
}

private:
VosVosErkinShiftNetworks shiftNetworks;
};
// The concrete permutation is the result of iterating over the index space
// of the tensors, and mapping fromLayout.eval(index) to
// toLayout.eval(index).
ArrayRef<int64_t> dims = op.getTensor().getType().getShape();
// Initial permutation starts as the identity permutation.
SmallVector<int64_t> permutation = identity(ciphertextSize);

// Looking for something like llvm::product_iterator, but found nothing.
// Iterating manually and using mod arithmetic to get the per-axis indices.
SmallVector<int64_t, 4> indices;
indices.resize(dims.size());
for (size_t index = 0; index < op.getTensor().getType().getNumElements();
++index) {
// Unflatten the index into dimensional components
int dimIndex = dims.size() - 1;
for (int64_t dim : llvm::reverse(dims)) {
indices[dimIndex] = index % dim;
index /= dim;
--dimIndex;
}

SmallVector<Attribute> inputLayoutResults;
SmallVector<Attribute> outputLayoutResults;
SmallVector<Attribute> operandConstants;
for (int64_t i = 0; i < dims.size(); i++) {
operandConstants.push_back(rewriter.getI64IntegerAttr(indices[i]));
}
if (failed(
inputLayout.constantFold(operandConstants, inputLayoutResults)) ||
failed(
outputLayout.constantFold(operandConstants, outputLayoutResults))) {
return op.emitError(
"unable to statically evaluate one of the two affine maps.");
}

int64_t inputLayoutResultIndex =
cast<IntegerAttr>(inputLayoutResults[0]).getInt();
int64_t outputLayoutResultIndex =
cast<IntegerAttr>(outputLayoutResults[0]).getInt();
permutation[inputLayoutResultIndex] = outputLayoutResultIndex;
}

LLVM_DEBUG({
llvm::dbgs() << "ConvertLayoutOp produces underlying permutation: ";
for (int i = 0; i < permutation.size(); i++) {
llvm::dbgs() << i << " -> " << permutation[i] << ", ";
if (i % 10 == 9) {
llvm::dbgs() << "\n";
}
}
llvm::dbgs() << "\n";
});

FrozenVector<int64_t> permKey = FrozenVector<int64_t>(std::move(permutation));
ArrayRef<RotationGroup> rotationGroup =
shiftNetworks.computeShiftNetwork(permKey);
assert(!rotationGroup.empty() &&
"Shift network must have at least one group");

// Process each rotation group separately with a full set of power-of-two
// shifts. Then sum the results together.
rewriter.setInsertionPointAfter(op);
std::optional<Value> result = std::nullopt;
for (const RotationGroup &group : rotationGroup) {
Value perGroupResult =
rotateGroup(op.getTensor(), group, ciphertextSize, rewriter);
if (result.has_value())
result =
rewriter.create<arith::AddIOp>(op.getLoc(), *result, perGroupResult);
else
result = perGroupResult;
}

rewriter.replaceOp(op, result.value());
return success();
}

struct LayoutConversionToShiftNetwork
: impl::LayoutConversionToShiftNetworkBase<LayoutConversionToShiftNetwork> {
Expand All @@ -247,9 +327,12 @@ struct LayoutConversionToShiftNetwork
RewritePatternSet patterns(context);

VosVosErkinShiftNetworks shiftNetworks{ciphertextSize};
patterns.add<RewriteLayoutConversion>(context, shiftNetworks);

(void)walkAndApplyPatterns(getOperation(), std::move(patterns));
getOperation()->walk([&](ConvertLayoutOp op) {
if (failed(convertLayoutOp(op, shiftNetworks, ciphertextSize))) {
signalPassFailure();
}
});
}
};

Expand Down

0 comments on commit de67600

Please sign in to comment.