From b1429a2fb7e251df6738e44f01c2f5c988f943a2 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 1 Dec 2023 16:29:23 -0800 Subject: [PATCH] test option and fix bug --- .../Secret/Transforms/DistributeGeneric.cpp | 71 ++----------------- tests/secret/distribute_generic_flags.mlir | 29 ++++++++ tests/secret/verifier.mlir | 27 ------- 3 files changed, 33 insertions(+), 94 deletions(-) create mode 100644 tests/secret/distribute_generic_flags.mlir diff --git a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp index 8f897f202..707dc783d 100644 --- a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp +++ b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp @@ -225,11 +225,6 @@ struct SplitGeneric : public OpRewritePattern { // to the corresponding secret operands (via the block argument number). rewriter.startRootUpdate(genericOp); - LLVM_DEBUG({ - llvm::dbgs() << "\n\ngeneric op before updating loop operands\n\n"; - genericOp.dump(); - }); - // Set the loop op's operands that came from the secret generic block // to be the the corresponding operand of the generic op. for (OpOperand &operand : opToDistribute.getOpOperands()) { @@ -240,11 +235,6 @@ struct SplitGeneric : public OpRewritePattern { } } - LLVM_DEBUG({ - llvm::dbgs() << "\n\ngeneric op after updating loop operands:\n\n"; - genericOp.dump(); - }); - // Set the op's region iter arg types, which need to match the possibly // new type of the operands modified above for (auto [arg, operand] : @@ -252,23 +242,7 @@ struct SplitGeneric : public OpRewritePattern { arg.setType(operand.getType()); } - LLVM_DEBUG({ - llvm::dbgs() << "\n\ngeneric op after updating region iter args\n\n"; - genericOp.dump(); - }); - - // There is a slight type conflict here: the loop's iter arg is - // secret, but its block argument is just index. Since the - // CollapseSecretlessGeneric pattern will resolve this type conflict - // later, we leave it as-is here. - opToDistribute.moveBefore(genericOp); - - LLVM_DEBUG({ - llvm::dbgs() << "\n\nparent after moving loop out of generic body:\n\n"; - genericOp->getParentOp()->dump(); - }); - // Now the loop is before the secret generic, but the generic still // yields the loop's result (the loop should yield the generic's result) // and the generic's body still needs to be moved inside the loop. @@ -284,11 +258,6 @@ struct SplitGeneric : public OpRewritePattern { // Move the generic op to be the first op of the loop body. genericOp->moveBefore(&loopBodyBlocks.front().getOperations().front()); - LLVM_DEBUG({ - llvm::dbgs() << "\n\nloop after moving generic into the loop body:\n\n"; - opToDistribute.dump(); - }); - // Update the yielded values by the terminators of the two ops' blocks. auto yieldedValues = loop.getYieldedValues(); genericOp.getBody(0)->getTerminator()->setOperands(yieldedValues); @@ -300,11 +269,6 @@ struct SplitGeneric : public OpRewritePattern { terminator->setOperands(genericOp.getResults()); } - LLVM_DEBUG({ - llvm::dbgs() << "\n\nloop after updating yielded values:\n\n"; - opToDistribute.dump(); - }); - // Update the return type of the loop op to match its terminator. auto resultRange = loop.getLoopResults(); if (resultRange.has_value()) { @@ -314,22 +278,11 @@ struct SplitGeneric : public OpRewritePattern { } } - LLVM_DEBUG({ - llvm::dbgs() << "\n\nloop after updating return types:\n\n"; - opToDistribute.dump(); - }); - // Move the old loop body ops into the secret.generic for (auto *op : loopBodyOps) { op->moveBefore(genericOp.getBody(0)->getTerminator()); } - LLVM_DEBUG({ - llvm::dbgs() << "\n\nloop after moving old loop body ops into the " - "secret.generic:\n\n"; - opToDistribute.dump(); - }); - // One of the secret.generic's inputs may still refer to the loop's // iter_args initializer, when now it should refer to the iter_arg itself. for (OpOperand &operand : genericOp->getOpOperands()) { @@ -339,12 +292,6 @@ struct SplitGeneric : public OpRewritePattern { } } - LLVM_DEBUG({ - llvm::dbgs() - << "\n\nloop after updating secret.generic to use iter_arg:\n\n"; - opToDistribute.dump(); - }); - // The ops within the secret generic may still refer to the loop // iter_args, which are not part of of the secret.generic's block. To be // a bit more general, walk the entire generic body, and for any operand @@ -371,12 +318,6 @@ struct SplitGeneric : public OpRewritePattern { } }); - LLVM_DEBUG({ - llvm::dbgs() << "\n\nloop after updating op args to use plaintext " - "analogues:\n\n"; - opToDistribute.dump(); - }); - // Finally, ops that came after the original secret.generic may still // refer to a secret.generic result, when now they should refer to the // corresponding result of the loop, if the loop has results. @@ -391,12 +332,6 @@ struct SplitGeneric : public OpRewritePattern { } } - LLVM_DEBUG({ - llvm::dbgs() - << "\n\nloop after updating potential downstream users\n\n"; - opToDistribute.getParentOp()->dump(); - }); - rewriter.finalizeRootUpdate(genericOp); return; } @@ -482,7 +417,7 @@ struct SplitGeneric : public OpRewritePattern { return failure(); } - Operation *opToDistribute; + Operation *opToDistribute = nullptr; bool first = true; if (opsToDistribute.empty()) { opToDistribute = &body->front(); @@ -492,6 +427,8 @@ struct SplitGeneric : public OpRewritePattern { // affine.for) if (std::find(opsToDistribute.begin(), opsToDistribute.end(), op.getName().getStringRef()) != opsToDistribute.end()) { + LLVM_DEBUG(llvm::dbgs() + << "Found op to distribute: " << op.getName() << "\n"); opToDistribute = &op; break; } @@ -501,7 +438,7 @@ struct SplitGeneric : public OpRewritePattern { // Base case: if none of a generic op's member ops are in the list of ops // to process, stop. - if (!opToDistribute) return failure(); + if (opToDistribute == nullptr) return failure(); if (numOps == 2 && !opToDistribute->getRegions().empty()) { distributeThroughRegionHoldingOp(op, *opToDistribute, rewriter); diff --git a/tests/secret/distribute_generic_flags.mlir b/tests/secret/distribute_generic_flags.mlir new file mode 100644 index 000000000..f7cf00c61 --- /dev/null +++ b/tests/secret/distribute_generic_flags.mlir @@ -0,0 +1,29 @@ +// RUN: heir-opt --secret-distribute-generic="distribute-through=affine.for" %s | FileCheck %s + +// CHECK-LABEL: test_affine_for +// CHECK-SAME: %[[value:.*]]: !secret.secret +// CHECK-SAME: %[[data:.*]]: !secret.secret> +func.func @test_affine_for( + %value: !secret.secret, + %data: !secret.secret>) -> !secret.secret> { + // CHECK: affine.for + // CHECK: secret.generic + // CHECK-NEXT: bb + // CHECK-NEXT: affine.load + // CHECK-NEXT: arith.addi + // CHECK-NEXT: affine.store + // CHECK-NEXT: secret.yield + // CHECK-NOT: secret.generic + // CHECK: return %[[data]] + secret.generic + ins(%value, %data : !secret.secret, !secret.secret>) { + ^bb0(%clear_value: i32, %clear_data : memref<10xi32>): + affine.for %i = 0 to 10 { + %2 = affine.load %clear_data[%i] : memref<10xi32> + %3 = arith.addi %2, %clear_value : i32 + affine.store %3, %clear_data[%i] : memref<10xi32> + } + secret.yield + } -> () + func.return %data : !secret.secret> +} diff --git a/tests/secret/verifier.mlir b/tests/secret/verifier.mlir index 1dbfd1959..b19c8c768 100644 --- a/tests/secret/verifier.mlir +++ b/tests/secret/verifier.mlir @@ -27,33 +27,6 @@ func.func @test_secret_type_mismatch(%value: !secret.secret, %c1: i32) { // ----- -func.func @test_refers_to_value_outside_block(%value: !secret.secret) { - %c1 = arith.constant 1 : i32 - // expected-error@+1 {{uses a value defined outside the block}} - %Z = secret.generic - ins(%value : !secret.secret) { - ^bb0(%clear_value: i32): - %1 = arith.addi %clear_value, %c1 : i32 - secret.yield %1 : i32 - } -> (!secret.secret) - return -} - -// ----- - -func.func @test_refers_to_block_argument_outside_block(%value: !secret.secret, %c1 : i32) { - // expected-error@+1 {{uses a block argument defined outside the block}} - %Z = secret.generic - ins(%value : !secret.secret) { - ^bb0(%clear_value: i32): - %1 = arith.addi %clear_value, %c1 : i32 - secret.yield %1 : i32 - } -> (!secret.secret) - return -} - -// ----- - func.func @ensure_yield_inside_generic(%value: !secret.secret) { // expected-error@+1 {{expects parent op 'secret.generic'}} secret.yield %value : !secret.secret