Skip to content

Commit

Permalink
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.sc…
Browse files Browse the repository at this point in the history
…aled_dot_product_attention (#3690)

Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
  • Loading branch information
rohan-tan-bhowmik authored Sep 9, 2024
1 parent 0a788e0 commit e86f56b
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 47 deletions.
33 changes: 26 additions & 7 deletions include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
["generateScalarImplementation"]>]> {
let summary = "Attention operator";
let description = [{
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
the attention. Each of the inputs has shape BxNxd where B is the
of the batch dimension, N is the sequence length and d is head dimension.
Typically N >>> d. Mathematically, the attention is defined as
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
this operator also performs scaling, masking and dropout, but we leave
that out of the current implementation.
This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an
optional mask(M) to compute the attention. These tensors must take on shapes
BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these
shapes, B represents the batch dimension, M represents sequence length, N
represents head dimension, and K1 and K2 are hidden dimensions.
Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
has shape BxMxN. Usually, this operator also performs scaling, masking and
dropout, but we leave that out of the current implementation.
}];

let arguments = (ins Variadic<AnyShaped>:$inputs,
Expand Down Expand Up @@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
Value getValue() {
return getInputOperand(2)->get();
}
std::optional<Value> getAttnMask() {
if (getNumInputs() < 4) {
return std::nullopt;
}
return getInputOperand(3)->get();
}
Value getOutput() {
return getOutputOperand(0)->get();
}
Expand All @@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
ShapedType getValueType() {
return cast<ShapedType>(getValue().getType());
}
std::optional<ShapedType> getAttnMaskType() {
if (getAttnMask()){
return cast<ShapedType>((*getAttnMask()).getType());
}
return std::nullopt;
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
Expand All @@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
int64_t getValueRank() {
return getValueType().getRank();
}
std::optional<int64_t> getAttnMaskRank() {
if (getAttnMask()){
return (*getAttnMaskType()).getRank();
}
return std::nullopt;
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
Expand Down
111 changes: 89 additions & 22 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,26 +1578,94 @@ class ConvertAtenScaledDotProductAttentionOp
LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value mask = op.getAttnMask();

auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto mask = adaptor.getAttnMask();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal();
Value scale = op.getScale();
Value enableGQA = op.getEnableGqa();
Type elementType =
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();

// Verify inputs (only support defaults)
if (!isa<Torch::NoneType>(mask.getType()))
return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported");
double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");

bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
return rewriter.notifyMatchFailure(
op.getLoc(), "causal attention masking not supported");
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
if (!isa<Torch::NoneType>(mask.getType())) {
return rewriter.notifyMatchFailure(
op.getLoc(), "expected no attention mask when isCausal is true");
}

SmallVector<OpFoldResult> maskSizes;

if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
auto seqLenQ =
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
auto seqLenK =
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
}

Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);

Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
Value negInf = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));

mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
.getResult(0);

int64_t rank = cast<ShapedType>(queryTy).getRank();
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), mask.getType(), ValueRange{}, mask,
SmallVector<AffineMap>{maskMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
Value j = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 1);

Value cond =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, i, j);
Value select = b.create<arith::SelectOp>(loc, cond, zero, negInf);
b.create<linalg::YieldOp>(loc, select);
});
mask = genericOp.getResult(0);
}

if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
Expand All @@ -1611,14 +1679,6 @@ class ConvertAtenScaledDotProductAttentionOp
return rewriter.notifyMatchFailure(
op.getLoc(), "grouped query attention not supported");

auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
Expand Down Expand Up @@ -1659,6 +1719,9 @@ class ConvertAtenScaledDotProductAttentionOp
query = collapseBatch(query);
key = collapseBatch(key);
value = collapseBatch(value);
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
mask = collapseBatch(mask);
}

SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
SmallVector<int64_t> valueSizes(
Expand All @@ -1672,13 +1735,17 @@ class ConvertAtenScaledDotProductAttentionOp
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType);

SmallVector<Value> inputs = SmallVector<Value>{query, key, value};

if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
inputs.push_back(mask);
}

// Overwrite with tm_tensor::attention
Value attention =
rewriter
.create<AttentionOp>(loc, outType,
SmallVector<Value>{query, key, value},
SmallVector<Value>{output})
.getResult()[0];
Value attention = rewriter
.create<AttentionOp>(loc, outType, inputs,
SmallVector<Value>{output})
.getResult()[0];

if (opTy != outType) {
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
Expand Down
94 changes: 87 additions & 7 deletions lib/Dialect/TMTensor/IR/TMTensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() {
Operation *op = getOperation();
ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();

auto optionalMaskType = getAttnMaskType();
ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType();

ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> maskShape =
optionalMaskType ? maskType.getShape() : ArrayRef<int64_t>();

for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (keyShape[i] != queryShape[i])
if (keyShape[i] != queryShape[i]) {
return op->emitOpError("query and key batch mismatch");
}
}
if (keyShape.back() != queryShape.back())
if (keyShape.back() != queryShape.back()) {
return op->emitOpError("query and key head dimension mismatch");
}

for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (valueShape[i] != queryShape[i]) {
return op->emitOpError("query and value batch dimension mismatch");
}
}
if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) {
return op->emitOpError("key and value sequence length dimension mismatch");
}
if (optionalMaskType) {
for (int i = 0, s = maskShape.size() - 2; i < s; ++i) {
if (maskShape[i] != queryShape[i]) {
return op->emitOpError("query and mask batch dimension mismatch");
}
}
if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) {
return op->emitOpError(
"mask sequence length and query sequence length mismatch");
}
if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) {
return op->emitOpError(
"mask sequence lengt and key sequence length mismatch");
}
}
return success();
}

Expand Down Expand Up @@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value query = getQuery();
Value key = getKey();
Value value = getValue();

auto optionalMask = getAttnMask();
Value mask = optionalMask ? *optionalMask : Value();

Value output = getOutput();
auto queryType = cast<MemRefType>(query.getType());
auto keyType = cast<MemRefType>(key.getType());
auto valueType = cast<MemRefType>(value.getType());
auto maskType = mask ? cast<MemRefType>(mask.getType()) : MemRefType();
auto queryRank = queryType.getRank();
auto keyRank = keyType.getRank();
auto valueRank = valueType.getRank();
Expand All @@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,

Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
b.getFloatAttr(elementType, 0.0));
Value negInfF = b.create<arith::ConstantOp>(
loc, elementType,
b.getFloatAttr(elementType, -std::numeric_limits<double>::infinity()));

// TODO: This needs to be fixed, it assumes everything is dynamic however if
// any shapes are static the `memref.alloc` generated is illegal.
Expand Down Expand Up @@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
/*transposed=*/true);

// weight = softmax(weight)
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value dim = weightDynSizes[weightRank - 1];
Value scaleFactor = b.create<math::SqrtOp>(
loc, b.create<arith::UIToFPOp>(
loc, elementType,
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
queryDynSizes[queryRank - 1])));

// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});

// Apply mask to weights if mask is given
if (mask) {
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value weightValue = b.create<memref::LoadOp>(loc, weight, localIVs);
Value maskValue = b.create<memref::LoadOp>(loc, mask, localIVs);
if (maskType.getElementType().isInteger(1)) {
maskValue =
b.create<arith::SelectOp>(loc, maskValue, zeroF, negInfF);
}
Value maskedWeight =
b.create<arith::AddFOp>(loc, weightValue, maskValue);
b.create<memref::StoreOp>(loc, maskedWeight, weight, localIVs);
});
}

// calculate max(weight)
Value init = b.create<memref::LoadOp>(loc, weight,
SmallVector<Value>(weightRank, zero));
Expand Down Expand Up @@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});
// calculate exp(weight)
Expand Down Expand Up @@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
SmallVector<Value> sumIVs(localIVs);
sumIVs.pop_back();

Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
Value divResult = b.create<arith::DivFOp>(loc, x, sum);

// Set to 0 if sum is 0 (can occur during boolean mask / large negative
// QK)
Value isSumZero =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, sum, zeroF);
Value result =
b.create<arith::SelectOp>(loc, isSumZero, zeroF, divResult);

b.create<memref::StoreOp>(loc, result, weight, localIVs);
});

// output = weight @ value
Expand Down
Loading

0 comments on commit e86f56b

Please sign in to comment.