diff --git a/docs/content/en/docs/Design/relinearization_ilp.md b/docs/content/en/docs/Design/relinearization_ilp.md index edde97ff3..9752821a5 100644 --- a/docs/content/en/docs/Design/relinearization_ilp.md +++ b/docs/content/en/docs/Design/relinearization_ilp.md @@ -50,13 +50,15 @@ rotation has key-switching built in, and multiplication relinearizes by default. That said, many FHE implementations do allow for the relinearization operation to be deferred. A useful such situation is when a series of independent multiplications are performed, and the results are added together. Addition can -operate in any key basis (though all inputs must have the same key basis), and -so the relinearization op that follows each multiplication can be deferred until -after the additions are complete, at which point there is only one -relinearization to perform. This technique is usually called _lazy -relinearization_. It has the benefit of avoiding expensive relinearization -operations, as well as reducing noise growth, as relinearization adds noise to -the ciphertext, which can further reduce the need for bootstrapping. +operate in any key basis (though depending on the backend FHE implementation's +details, all inputs may require the same key basis, cf. +[Optional operand agreement](#optional-operand-agreement)), and so the +relinearization op that follows each multiplication can be deferred until after +the additions are complete, at which point there is only one relinearization to +perform. This technique is usually called _lazy relinearization_. It has the +benefit of avoiding expensive relinearization operations, as well as reducing +noise growth, as relinearization adds noise to the ciphertext, which can further +reduce the need for bootstrapping. In much of the literature, lazy relinearization is applied manually. See for example @@ -128,13 +130,12 @@ TODO(#1018): update docs when objective is generalized. ### Constraints +#### Simple constraints + The simple constraints are as follows: - Initial key basis degree: For each block argument, $\\textup{KB}\_v$ is fixed to equal the `dimension` parameter on the RLWE ciphertext type. -- Operand agreement: For each operation with operand SSA values $v_1, \\dots, - v_k$, $\\textup{KB}\_{v_1} = \\dots = \\textup{KB}\_{v_k}$, i.e., all key - basis inputs must match. - Special linearized ops: `bgv.rotate` and `func.return` require linearized inputs, i.e., $\\textup{KB}\_{v_i} = 1$ for all inputs $v_i$ to these operations. @@ -146,6 +147,36 @@ The simple constraints are as follows: only op that increases the degree, and all operands are constrained to have equal degree. +#### Optional operand agreement + +There are two versions of the model, one where the an operation requires the +input key basis degrees of each operand to be equal, and one where differing key +basis degrees are allowed. + +This is an option because the model was originally implemented under the +incorrect assumption that CPU backends like OpenFHE and Lattigo require the key +basis degree operands to be equal for ops like ciphertext addition. When we +discovered this was not the case, we generalized the model to support both +cases, in case other backends do have this requirement. + +When operands must have the same key basis degree, then for each operation with +operand SSA values $v_1, \\dots, v_k$, we add the constraint +$\\textup{KB}\_{v_1} = \\dots = \\textup{KB}\_{v_k}$, i.e., all key basis inputs +must match. + +When operands may have different key basis degrees, we instead add the +constraint that each operation result key basis degree (before relinearization) +is at least as large as the max of all operand key basis degrees. For all $i$, +$\\textup{KB}\_{\\textup{result}(o)}^{br} \\geq \\textup{KB}\_{v_i}$. Note that +we are relying on an implicit behavior of the model to ensure that, even if the +solver chooses key basis degree variables for these op results larger than the +max of the operand degrees, the resulting optimal solution is the same. + +TODO(#1018): this will change to a more principled approach when the objective +is generalized + +#### Impact of relinearization choices on key basis degree + The remaining constraints control the dynamics of how the key basis degree changes as relinearizations are inserted. diff --git a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp index b18121bcd..b4ec573b3 100644 --- a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp +++ b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.cpp @@ -5,23 +5,22 @@ #include #include -#include "lib/Dialect/Mgmt/IR/MgmtOps.h" +#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h" +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/TensorExt/IR/TensorExtOps.h" -#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project // Avoid copybara mangling and separate third party includes with a comment. #include "absl/status/statusor.h" // from @com_google_absl @@ -157,15 +156,6 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } }); - // The objective is to minimize the number of relinearization ops. - // TODO(#1018): improve the objective function to account for differing - // costs of operations at varying degrees. - math_opt::LinearExpression obj; - for (auto &[op, decisionVar] : decisionVariables) { - obj += decisionVar; - } - model.Minimize(obj); - // Constraints to initialize the key basis degree variables at the start of // the computation. for (auto &[value, var] : keyBasisVars) { @@ -176,46 +166,52 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } } - // For each operation, constrain its inputs to all have the same key basis - // degree. std::string cstName; - opToRunOn->walk([&](Operation *op) { - if (op->getNumOperands() <= 1) { - return; - } + // For each operation, constrain its inputs to all have the same key basis + // degree. Most FHE backends we're aware of do not require this, and can + // handle mixed-degree operations like ciphertext addition. When we do require + // this, the output of an operation like ciphertext addition can be passed + // through from the input unchanged. If we don't require this, the output + // of the addition must be a max over the input degrees. + if (!allowMixedDegreeOperands) { + opToRunOn->walk([&](Operation *op) { + if (op->getNumOperands() <= 1) { + return; + } - // secret generic op arguments are not constrained - // instead their block arguments are constrained - if (isa(op)) { - return; - } + // secret generic op arguments are not constrained + // instead their block arguments are constrained + if (isa(op)) { + return; + } - std::string name = uniqueName(op); + std::string name = uniqueName(op); - // only equality for secret operands - SmallVector secretOperands; - getSecretOperands(op, secretOperands, solver); - if (secretOperands.size() <= 1) { - return; - } + // only equality for secret operands + SmallVector secretOperands; + getSecretOperands(op, secretOperands, solver); + if (secretOperands.size() <= 1) { + return; + } - auto anchorVar = keyBasisVars.at(secretOperands[0]->get()); + auto anchorVar = keyBasisVars.at(secretOperands[0]->get()); - // degree(operand 0) == degree(operand i) - for (OpOperand *opOperand : secretOperands) { - if (!keyBasisVars.contains(opOperand->get())) { - continue; - } - auto operandDegreeVar = keyBasisVars.at(opOperand->get()); - if (anchorVar == operandDegreeVar) { - continue; + // degree(operand 0) == degree(operand i) + for (OpOperand *opOperand : secretOperands) { + if (!keyBasisVars.contains(opOperand->get())) { + continue; + } + auto operandDegreeVar = keyBasisVars.at(opOperand->get()); + if (anchorVar == operandDegreeVar) { + continue; + } + std::stringstream ss; + ss << "ArgKeyBasisEquality_" << opOperand->getOperandNumber() << "_" + << name; + model.AddLinearConstraint(operandDegreeVar == anchorVar, ss.str()); } - std::stringstream ss; - ss << "ArgKeyBasisEquality_" << opOperand->getOperandNumber() << "_" - << name; - model.AddLinearConstraint(operandDegreeVar == anchorVar, ss.str()); - } - }); + }); + } // Some ops require a linear key basis. Yield is a special case // where we require returned values from funcs to be linearized. @@ -242,6 +238,12 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { }); }); + // When mixed-degree ops are enabled, the default result degree of an op is + // the max of the operand degree. This next block of code adds inequality + // constraints to ensure the result is larger than the arguments, but it will + // not be constrained by the model to be equal to that max value. + std::unordered_set extraVarsForObjective; + // Add constraints that set the before_relin variables appropriately opToRunOn->walk([&](Operation *op) { llvm::TypeSwitch(*op) @@ -289,11 +291,18 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { }) .Default([&](Operation &op) { // For any other op, the key basis does not change unless we insert - // a relin op. Because the verifier ensures the operands and results - // have identical key bases, we can just pass through the first - // argument to the before_relin variable. + // a relin op. The operands may have the same basis degree, if that + // is required by the backend and allowMixedDegreeOperands is false, + // in which case we can just forward the degree of the first secret + // operand. Otherwise, we have to require the output to be the max + // of the inputs, which requires inequality constraints. // // before_relin = arg1_degree + // + // or, + // before_relin >= arg1_degree + // before_relin >= arg2_degree + // before_relin >= arg3_degree // secret generic op arguments are not constrained // instead their block arguments are constrained @@ -305,26 +314,55 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() { } SmallVector secretOperands; getSecretOperands(&op, secretOperands, solver); - // this works because we constraint argDegreeVar for all - // SecretOperands to be the same - auto argDegreeVar = keyBasisVars.at(secretOperands[0]->get()); - for (Value result : op.getResults()) { - auto resultBeforeRelinVar = beforeRelinVars.at(result); - std::string opName = uniqueName(&op); - std::string ddPrefix = "DecisionDynamics_" + opName; + std::string opName = uniqueName(&op); + if (allowMixedDegreeOperands) { + for (OpOperand *opOperand : secretOperands) { + std::string ddPrefix = + "DecisionDynamics_Mixed_" + opName + "_" + + std::to_string(opOperand->getOperandNumber()); + auto argDegreeVar = keyBasisVars.at(opOperand->get()); + for (OpResult opResult : op.getOpResults()) { + Value result = opResult; + const math_opt::Variable &resultBeforeRelinVar = + beforeRelinVars.at(result); + cstName = + ddPrefix + "_" + std::to_string(opResult.getResultNumber()); + model.AddLinearConstraint(resultBeforeRelinVar >= argDegreeVar, + cstName); + extraVarsForObjective.insert(&resultBeforeRelinVar); + } + } + } else { + auto argDegreeVar = keyBasisVars.at(secretOperands[0]->get()); - cstName = ddPrefix + "_0"; - // This is mildly wasteful, but the presolve will prune it out and - // it shouldn't affect the solve time. It simply helps us do - // bookkeeping for the before/after relin vars uniformly across - // all cases. - model.AddLinearConstraint(resultBeforeRelinVar == argDegreeVar, - cstName); + for (Value result : op.getResults()) { + auto resultBeforeRelinVar = beforeRelinVars.at(result); + std::string opName = uniqueName(&op); + std::string ddPrefix = "DecisionDynamics_" + opName; + + cstName = ddPrefix + "_0"; + // This is mildly wasteful, but the presolve will prune it out and + // it shouldn't affect the solve time. It simply helps us do + // bookkeeping for the before/after relin vars uniformly across + // all cases. + model.AddLinearConstraint(resultBeforeRelinVar == argDegreeVar, + cstName); + } } }); }); + // The objective is to minimize the number of relinearization ops. + // TODO(#1018): improve the objective function to account for differing costs + // of operations at varying degrees, as well as the cost of relinearizing + // based on the starting degree of the input. + math_opt::LinearExpression obj; + for (auto &[op, decisionVar] : decisionVariables) { + obj += decisionVar; + } + model.Minimize(obj); + // Add constraints that control the effect of relinearization insertion. opToRunOn->walk([&](Operation *op) { // We don't need a type switch here because the only difference diff --git a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h index 6ef7c6317..f8be27899 100644 --- a/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h +++ b/lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h @@ -1,8 +1,6 @@ #ifndef LIB_ANALYSIS_OPTIMIZE_RELINEARIZATIONANALYSIS_H #define LIB_ANALYSIS_OPTIMIZE_RELINEARIZATIONANALYSIS_H -#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h" -#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project @@ -13,10 +11,12 @@ namespace heir { class OptimizeRelinearizationAnalysis { public: OptimizeRelinearizationAnalysis(Operation *op, DataFlowSolver *solver, - bool useLocBasedVariableNames) + bool useLocBasedVariableNames, + bool allowMixedDegreeOperands) : opToRunOn(op), solver(solver), - useLocBasedVariableNames(useLocBasedVariableNames) {} + useLocBasedVariableNames(useLocBasedVariableNames), + allowMixedDegreeOperands(allowMixedDegreeOperands) {} ~OptimizeRelinearizationAnalysis() = default; LogicalResult solve(); @@ -40,7 +40,8 @@ class OptimizeRelinearizationAnalysis { private: Operation *opToRunOn; DataFlowSolver *solver; - bool useLocBasedVariableNames = false; + bool useLocBasedVariableNames; + bool allowMixedDegreeOperands; llvm::DenseMap solution; llvm::DenseMap solutionKeyBasisDegreeBeforeRelin; }; diff --git a/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp b/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp index bdc3087ae..135b5a092 100644 --- a/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp +++ b/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.cpp @@ -1,35 +1,21 @@ #include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h" -#include -#include -#include -#include - #include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h" -#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" #include "lib/Analysis/OptimizeRelinearizationAnalysis/OptimizeRelinearizationAnalysis.h" #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/Mgmt/IR/MgmtOps.h" #include "lib/Dialect/Mgmt/Transforms/AnnotateMgmt.h" -#include "lib/Dialect/Mgmt/Transforms/Passes.h" #include "lib/Dialect/Secret/IR/SecretOps.h" -#include "lib/Dialect/Utils.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace heir { @@ -53,8 +39,8 @@ struct OptimizeRelinearization op.erase(); }); - OptimizeRelinearizationAnalysis analysis(genericOp, solver, - useLocBasedVariableNames); + OptimizeRelinearizationAnalysis analysis( + genericOp, solver, useLocBasedVariableNames, allowMixedDegreeOperands); if (failed(analysis.solve())) { genericOp->emitError("Failed to solve the optimization problem"); return signalPassFailure(); diff --git a/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.td b/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.td index 7e8fe6006..458a3c42e 100644 --- a/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.td +++ b/lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.td @@ -38,6 +38,13 @@ def OptimizeRelinearization : Pass <"optimize-relinearization"> { /*default=*/"false", "When true, the ILP uses op source locations in variable names, " "which can help debug ILP model bugs.">, + Option<"allowMixedDegreeOperands", + "allow-mixed-degree-operands", + "bool", + /*default=*/"true", + "When true, allow ops to have mixed-degree ciphertexts as inputs, e.g., " + "adding two ciphertexts with different key bases; this is supported by " + "many FHE backends, like OpenFHE and Lattigo">, ]; } diff --git a/tests/Transforms/optimize_relinearization/force_equal_args.mlir b/tests/Transforms/optimize_relinearization/force_equal_args.mlir new file mode 100644 index 000000000..49aeb30c5 --- /dev/null +++ b/tests/Transforms/optimize_relinearization/force_equal_args.mlir @@ -0,0 +1,24 @@ +// RUN: heir-opt --mlir-print-local-scope --secretize --mlir-to-secret-arithmetic --optimize-relinearization='allow-mixed-degree-operands=false' %s | FileCheck %s + +// CHECK-LABEL: func.func @relinearize_both_add_operands +// CHECK: secret.generic +// CHECK: arith.muli +// CHECK-NEXT: mgmt.relinearize +// CHECK-NEXT: arith.muli +// CHECK-NEXT: mgmt.relinearize +// CHECK-NEXT: tensor_ext.rotate +// CHECK-NEXT: arith.addi +// CHECK-NOT: mgmt.relinearize +// CHECK-NEXT: secret.yield +func.func @relinearize_both_add_operands(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> tensor<8xi16> { + %0 = arith.muli %arg0, %arg0: tensor<8xi16> + %1 = mgmt.relinearize %0 : tensor<8xi16> + %2 = arith.muli %arg1, %arg1: tensor<8xi16> + %3 = mgmt.relinearize %2 : tensor<8xi16> + + // Rotation requires degree 1 key basis input + %c1 = arith.constant 1 : index + %6 = tensor_ext.rotate %3, %c1 : tensor<8xi16>, index + %7 = arith.addi %1, %6 : tensor<8xi16> + func.return %7 : tensor<8xi16> +} diff --git a/tests/Transforms/optimize_relinearization/issue_1284.mlir b/tests/Transforms/optimize_relinearization/issue_1284.mlir new file mode 100644 index 000000000..d2d5ff1ff --- /dev/null +++ b/tests/Transforms/optimize_relinearization/issue_1284.mlir @@ -0,0 +1,25 @@ +// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --optimize-relinearization %s | FileCheck %s + +// CHECK-LABEL: func.func @repro +// CHECK-COUNT-1: mgmt.relinearize +// CHECK-NOT: mgmt.relinearize +func.func @repro(%x: i16 {secret.secret}, %y: i16 {secret.secret}, %p: i16) -> (i16) { + %xx = arith.muli %x, %x : i16 + %yy = arith.muli %y, %y : i16 + %0 = arith.addi %xx, %yy : i16 + %xp = arith.muli %x, %p : i16 + %1 = arith.addi %xp, %0 : i16 + func.return %1 : i16 +} + +// CHECK-LABEL: func.func @repro2 +// CHECK-COUNT-1: mgmt.relinearize +// CHECK-NOT: mgmt.relinearize +func.func @repro2(%x: i16 {secret.secret}, %y: i16 {secret.secret}, %p: i16) -> (i16) { + %xx = arith.muli %x, %x : i16 + %yy = arith.muli %y, %y : i16 + %0 = arith.addi %xx, %yy : i16 + %xp = arith.muli %x, %p : i16 + %1 = arith.addi %0, %xp : i16 + func.return %1 : i16 +} diff --git a/tests/Transforms/optimize_relinearization/optimize_relinearization.mlir b/tests/Transforms/optimize_relinearization/optimize_relinearization.mlir index b8a9eec3c..17d0f0f71 100644 --- a/tests/Transforms/optimize_relinearization/optimize_relinearization.mlir +++ b/tests/Transforms/optimize_relinearization/optimize_relinearization.mlir @@ -195,11 +195,11 @@ func.func @smoke_test(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> tensor<8xi1 // CHECK-LABEL: func.func @rotation_needs_linear_inputs // CHECK: secret.generic // CHECK: arith.muli -// CHECK-NEXT: mgmt.relinearize // CHECK-NEXT: arith.muli // CHECK-NEXT: mgmt.relinearize // CHECK-NEXT: tensor_ext.rotate // CHECK-NEXT: arith.addi +// CHECK-NEXT: mgmt.relinearize // CHECK-NEXT: secret.yield func.func @rotation_needs_linear_inputs(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> tensor<8xi16> { %0 = arith.muli %arg0, %arg0: tensor<8xi16>