Skip to content

Commit

Permalink
[VectorLayoutAnalysis] Fix bug in scf.for transfer functions (#15989)
Browse files Browse the repository at this point in the history
Prior to this patch, the scf.for transfer functions were not propagating
change on resolution of scf.for operands/results.
  • Loading branch information
Groverkss authored Dec 28, 2023
1 parent 4592b8f commit b0e8f3c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -924,31 +924,28 @@ void transform_dialect::WorkgroupSwizzleOp::getEffects(

static void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
func::FuncOp funcOp) {
for (Block &block : funcOp) {
for (Operation &op : block) {
for (NamedAttribute attr : op.getAttrs()) {
StringRef name = attr.getName().strref();
if (name.find("__vector_layout_test_anchor_operand_") !=
std::string::npos) {
int operandNum;
name.substr(name.find_last_of("_") + 1)
.getAsInteger(/*Radix=*/10, operandNum);
assert(operandNum < op.getNumOperands() &&
"operand number out of range");
analysis.setAnchor(op.getOperand(operandNum), attr.getValue());
}
if (name.find("__vector_layout_test_anchor_result_") !=
std::string::npos) {
int resultNum;
name.substr(name.find_last_of("_") + 1)
.getAsInteger(/*Radix=*/10, resultNum);
assert(resultNum < op.getNumResults() &&
"result number out of range");
analysis.setAnchor(op.getResult(resultNum), attr.getValue());
}
funcOp.walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
StringRef name = attr.getName().strref();
if (name.find("__vector_layout_test_anchor_operand_") !=
std::string::npos) {
int operandNum;
name.substr(name.find_last_of("_") + 1)
.getAsInteger(/*Radix=*/10, operandNum);
assert(operandNum < op->getNumOperands() &&
"operand number out of range");
analysis.setAnchor(op->getOperand(operandNum), attr.getValue());
}
if (name.find("__vector_layout_test_anchor_result_") !=
std::string::npos) {
int resultNum;
name.substr(name.find_last_of("_") + 1)
.getAsInteger(/*Radix=*/10, resultNum);
assert(resultNum < op->getNumResults() && "result number out of range");
analysis.setAnchor(op->getResult(resultNum), attr.getValue());
}
}
}
});
}

static void emitLayoutRemarks(VectorLayoutAnalysis &analysis,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,8 @@ void PropagateLayout::visitRegionSuccessors(RegionBranchOpInterface branch,
// Propagate the layouts.
for (auto [forwardedLattice, inputLattice] :
llvm::zip(forwardedLattices, inputLattices)) {
inputLattice->resolve(forwardedLattice);
ChangeResult changed = inputLattice->resolve(forwardedLattice);
propagateIfChanged(inputLattice, changed);
}
}
}
Expand Down Expand Up @@ -887,8 +888,9 @@ void EnforceLayout::visitRegionSuccessors(RegionBranchOpInterface branch,
int64_t curr = 0;
for (auto [forwardedLattice, inputLattice] :
llvm::zip(forwardedLattices, inputLattices)) {
forwardedLattice->resolveWithPossibleConflict(inputLattice,
*forwardedOperands[curr]);
ChangeResult changed = forwardedLattice->resolveWithPossibleConflict(
inputLattice, *forwardedOperands[curr]);
propagateIfChanged(forwardedLattice, changed);
curr++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,46 @@ builtin.module attributes { transform.with_named_sequence } {
transform.yield
}
}

// -----

#layout = #iree_vector_ext.layout<<[VECTORY], [16]>, <[BATCHY, VECTORX], [2, 8]>>

// Propagate and enforce through scf.for
builtin.module attributes { transform.with_named_sequence } {
func.func @scffor(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector<16xf16> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1024 = arith.constant 1024 : index
%cst_0 = arith.constant 0.0 : f16
%cst0_1 = arith.constant dense<0.0> : vector<16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}

%out = scf.for %iv = %c0 to %c1024 step %c1 iter_args(%arg1 = %cst0_1) -> (vector<16xf16>) {
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>, <[ BATCHY, VECTORX], [2, 8]>>}}
%root2 = vector.transfer_read %arr2[%c0], %cst_0 {in_bounds = [true]} : memref<16xf16>, vector<16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
%root_transpose = vector.transpose %root, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHY, VECTORX], [2, 8]>, <[ VECTORY], [16]>>}}
%root_red = vector.multi_reduction<add>, %root_transpose, %arg1 [0] : vector<16x16xf16> to vector<16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
%c = arith.mulf %root_red, %b : vector<16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
%d = arith.addf %c, %a : vector<16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
%e = arith.divf %d, %root2 : vector<16xf16>
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
scf.yield %e : vector<16xf16>
}

func.return %out : vector<16xf16>
}

transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
transform.yield
}
}

0 comments on commit b0e8f3c

Please sign in to comment.