Skip to content

Commit

Permalink
Add a verifier for bgv.rotate and corresponding test (#287) && Nit: R…
Browse files Browse the repository at this point in the history
…emove 'this->' in bgv verifiers
  • Loading branch information
Maokami authored and jaeho committed Nov 30, 2023
1 parent c057ae2 commit 04fb0e2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
2 changes: 2 additions & 0 deletions include/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
let results = (outs
Ciphertext:$output
);

let hasVerifier = 1;
}

def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> {
Expand Down
49 changes: 30 additions & 19 deletions lib/Dialect/BGV/IR/BGVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,56 @@ void BGVDialect::initialize() {
}

LogicalResult MulOp::verify() {
auto x = this->getX().getType();
auto y = this->getY().getType();
auto x = getX().getType();
auto y = getY().getType();
if (x.getDim() != y.getDim()) {
return this->emitOpError() << "input dimensions do not match";
return emitOpError() << "input dimensions do not match";
}
auto out = this->getOutput().getType();
auto out = getOutput().getType();
if (out.getDim() != 1 + x.getDim()) {
return this->emitOpError() << "output.dim == x.dim + 1 does not hold";
return emitOpError() << "output.dim == x.dim + 1 does not hold";
}
return success();
}
LogicalResult Rotate::verify() {
auto x = getX().getType();
if (x.getDim() != 2) {
return emitOpError() << "x.dim == 2 does not hold";
}
auto out = getOutput().getType();
if (out.getDim() != 2) {
return emitOpError() << "output.dim == 2 does not hold";
}
return success();
}

LogicalResult Relinearize::verify() {
auto x = this->getX().getType();
auto out = this->getOutput().getType();
if (x.getDim() != this->getFromBasis().size()) {
return this->emitOpError() << "input dimension does not match from_basis";
auto x = getX().getType();
auto out = getOutput().getType();
if (x.getDim() != getFromBasis().size()) {
return emitOpError() << "input dimension does not match from_basis";
}
if (out.getDim() != this->getToBasis().size()) {
return this->emitOpError() << "output dimension does not match to_basis";
if (out.getDim() != getToBasis().size()) {
return emitOpError() << "output dimension does not match to_basis";
}
return success();
}

LogicalResult ModulusSwitch::verify() {
auto x = this->getX().getType();
auto x = getX().getType();
auto rings = x.getRings().getRings().size();
auto to = this->getToLevel();
auto from = this->getFromLevel();
auto to = getToLevel();
auto from = getFromLevel();
if (to < 0 || to >= from || from >= rings) {
return this->emitOpError() << "invalid levels, should be true: 0 <= " << to
<< " < " << from << " < " << rings;
return emitOpError() << "invalid levels, should be true: 0 <= " << to
<< " < " << from << " < " << rings;
}
if (x.getLevel().has_value() && x.getLevel().value() != from) {
return this->emitOpError() << "input level does not match from_level";
return emitOpError() << "input level does not match from_level";
}
auto outLvl = this->getOutput().getType().getLevel();
auto outLvl = getOutput().getType().getLevel();
if (!outLvl.has_value() || outLvl.value() != to) {
return this->emitOpError()
return emitOpError()
<< "output level should be specified and match to_level";
}
return success();
Expand Down
13 changes: 13 additions & 0 deletions tests/bgv/verifier.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: heir-opt --verify-diagnostics %s

#my_poly = #polynomial.polynomial<1 + x**1024>
#ring1 = #polynomial.ring<cmod=463187969, ideal=#my_poly>
#ring2 = #polynomial.ring<cmod=33538049, ideal=#my_poly>
#rings = #bgv.rings<#ring1, #ring2>

func.func @test_input_dimension_error(%input: !bgv.ciphertext<rings=#rings, dim=3>) {
// expected-error@+1 {{x.dim == 2 does not hold}}
%out = bgv.rotate (%input) {offset = 4} : (!bgv.ciphertext<rings=#rings, dim=3>) -> !bgv.ciphertext<rings=#rings>

return
}

0 comments on commit 04fb0e2

Please sign in to comment.