diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index b3697b59bc31..a46aad8f31de 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -924,31 +924,28 @@ void transform_dialect::WorkgroupSwizzleOp::getEffects( static void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis, func::FuncOp funcOp) { - for (Block &block : funcOp) { - for (Operation &op : block) { - for (NamedAttribute attr : op.getAttrs()) { - StringRef name = attr.getName().strref(); - if (name.find("__vector_layout_test_anchor_operand_") != - std::string::npos) { - int operandNum; - name.substr(name.find_last_of("_") + 1) - .getAsInteger(/*Radix=*/10, operandNum); - assert(operandNum < op.getNumOperands() && - "operand number out of range"); - analysis.setAnchor(op.getOperand(operandNum), attr.getValue()); - } - if (name.find("__vector_layout_test_anchor_result_") != - std::string::npos) { - int resultNum; - name.substr(name.find_last_of("_") + 1) - .getAsInteger(/*Radix=*/10, resultNum); - assert(resultNum < op.getNumResults() && - "result number out of range"); - analysis.setAnchor(op.getResult(resultNum), attr.getValue()); - } + funcOp.walk([&](Operation *op) { + for (NamedAttribute attr : op->getAttrs()) { + StringRef name = attr.getName().strref(); + if (name.find("__vector_layout_test_anchor_operand_") != + std::string::npos) { + int operandNum; + name.substr(name.find_last_of("_") + 1) + .getAsInteger(/*Radix=*/10, operandNum); + assert(operandNum < op->getNumOperands() && + "operand number out of range"); + analysis.setAnchor(op->getOperand(operandNum), attr.getValue()); + } + if (name.find("__vector_layout_test_anchor_result_") != + std::string::npos) { + int resultNum; + name.substr(name.find_last_of("_") + 1) + .getAsInteger(/*Radix=*/10, resultNum); + assert(resultNum < op->getNumResults() && "result number out of range"); + analysis.setAnchor(op->getResult(resultNum), attr.getValue()); } } - } + }); } static void emitLayoutRemarks(VectorLayoutAnalysis &analysis, diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index fafb74ac82fa..b2a89c43d31b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -763,7 +763,8 @@ void PropagateLayout::visitRegionSuccessors(RegionBranchOpInterface branch, // Propagate the layouts. for (auto [forwardedLattice, inputLattice] : llvm::zip(forwardedLattices, inputLattices)) { - inputLattice->resolve(forwardedLattice); + ChangeResult changed = inputLattice->resolve(forwardedLattice); + propagateIfChanged(inputLattice, changed); } } } @@ -887,8 +888,9 @@ void EnforceLayout::visitRegionSuccessors(RegionBranchOpInterface branch, int64_t curr = 0; for (auto [forwardedLattice, inputLattice] : llvm::zip(forwardedLattices, inputLattices)) { - forwardedLattice->resolveWithPossibleConflict(inputLattice, - *forwardedOperands[curr]); + ChangeResult changed = forwardedLattice->resolveWithPossibleConflict( + inputLattice, *forwardedOperands[curr]); + propagateIfChanged(forwardedLattice, changed); curr++; } } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir index 72c42573b069..eb4ba0e9b463 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir @@ -145,3 +145,46 @@ builtin.module attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +#layout = #iree_vector_ext.layout<<[VECTORY], [16]>, <[BATCHY, VECTORX], [2, 8]>> + +// Propagate and enforce through scf.for +builtin.module attributes { transform.with_named_sequence } { + func.func @scffor(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector<16xf16> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cst_0 = arith.constant 0.0 : f16 + %cst0_1 = arith.constant dense<0.0> : vector<16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + + %out = scf.for %iv = %c0 to %c1024 step %c1 iter_args(%arg1 = %cst0_1) -> (vector<16xf16>) { + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>, <[ BATCHY, VECTORX], [2, 8]>>}} + %root2 = vector.transfer_read %arr2[%c0], %cst_0 {in_bounds = [true]} : memref<16xf16>, vector<16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + %root_transpose = vector.transpose %root, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHY, VECTORX], [2, 8]>, <[ VECTORY], [16]>>}} + %root_red = vector.multi_reduction, %root_transpose, %arg1 [0] : vector<16x16xf16> to vector<16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + %c = arith.mulf %root_red, %b : vector<16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + %d = arith.addf %c, %a : vector<16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + %e = arith.divf %d, %root2 : vector<16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}} + scf.yield %e : vector<16xf16> + } + + func.return %out : vector<16xf16> + } + + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +}