Skip to content

Commit

Permalink
perf(BIN): adaptative number line (#531)
Browse files Browse the repository at this point in the history
* perf(BIN): adaptative nb line

* fix(bin): rebase fix

* feat: column typing

* fix(constraint): ras

* fix: ras

* fix: corset export

* feat(constraint): rebase

* feat(constraint): ras
  • Loading branch information
letypequividelespoubelles authored and delehef committed Feb 29, 2024
1 parent a557e1d commit 2159044
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import net.consensys.linea.zktracer.bytestheta.BaseBytes;
import net.consensys.linea.zktracer.container.ModuleOperation;
import net.consensys.linea.zktracer.opcode.OpCode;
import net.consensys.linea.zktracer.types.Bytes16;
import net.consensys.linea.zktracer.types.UnsignedByte;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;
Expand All @@ -43,24 +42,45 @@ public class BinOperation extends ModuleOperation {
@EqualsAndHashCode.Include private final OpCode opCode;
@EqualsAndHashCode.Include private final BaseBytes arg1;
@EqualsAndHashCode.Include private final BaseBytes arg2;

public BinOperation(OpCode opCode, BaseBytes arg1, BaseBytes arg2) {
this.opCode = opCode;
this.arg1 = arg1;
this.arg2 = arg2;
this.ctMax = maxCt();
}

private static final int LLARGE = 16;
private static final int LLARGEMO = 15;
private final int ctMax;
private List<Boolean> lastEightBits = List.of(false);
private boolean bit4 = false;
private int low4 = 0;
private boolean isSmall = false;
private int pivotThreshold = 0;
private int pivot = 0;

private boolean isOneLineInstruction() {
return (opCode == OpCode.BYTE || opCode == OpCode.SIGNEXTEND) && !arg1.getHigh().isZero();
}

@Override
protected int computeLineCount() {
return this.maxCt();
return this.ctMax + 1;
}

private int maxCt() {
return isOneLineInstruction() ? 1 : LIMB_SIZE;
return switch (opCode) {
case NOT -> LLARGEMO;
case BYTE, SIGNEXTEND -> arg1.getHigh().isZero() ? LLARGEMO : 0;
case AND, OR, XOR -> Math.max(
0,
Math.max(
Math.max(
arg1.getHigh().trimLeadingZeros().size(),
arg2.getHigh().trimLeadingZeros().size()),
Math.max(
arg1.getLow().trimLeadingZeros().size(),
arg2.getLow().trimLeadingZeros().size()))
- 1);
default -> throw new IllegalStateException("Unexpected value: " + opCode);
};
}

private boolean isSmall() {
Expand Down Expand Up @@ -177,22 +197,29 @@ private void compute() {
public void traceBinOperation(int stamp, Trace trace) {
this.compute();

final Bytes16 resHi = this.getResult().getHigh();
final Bytes16 resLo = this.getResult().getLow();
final int length = ctMax + 1;
final int offset = LLARGE - length;

final Bytes stampBytes = Bytes.minimalBytes(stamp);
final Bytes arg1Hi = this.arg1.getHigh().slice(offset, length);
final Bytes arg1Lo = this.arg1.getLow().slice(offset, length);
final Bytes arg2Hi = this.arg2.getHigh().slice(offset, length);
final Bytes arg2Lo = this.arg2.getLow().slice(offset, length);
final Bytes resHi = this.getResult().getHigh().slice(offset, length);
final Bytes resLo = this.getResult().getLow().slice(offset, length);
final List<Boolean> bit1 = this.getBit1();
final List<Boolean> bits =
Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList();
for (int ct = 0; ct < this.maxCt(); ct++) {
for (int ct = 0; ct <= this.ctMax; ct++) {
trace
.stamp(Bytes.ofUnsignedInt(stamp))
.oneLineInstruction(this.maxCt() == 1)
.mli(this.maxCt() != 1)
.stamp(stampBytes)
.ctMax(UnsignedByte.of(ctMax))
.counter(UnsignedByte.of(ct))
.inst(UnsignedByte.of(this.opCode().byteValue()))
.argument1Hi(this.arg1().getHigh())
.argument1Lo(this.arg1().getLow())
.argument2Hi(this.arg2().getHigh())
.argument2Lo(this.arg2().getLow())
.argument1Hi(arg1Hi)
.argument1Lo(arg1Lo)
.argument2Hi(arg2Hi)
.argument2Lo(arg2Lo)
.resultHi(resHi)
.resultLo(resLo)
.isAnd(this.opCode() == OpCode.AND)
Expand All @@ -208,16 +235,16 @@ public void traceBinOperation(int stamp, Trace trace) {
.neg(bits.get(0))
.bit1(bit1.get(ct))
.pivot(UnsignedByte.of(this.pivot))
.byte1(UnsignedByte.of(this.arg1().getHigh().get(ct)))
.byte2(UnsignedByte.of(this.arg1().getLow().get(ct)))
.byte3(UnsignedByte.of(this.arg2().getHigh().get(ct)))
.byte4(UnsignedByte.of(this.arg2().getLow().get(ct)))
.byte1(UnsignedByte.of(arg1Hi.get(ct)))
.byte2(UnsignedByte.of(arg1Lo.get(ct)))
.byte3(UnsignedByte.of(arg2Hi.get(ct)))
.byte4(UnsignedByte.of(arg2Lo.get(ct)))
.byte5(UnsignedByte.of(resHi.get(ct)))
.byte6(UnsignedByte.of(resLo.get(ct)))
.acc1(this.arg1().getHigh().slice(0, ct + 1))
.acc2(this.arg1().getLow().slice(0, ct + 1))
.acc3(this.arg2().getHigh().slice(0, ct + 1))
.acc4(this.arg2().getLow().slice(0, ct + 1))
.acc1(arg1Hi.slice(0, ct + 1))
.acc2(arg1Lo.slice(0, ct + 1))
.acc3(arg2Hi.slice(0, ct + 1))
.acc4(arg2Lo.slice(0, ct + 1))
.acc5(resHi.slice(0, ct + 1))
.acc6(resLo.slice(0, ct + 1))
.xxxByteHi(UnsignedByte.of(resHi.get(ct)))
Expand Down
Loading

0 comments on commit 2159044

Please sign in to comment.