Skip to content

Commit

Permalink
Merge pull request #1295 from j2kun:issue-1284
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720311933
  • Loading branch information
copybara-github committed Jan 27, 2025
2 parents b1436b4 + 1cd7dd9 commit b9500b6
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 114 deletions.
51 changes: 41 additions & 10 deletions docs/content/en/docs/Design/relinearization_ilp.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ rotation has key-switching built in, and multiplication relinearizes by default.
That said, many FHE implementations do allow for the relinearization operation
to be deferred. A useful such situation is when a series of independent
multiplications are performed, and the results are added together. Addition can
operate in any key basis (though all inputs must have the same key basis), and
so the relinearization op that follows each multiplication can be deferred until
after the additions are complete, at which point there is only one
relinearization to perform. This technique is usually called _lazy
relinearization_. It has the benefit of avoiding expensive relinearization
operations, as well as reducing noise growth, as relinearization adds noise to
the ciphertext, which can further reduce the need for bootstrapping.
operate in any key basis (though depending on the backend FHE implementation's
details, all inputs may require the same key basis, cf.
[Optional operand agreement](#optional-operand-agreement)), and so the
relinearization op that follows each multiplication can be deferred until after
the additions are complete, at which point there is only one relinearization to
perform. This technique is usually called _lazy relinearization_. It has the
benefit of avoiding expensive relinearization operations, as well as reducing
noise growth, as relinearization adds noise to the ciphertext, which can further
reduce the need for bootstrapping.

In much of the literature, lazy relinearization is applied manually. See for
example
Expand Down Expand Up @@ -128,13 +130,12 @@ TODO(#1018): update docs when objective is generalized.

### Constraints

#### Simple constraints

The simple constraints are as follows:

- Initial key basis degree: For each block argument, $\\textup{KB}\_v$ is fixed
to equal the `dimension` parameter on the RLWE ciphertext type.
- Operand agreement: For each operation with operand SSA values $v_1, \\dots,
v_k$, $\\textup{KB}\_{v_1} = \\dots = \\textup{KB}\_{v_k}$, i.e., all key
basis inputs must match.
- Special linearized ops: `bgv.rotate` and `func.return` require linearized
inputs, i.e., $\\textup{KB}\_{v_i} = 1$ for all inputs $v_i$ to these
operations.
Expand All @@ -146,6 +147,36 @@ The simple constraints are as follows:
only op that increases the degree, and all operands are constrained to have
equal degree.

#### Optional operand agreement

There are two versions of the model, one where the an operation requires the
input key basis degrees of each operand to be equal, and one where differing key
basis degrees are allowed.

This is an option because the model was originally implemented under the
incorrect assumption that CPU backends like OpenFHE and Lattigo require the key
basis degree operands to be equal for ops like ciphertext addition. When we
discovered this was not the case, we generalized the model to support both
cases, in case other backends do have this requirement.

When operands must have the same key basis degree, then for each operation with
operand SSA values $v_1, \\dots, v_k$, we add the constraint
$\\textup{KB}\_{v_1} = \\dots = \\textup{KB}\_{v_k}$, i.e., all key basis inputs
must match.

When operands may have different key basis degrees, we instead add the
constraint that each operation result key basis degree (before relinearization)
is at least as large as the max of all operand key basis degrees. For all $i$,
$\\textup{KB}\_{\\textup{result}(o)}^{br} \\geq \\textup{KB}\_{v_i}$. Note that
we are relying on an implicit behavior of the model to ensure that, even if the
solver chooses key basis degree variables for these op results larger than the
max of the operand degrees, the resulting optimal solution is the same.

TODO(#1018): this will change to a more principled approach when the objective
is generalized

#### Impact of relinearization choices on key basis degree

The remaining constraints control the dynamics of how the key basis degree
changes as relinearizations are inserted.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,22 @@
#include <string>
#include <utility>

#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

// Avoid copybara mangling and separate third party includes with a comment.
#include "absl/status/statusor.h" // from @com_google_absl
Expand Down Expand Up @@ -157,15 +156,6 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}
});

// The objective is to minimize the number of relinearization ops.
// TODO(#1018): improve the objective function to account for differing
// costs of operations at varying degrees.
math_opt::LinearExpression obj;
for (auto &[op, decisionVar] : decisionVariables) {
obj += decisionVar;
}
model.Minimize(obj);

// Constraints to initialize the key basis degree variables at the start of
// the computation.
for (auto &[value, var] : keyBasisVars) {
Expand All @@ -176,46 +166,52 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}
}

// For each operation, constrain its inputs to all have the same key basis
// degree.
std::string cstName;
opToRunOn->walk([&](Operation *op) {
if (op->getNumOperands() <= 1) {
return;
}
// For each operation, constrain its inputs to all have the same key basis
// degree. Most FHE backends we're aware of do not require this, and can
// handle mixed-degree operations like ciphertext addition. When we do require
// this, the output of an operation like ciphertext addition can be passed
// through from the input unchanged. If we don't require this, the output
// of the addition must be a max over the input degrees.
if (!allowMixedDegreeOperands) {
opToRunOn->walk([&](Operation *op) {
if (op->getNumOperands() <= 1) {
return;
}

// secret generic op arguments are not constrained
// instead their block arguments are constrained
if (isa<secret::GenericOp>(op)) {
return;
}
// secret generic op arguments are not constrained
// instead their block arguments are constrained
if (isa<secret::GenericOp>(op)) {
return;
}

std::string name = uniqueName(op);
std::string name = uniqueName(op);

// only equality for secret operands
SmallVector<OpOperand *, 4> secretOperands;
getSecretOperands(op, secretOperands, solver);
if (secretOperands.size() <= 1) {
return;
}
// only equality for secret operands
SmallVector<OpOperand *, 4> secretOperands;
getSecretOperands(op, secretOperands, solver);
if (secretOperands.size() <= 1) {
return;
}

auto anchorVar = keyBasisVars.at(secretOperands[0]->get());
auto anchorVar = keyBasisVars.at(secretOperands[0]->get());

// degree(operand 0) == degree(operand i)
for (OpOperand *opOperand : secretOperands) {
if (!keyBasisVars.contains(opOperand->get())) {
continue;
}
auto operandDegreeVar = keyBasisVars.at(opOperand->get());
if (anchorVar == operandDegreeVar) {
continue;
// degree(operand 0) == degree(operand i)
for (OpOperand *opOperand : secretOperands) {
if (!keyBasisVars.contains(opOperand->get())) {
continue;
}
auto operandDegreeVar = keyBasisVars.at(opOperand->get());
if (anchorVar == operandDegreeVar) {
continue;
}
std::stringstream ss;
ss << "ArgKeyBasisEquality_" << opOperand->getOperandNumber() << "_"
<< name;
model.AddLinearConstraint(operandDegreeVar == anchorVar, ss.str());
}
std::stringstream ss;
ss << "ArgKeyBasisEquality_" << opOperand->getOperandNumber() << "_"
<< name;
model.AddLinearConstraint(operandDegreeVar == anchorVar, ss.str());
}
});
});
}

// Some ops require a linear key basis. Yield is a special case
// where we require returned values from funcs to be linearized.
Expand All @@ -242,6 +238,12 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
});
});

// When mixed-degree ops are enabled, the default result degree of an op is
// the max of the operand degree. This next block of code adds inequality
// constraints to ensure the result is larger than the arguments, but it will
// not be constrained by the model to be equal to that max value.
std::unordered_set<const math_opt::Variable *> extraVarsForObjective;

// Add constraints that set the before_relin variables appropriately
opToRunOn->walk([&](Operation *op) {
llvm::TypeSwitch<Operation &>(*op)
Expand Down Expand Up @@ -289,11 +291,18 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
})
.Default([&](Operation &op) {
// For any other op, the key basis does not change unless we insert
// a relin op. Because the verifier ensures the operands and results
// have identical key bases, we can just pass through the first
// argument to the before_relin variable.
// a relin op. The operands may have the same basis degree, if that
// is required by the backend and allowMixedDegreeOperands is false,
// in which case we can just forward the degree of the first secret
// operand. Otherwise, we have to require the output to be the max
// of the inputs, which requires inequality constraints.
//
// before_relin = arg1_degree
//
// or,
// before_relin >= arg1_degree
// before_relin >= arg2_degree
// before_relin >= arg3_degree

// secret generic op arguments are not constrained
// instead their block arguments are constrained
Expand All @@ -305,26 +314,55 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}
SmallVector<OpOperand *, 4> secretOperands;
getSecretOperands(&op, secretOperands, solver);
// this works because we constraint argDegreeVar for all
// SecretOperands to be the same
auto argDegreeVar = keyBasisVars.at(secretOperands[0]->get());

for (Value result : op.getResults()) {
auto resultBeforeRelinVar = beforeRelinVars.at(result);
std::string opName = uniqueName(&op);
std::string ddPrefix = "DecisionDynamics_" + opName;
std::string opName = uniqueName(&op);
if (allowMixedDegreeOperands) {
for (OpOperand *opOperand : secretOperands) {
std::string ddPrefix =
"DecisionDynamics_Mixed_" + opName + "_" +
std::to_string(opOperand->getOperandNumber());
auto argDegreeVar = keyBasisVars.at(opOperand->get());
for (OpResult opResult : op.getOpResults()) {
Value result = opResult;
const math_opt::Variable &resultBeforeRelinVar =
beforeRelinVars.at(result);
cstName =
ddPrefix + "_" + std::to_string(opResult.getResultNumber());
model.AddLinearConstraint(resultBeforeRelinVar >= argDegreeVar,
cstName);
extraVarsForObjective.insert(&resultBeforeRelinVar);
}
}
} else {
auto argDegreeVar = keyBasisVars.at(secretOperands[0]->get());

cstName = ddPrefix + "_0";
// This is mildly wasteful, but the presolve will prune it out and
// it shouldn't affect the solve time. It simply helps us do
// bookkeeping for the before/after relin vars uniformly across
// all cases.
model.AddLinearConstraint(resultBeforeRelinVar == argDegreeVar,
cstName);
for (Value result : op.getResults()) {
auto resultBeforeRelinVar = beforeRelinVars.at(result);
std::string opName = uniqueName(&op);
std::string ddPrefix = "DecisionDynamics_" + opName;

cstName = ddPrefix + "_0";
// This is mildly wasteful, but the presolve will prune it out and
// it shouldn't affect the solve time. It simply helps us do
// bookkeeping for the before/after relin vars uniformly across
// all cases.
model.AddLinearConstraint(resultBeforeRelinVar == argDegreeVar,
cstName);
}
}
});
});

// The objective is to minimize the number of relinearization ops.
// TODO(#1018): improve the objective function to account for differing costs
// of operations at varying degrees, as well as the cost of relinearizing
// based on the starting degree of the input.
math_opt::LinearExpression obj;
for (auto &[op, decisionVar] : decisionVariables) {
obj += decisionVar;
}
model.Minimize(obj);

// Add constraints that control the effect of relinearization insertion.
opToRunOn->walk([&](Operation *op) {
// We don't need a type switch here because the only difference
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#ifndef LIB_ANALYSIS_OPTIMIZE_RELINEARIZATIONANALYSIS_H
#define LIB_ANALYSIS_OPTIMIZE_RELINEARIZATIONANALYSIS_H

#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
Expand All @@ -13,10 +11,12 @@ namespace heir {
class OptimizeRelinearizationAnalysis {
public:
OptimizeRelinearizationAnalysis(Operation *op, DataFlowSolver *solver,
bool useLocBasedVariableNames)
bool useLocBasedVariableNames,
bool allowMixedDegreeOperands)
: opToRunOn(op),
solver(solver),
useLocBasedVariableNames(useLocBasedVariableNames) {}
useLocBasedVariableNames(useLocBasedVariableNames),
allowMixedDegreeOperands(allowMixedDegreeOperands) {}
~OptimizeRelinearizationAnalysis() = default;

LogicalResult solve();
Expand All @@ -40,7 +40,8 @@ class OptimizeRelinearizationAnalysis {
private:
Operation *opToRunOn;
DataFlowSolver *solver;
bool useLocBasedVariableNames = false;
bool useLocBasedVariableNames;
bool allowMixedDegreeOperands;
llvm::DenseMap<Operation *, bool> solution;
llvm::DenseMap<Value, int> solutionKeyBasisDegreeBeforeRelin;
};
Expand Down
Loading

0 comments on commit b9500b6

Please sign in to comment.