From 09cdda92bfcb4a0c43ed3f4c0ce32521205c88e4 Mon Sep 17 00:00:00 2001 From: Francois Bojarski Date: Mon, 8 Jan 2024 14:18:22 +0100 Subject: [PATCH 1/6] perf(BIN): adaptative nb line --- .../zktracer/module/bin/BinOperation.java | 103 ++++++--- .../linea/zktracer/module/bin/Trace.java | 201 ++++++++---------- 2 files changed, 167 insertions(+), 137 deletions(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java index 748d71189b..731798592a 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java @@ -28,7 +28,6 @@ import net.consensys.linea.zktracer.bytestheta.BaseBytes; import net.consensys.linea.zktracer.container.ModuleOperation; import net.consensys.linea.zktracer.opcode.OpCode; -import net.consensys.linea.zktracer.types.Bytes16; import net.consensys.linea.zktracer.types.UnsignedByte; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; @@ -40,9 +39,26 @@ public class BinOperation extends ModuleOperation { private static final int LIMB_SIZE = 16; - @EqualsAndHashCode.Include private final OpCode opCode; - @EqualsAndHashCode.Include private final BaseBytes arg1; - @EqualsAndHashCode.Include private final BaseBytes arg2; + @EqualsAndHashCode.Include + private final OpCode opCode; + @EqualsAndHashCode.Include + private final BaseBytes arg1; + @EqualsAndHashCode.Include + private final BaseBytes arg2; + + public BinOperation(OpCode opCode, BaseBytes arg1, BaseBytes arg2) { + this.opCode = opCode; + this.arg1 = arg1; + this.arg2 = arg2; + this.ctMax = maxCt(); + } + + private static final int LLARGE = 16; + private static final int LLARGEMO = 15; + private final OpCode opCode; + private final BaseBytes arg1; + private final BaseBytes arg2; + private final int ctMax; private List lastEightBits = List.of(false); private boolean bit4 = false; private int low4 = 0; @@ -50,17 +66,49 @@ public class BinOperation extends ModuleOperation { private int pivotThreshold = 0; private int pivot = 0; + <<<<<<>>>>>> 49b702a1 (perf(BIN): adaptative nb line) } @Override protected int computeLineCount() { - return this.maxCt(); + return this.ctMax + 1; } private int maxCt() { - return isOneLineInstruction() ? 1 : LIMB_SIZE; + return switch (opCode) { + case NOT -> LLARGEMO; + case BYTE, SIGNEXTEND -> arg1.getHigh().isZero() ? LLARGEMO : 0; + case AND, OR, XOR -> Math.max( + 0, + Math.max( + Math.max( + arg1.getHigh().trimLeadingZeros().size(), + arg2.getHigh().trimLeadingZeros().size()), + Math.max( + arg1.getLow().trimLeadingZeros().size(), + arg2.getLow().trimLeadingZeros().size())) + - 1); + default -> throw new IllegalStateException("Unexpected value: " + opCode); + }; } private boolean isSmall() { @@ -177,22 +225,27 @@ private void compute() { public void traceBinOperation(int stamp, Trace trace) { this.compute(); - final Bytes16 resHi = this.getResult().getHigh(); - final Bytes16 resLo = this.getResult().getLow(); + final int length = ctMax + 1; + final int offset = LLARGE - length; + + final Bytes arg1Hi = this.arg1.getHigh().slice(offset, length); + final Bytes arg1Lo = this.arg1.getLow().slice(offset, length); + final Bytes arg2Hi = this.arg2.getHigh().slice(offset, length); + final Bytes arg2Lo = this.arg2.getLow().slice(offset, length); + final Bytes resHi = this.getResult().getHigh().slice(offset, length); + final Bytes resLo = this.getResult().getLow().slice(offset, length); final List bit1 = this.getBit1(); - final List bits = - Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList(); - for (int ct = 0; ct < this.maxCt(); ct++) { + final List bits = Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList(); + for (int ct = 0; ct <= this.ctMax; ct++) { trace .stamp(Bytes.ofUnsignedInt(stamp)) - .oneLineInstruction(this.maxCt() == 1) - .mli(this.maxCt() != 1) + .ctMax(UnsignedByte.of(ctMax)) .counter(UnsignedByte.of(ct)) .inst(UnsignedByte.of(this.opCode().byteValue())) - .argument1Hi(this.arg1().getHigh()) - .argument1Lo(this.arg1().getLow()) - .argument2Hi(this.arg2().getHigh()) - .argument2Lo(this.arg2().getLow()) + .argument1Hi(arg1Hi) + .argument1Lo(arg1Lo) + .argument2Hi(arg2Hi) + .argument2Lo(arg2Lo) .resultHi(resHi) .resultLo(resLo) .isAnd(this.opCode() == OpCode.AND) @@ -208,16 +261,16 @@ public void traceBinOperation(int stamp, Trace trace) { .neg(bits.get(0)) .bit1(bit1.get(ct)) .pivot(UnsignedByte.of(this.pivot)) - .byte1(UnsignedByte.of(this.arg1().getHigh().get(ct))) - .byte2(UnsignedByte.of(this.arg1().getLow().get(ct))) - .byte3(UnsignedByte.of(this.arg2().getHigh().get(ct))) - .byte4(UnsignedByte.of(this.arg2().getLow().get(ct))) + .byte1(UnsignedByte.of(arg1Hi.get(ct))) + .byte2(UnsignedByte.of(arg1Lo.get(ct))) + .byte3(UnsignedByte.of(arg2Hi.get(ct))) + .byte4(UnsignedByte.of(arg2Lo.get(ct))) .byte5(UnsignedByte.of(resHi.get(ct))) .byte6(UnsignedByte.of(resLo.get(ct))) - .acc1(this.arg1().getHigh().slice(0, ct + 1)) - .acc2(this.arg1().getLow().slice(0, ct + 1)) - .acc3(this.arg2().getHigh().slice(0, ct + 1)) - .acc4(this.arg2().getLow().slice(0, ct + 1)) + .acc1(arg1Hi.slice(0, ct + 1)) + .acc2(arg1Lo.slice(0, ct + 1)) + .acc3(arg2Hi.slice(0, ct + 1)) + .acc4(arg2Lo.slice(0, ct + 1)) .acc5(resHi.slice(0, ct + 1)) .acc6(resLo.slice(0, ct + 1)) .xxxByteHi(UnsignedByte.of(resHi.get(ct))) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java index 900497e7fe..1e77325c15 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java @@ -54,6 +54,7 @@ public class Trace { private final MappedByteBuffer byte5; private final MappedByteBuffer byte6; private final MappedByteBuffer counter; + private final MappedByteBuffer ctMax; private final MappedByteBuffer inst; private final MappedByteBuffer isAnd; private final MappedByteBuffer isByte; @@ -62,9 +63,7 @@ public class Trace { private final MappedByteBuffer isSignextend; private final MappedByteBuffer isXor; private final MappedByteBuffer low4; - private final MappedByteBuffer mli; private final MappedByteBuffer neg; - private final MappedByteBuffer oneLineInstruction; private final MappedByteBuffer pivot; private final MappedByteBuffer resultHi; private final MappedByteBuffer resultLo; @@ -95,6 +94,7 @@ static List headers(int length) { new ColumnHeader("bin.BYTE_5", 1, length), new ColumnHeader("bin.BYTE_6", 1, length), new ColumnHeader("bin.COUNTER", 1, length), + new ColumnHeader("bin.CT_MAX", 1, length), new ColumnHeader("bin.INST", 1, length), new ColumnHeader("bin.IS_AND", 1, length), new ColumnHeader("bin.IS_BYTE", 1, length), @@ -103,9 +103,7 @@ static List headers(int length) { new ColumnHeader("bin.IS_SIGNEXTEND", 1, length), new ColumnHeader("bin.IS_XOR", 1, length), new ColumnHeader("bin.LOW_4", 1, length), - new ColumnHeader("bin.MLI", 1, length), new ColumnHeader("bin.NEG", 1, length), - new ColumnHeader("bin.ONE_LINE_INSTRUCTION", 1, length), new ColumnHeader("bin.PIVOT", 1, length), new ColumnHeader("bin.RESULT_HI", 32, length), new ColumnHeader("bin.RESULT_LO", 32, length), @@ -136,24 +134,23 @@ public Trace(List buffers) { this.byte5 = buffers.get(17); this.byte6 = buffers.get(18); this.counter = buffers.get(19); - this.inst = buffers.get(20); - this.isAnd = buffers.get(21); - this.isByte = buffers.get(22); - this.isNot = buffers.get(23); - this.isOr = buffers.get(24); - this.isSignextend = buffers.get(25); - this.isXor = buffers.get(26); - this.low4 = buffers.get(27); - this.mli = buffers.get(28); + this.ctMax = buffers.get(20); + this.inst = buffers.get(21); + this.isAnd = buffers.get(22); + this.isByte = buffers.get(23); + this.isNot = buffers.get(24); + this.isOr = buffers.get(25); + this.isSignextend = buffers.get(26); + this.isXor = buffers.get(27); + this.low4 = buffers.get(28); this.neg = buffers.get(29); - this.oneLineInstruction = buffers.get(30); - this.pivot = buffers.get(31); - this.resultHi = buffers.get(32); - this.resultLo = buffers.get(33); - this.small = buffers.get(34); - this.stamp = buffers.get(35); - this.xxxByteHi = buffers.get(36); - this.xxxByteLo = buffers.get(37); + this.pivot = buffers.get(30); + this.resultHi = buffers.get(31); + this.resultLo = buffers.get(32); + this.small = buffers.get(33); + this.stamp = buffers.get(34); + this.xxxByteHi = buffers.get(35); + this.xxxByteLo = buffers.get(36); } public int size() { @@ -444,110 +441,110 @@ public Trace counter(final UnsignedByte b) { return this; } - public Trace inst(final UnsignedByte b) { + public Trace ctMax(final UnsignedByte b) { if (filled.get(20)) { - throw new IllegalStateException("bin.INST already set"); + throw new IllegalStateException("bin.CT_MAX already set"); } else { filled.set(20); } - inst.put(b.toByte()); + ctMax.put(b.toByte()); return this; } - public Trace isAnd(final Boolean b) { + public Trace inst(final UnsignedByte b) { if (filled.get(21)) { - throw new IllegalStateException("bin.IS_AND already set"); + throw new IllegalStateException("bin.INST already set"); } else { filled.set(21); } - isAnd.put((byte) (b ? 1 : 0)); + inst.put(b.toByte()); return this; } - public Trace isByte(final Boolean b) { + public Trace isAnd(final Boolean b) { if (filled.get(22)) { - throw new IllegalStateException("bin.IS_BYTE already set"); + throw new IllegalStateException("bin.IS_AND already set"); } else { filled.set(22); } - isByte.put((byte) (b ? 1 : 0)); + isAnd.put((byte) (b ? 1 : 0)); return this; } - public Trace isNot(final Boolean b) { + public Trace isByte(final Boolean b) { if (filled.get(23)) { - throw new IllegalStateException("bin.IS_NOT already set"); + throw new IllegalStateException("bin.IS_BYTE already set"); } else { filled.set(23); } - isNot.put((byte) (b ? 1 : 0)); + isByte.put((byte) (b ? 1 : 0)); return this; } - public Trace isOr(final Boolean b) { + public Trace isNot(final Boolean b) { if (filled.get(24)) { - throw new IllegalStateException("bin.IS_OR already set"); + throw new IllegalStateException("bin.IS_NOT already set"); } else { filled.set(24); } - isOr.put((byte) (b ? 1 : 0)); + isNot.put((byte) (b ? 1 : 0)); return this; } - public Trace isSignextend(final Boolean b) { + public Trace isOr(final Boolean b) { if (filled.get(25)) { - throw new IllegalStateException("bin.IS_SIGNEXTEND already set"); + throw new IllegalStateException("bin.IS_OR already set"); } else { filled.set(25); } - isSignextend.put((byte) (b ? 1 : 0)); + isOr.put((byte) (b ? 1 : 0)); return this; } - public Trace isXor(final Boolean b) { + public Trace isSignextend(final Boolean b) { if (filled.get(26)) { - throw new IllegalStateException("bin.IS_XOR already set"); + throw new IllegalStateException("bin.IS_SIGNEXTEND already set"); } else { filled.set(26); } - isXor.put((byte) (b ? 1 : 0)); + isSignextend.put((byte) (b ? 1 : 0)); return this; } - public Trace low4(final UnsignedByte b) { + public Trace isXor(final Boolean b) { if (filled.get(27)) { - throw new IllegalStateException("bin.LOW_4 already set"); + throw new IllegalStateException("bin.IS_XOR already set"); } else { filled.set(27); } - low4.put(b.toByte()); + isXor.put((byte) (b ? 1 : 0)); return this; } - public Trace mli(final Boolean b) { + public Trace low4(final UnsignedByte b) { if (filled.get(28)) { - throw new IllegalStateException("bin.MLI already set"); + throw new IllegalStateException("bin.LOW_4 already set"); } else { filled.set(28); } - mli.put((byte) (b ? 1 : 0)); + low4.put(b.toByte()); return this; } @@ -564,23 +561,11 @@ public Trace neg(final Boolean b) { return this; } - public Trace oneLineInstruction(final Boolean b) { - if (filled.get(30)) { - throw new IllegalStateException("bin.ONE_LINE_INSTRUCTION already set"); - } else { - filled.set(30); - } - - oneLineInstruction.put((byte) (b ? 1 : 0)); - - return this; - } - public Trace pivot(final UnsignedByte b) { - if (filled.get(31)) { + if (filled.get(30)) { throw new IllegalStateException("bin.PIVOT already set"); } else { - filled.set(31); + filled.set(30); } pivot.put(b.toByte()); @@ -589,10 +574,10 @@ public Trace pivot(final UnsignedByte b) { } public Trace resultHi(final Bytes b) { - if (filled.get(32)) { + if (filled.get(31)) { throw new IllegalStateException("bin.RESULT_HI already set"); } else { - filled.set(32); + filled.set(31); } final byte[] bs = b.toArrayUnsafe(); @@ -605,10 +590,10 @@ public Trace resultHi(final Bytes b) { } public Trace resultLo(final Bytes b) { - if (filled.get(33)) { + if (filled.get(32)) { throw new IllegalStateException("bin.RESULT_LO already set"); } else { - filled.set(33); + filled.set(32); } final byte[] bs = b.toArrayUnsafe(); @@ -621,10 +606,10 @@ public Trace resultLo(final Bytes b) { } public Trace small(final Boolean b) { - if (filled.get(34)) { + if (filled.get(33)) { throw new IllegalStateException("bin.SMALL already set"); } else { - filled.set(34); + filled.set(33); } small.put((byte) (b ? 1 : 0)); @@ -633,10 +618,10 @@ public Trace small(final Boolean b) { } public Trace stamp(final Bytes b) { - if (filled.get(35)) { + if (filled.get(34)) { throw new IllegalStateException("bin.STAMP already set"); } else { - filled.set(35); + filled.set(34); } final byte[] bs = b.toArrayUnsafe(); @@ -649,10 +634,10 @@ public Trace stamp(final Bytes b) { } public Trace xxxByteHi(final UnsignedByte b) { - if (filled.get(36)) { + if (filled.get(35)) { throw new IllegalStateException("bin.XXX_BYTE_HI already set"); } else { - filled.set(36); + filled.set(35); } xxxByteHi.put(b.toByte()); @@ -661,10 +646,10 @@ public Trace xxxByteHi(final UnsignedByte b) { } public Trace xxxByteLo(final UnsignedByte b) { - if (filled.get(37)) { + if (filled.get(36)) { throw new IllegalStateException("bin.XXX_BYTE_LO already set"); } else { - filled.set(37); + filled.set(36); } xxxByteLo.put(b.toByte()); @@ -754,39 +739,39 @@ public Trace validateRow() { } if (!filled.get(20)) { - throw new IllegalStateException("bin.INST has not been filled"); + throw new IllegalStateException("bin.CT_MAX has not been filled"); } if (!filled.get(21)) { - throw new IllegalStateException("bin.IS_AND has not been filled"); + throw new IllegalStateException("bin.INST has not been filled"); } if (!filled.get(22)) { - throw new IllegalStateException("bin.IS_BYTE has not been filled"); + throw new IllegalStateException("bin.IS_AND has not been filled"); } if (!filled.get(23)) { - throw new IllegalStateException("bin.IS_NOT has not been filled"); + throw new IllegalStateException("bin.IS_BYTE has not been filled"); } if (!filled.get(24)) { - throw new IllegalStateException("bin.IS_OR has not been filled"); + throw new IllegalStateException("bin.IS_NOT has not been filled"); } if (!filled.get(25)) { - throw new IllegalStateException("bin.IS_SIGNEXTEND has not been filled"); + throw new IllegalStateException("bin.IS_OR has not been filled"); } if (!filled.get(26)) { - throw new IllegalStateException("bin.IS_XOR has not been filled"); + throw new IllegalStateException("bin.IS_SIGNEXTEND has not been filled"); } if (!filled.get(27)) { - throw new IllegalStateException("bin.LOW_4 has not been filled"); + throw new IllegalStateException("bin.IS_XOR has not been filled"); } if (!filled.get(28)) { - throw new IllegalStateException("bin.MLI has not been filled"); + throw new IllegalStateException("bin.LOW_4 has not been filled"); } if (!filled.get(29)) { @@ -794,34 +779,30 @@ public Trace validateRow() { } if (!filled.get(30)) { - throw new IllegalStateException("bin.ONE_LINE_INSTRUCTION has not been filled"); - } - - if (!filled.get(31)) { throw new IllegalStateException("bin.PIVOT has not been filled"); } - if (!filled.get(32)) { + if (!filled.get(31)) { throw new IllegalStateException("bin.RESULT_HI has not been filled"); } - if (!filled.get(33)) { + if (!filled.get(32)) { throw new IllegalStateException("bin.RESULT_LO has not been filled"); } - if (!filled.get(34)) { + if (!filled.get(33)) { throw new IllegalStateException("bin.SMALL has not been filled"); } - if (!filled.get(35)) { + if (!filled.get(34)) { throw new IllegalStateException("bin.STAMP has not been filled"); } - if (!filled.get(36)) { + if (!filled.get(35)) { throw new IllegalStateException("bin.XXX_BYTE_HI has not been filled"); } - if (!filled.get(37)) { + if (!filled.get(36)) { throw new IllegalStateException("bin.XXX_BYTE_LO has not been filled"); } @@ -913,39 +894,39 @@ public Trace fillAndValidateRow() { } if (!filled.get(20)) { - inst.position(inst.position() + 1); + ctMax.position(ctMax.position() + 1); } if (!filled.get(21)) { - isAnd.position(isAnd.position() + 1); + inst.position(inst.position() + 1); } if (!filled.get(22)) { - isByte.position(isByte.position() + 1); + isAnd.position(isAnd.position() + 1); } if (!filled.get(23)) { - isNot.position(isNot.position() + 1); + isByte.position(isByte.position() + 1); } if (!filled.get(24)) { - isOr.position(isOr.position() + 1); + isNot.position(isNot.position() + 1); } if (!filled.get(25)) { - isSignextend.position(isSignextend.position() + 1); + isOr.position(isOr.position() + 1); } if (!filled.get(26)) { - isXor.position(isXor.position() + 1); + isSignextend.position(isSignextend.position() + 1); } if (!filled.get(27)) { - low4.position(low4.position() + 1); + isXor.position(isXor.position() + 1); } if (!filled.get(28)) { - mli.position(mli.position() + 1); + low4.position(low4.position() + 1); } if (!filled.get(29)) { @@ -953,34 +934,30 @@ public Trace fillAndValidateRow() { } if (!filled.get(30)) { - oneLineInstruction.position(oneLineInstruction.position() + 1); - } - - if (!filled.get(31)) { pivot.position(pivot.position() + 1); } - if (!filled.get(32)) { + if (!filled.get(31)) { resultHi.position(resultHi.position() + 32); } - if (!filled.get(33)) { + if (!filled.get(32)) { resultLo.position(resultLo.position() + 32); } - if (!filled.get(34)) { + if (!filled.get(33)) { small.position(small.position() + 1); } - if (!filled.get(35)) { + if (!filled.get(34)) { stamp.position(stamp.position() + 32); } - if (!filled.get(36)) { + if (!filled.get(35)) { xxxByteHi.position(xxxByteHi.position() + 1); } - if (!filled.get(37)) { + if (!filled.get(36)) { xxxByteLo.position(xxxByteLo.position() + 1); } From 7f0f16e64657a2c8cb1c04f513400d23f6a8f393 Mon Sep 17 00:00:00 2001 From: Francois Bojarski Date: Wed, 24 Jan 2024 17:28:59 +0100 Subject: [PATCH 2/6] fix(bin): rebase fix --- .../zktracer/module/bin/BinOperation.java | 49 +++++-------------- 1 file changed, 11 insertions(+), 38 deletions(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java index 731798592a..9477f5fadf 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java @@ -39,12 +39,9 @@ public class BinOperation extends ModuleOperation { private static final int LIMB_SIZE = 16; - @EqualsAndHashCode.Include - private final OpCode opCode; - @EqualsAndHashCode.Include - private final BaseBytes arg1; - @EqualsAndHashCode.Include - private final BaseBytes arg2; + @EqualsAndHashCode.Include private final OpCode opCode; + @EqualsAndHashCode.Include private final BaseBytes arg1; + @EqualsAndHashCode.Include private final BaseBytes arg2; public BinOperation(OpCode opCode, BaseBytes arg1, BaseBytes arg2) { this.opCode = opCode; @@ -55,9 +52,6 @@ public BinOperation(OpCode opCode, BaseBytes arg1, BaseBytes arg2) { private static final int LLARGE = 16; private static final int LLARGEMO = 15; - private final OpCode opCode; - private final BaseBytes arg1; - private final BaseBytes arg2; private final int ctMax; private List lastEightBits = List.of(false); private boolean bit4 = false; @@ -66,28 +60,6 @@ public BinOperation(OpCode opCode, BaseBytes arg1, BaseBytes arg2) { private int pivotThreshold = 0; private int pivot = 0; - <<<<<<>>>>>> 49b702a1 (perf(BIN): adaptative nb line) - } - @Override protected int computeLineCount() { return this.ctMax + 1; @@ -100,12 +72,12 @@ private int maxCt() { case AND, OR, XOR -> Math.max( 0, Math.max( - Math.max( - arg1.getHigh().trimLeadingZeros().size(), - arg2.getHigh().trimLeadingZeros().size()), - Math.max( - arg1.getLow().trimLeadingZeros().size(), - arg2.getLow().trimLeadingZeros().size())) + Math.max( + arg1.getHigh().trimLeadingZeros().size(), + arg2.getHigh().trimLeadingZeros().size()), + Math.max( + arg1.getLow().trimLeadingZeros().size(), + arg2.getLow().trimLeadingZeros().size())) - 1); default -> throw new IllegalStateException("Unexpected value: " + opCode); }; @@ -235,7 +207,8 @@ public void traceBinOperation(int stamp, Trace trace) { final Bytes resHi = this.getResult().getHigh().slice(offset, length); final Bytes resLo = this.getResult().getLow().slice(offset, length); final List bit1 = this.getBit1(); - final List bits = Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList(); + final List bits = + Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList(); for (int ct = 0; ct <= this.ctMax; ct++) { trace .stamp(Bytes.ofUnsignedInt(stamp)) From b8af7898e94f7a77b1c077b0fc88033d19687976 Mon Sep 17 00:00:00 2001 From: Francois Bojarski Date: Wed, 24 Jan 2024 17:36:51 +0100 Subject: [PATCH 3/6] feat: column typing --- .../zktracer/module/bin/BinOperation.java | 6 +- .../linea/zktracer/module/bin/Trace.java | 68 +++++++++++-------- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java index 9477f5fadf..a488291f77 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java @@ -211,9 +211,9 @@ public void traceBinOperation(int stamp, Trace trace) { Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList(); for (int ct = 0; ct <= this.ctMax; ct++) { trace - .stamp(Bytes.ofUnsignedInt(stamp)) - .ctMax(UnsignedByte.of(ctMax)) - .counter(UnsignedByte.of(ct)) + .stamp(Bytes.of(stamp)) + .ctMax(Bytes.of(ctMax)) + .counter(Bytes.of(ct)) .inst(UnsignedByte.of(this.opCode().byteValue())) .argument1Hi(arg1Hi) .argument1Lo(arg1Lo) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java index 1e77325c15..ac7110497d 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java @@ -74,16 +74,16 @@ public class Trace { static List headers(int length) { return List.of( - new ColumnHeader("bin.ACC_1", 32, length), - new ColumnHeader("bin.ACC_2", 32, length), - new ColumnHeader("bin.ACC_3", 32, length), - new ColumnHeader("bin.ACC_4", 32, length), - new ColumnHeader("bin.ACC_5", 32, length), - new ColumnHeader("bin.ACC_6", 32, length), - new ColumnHeader("bin.ARGUMENT_1_HI", 32, length), - new ColumnHeader("bin.ARGUMENT_1_LO", 32, length), - new ColumnHeader("bin.ARGUMENT_2_HI", 32, length), - new ColumnHeader("bin.ARGUMENT_2_LO", 32, length), + new ColumnHeader("bin.ACC_1", 16, length), + new ColumnHeader("bin.ACC_2", 16, length), + new ColumnHeader("bin.ACC_3", 16, length), + new ColumnHeader("bin.ACC_4", 16, length), + new ColumnHeader("bin.ACC_5", 16, length), + new ColumnHeader("bin.ACC_6", 16, length), + new ColumnHeader("bin.ARGUMENT_1_HI", 16, length), + new ColumnHeader("bin.ARGUMENT_1_LO", 16, length), + new ColumnHeader("bin.ARGUMENT_2_HI", 16, length), + new ColumnHeader("bin.ARGUMENT_2_LO", 16, length), new ColumnHeader("bin.BIT_1", 1, length), new ColumnHeader("bin.BIT_B_4", 1, length), new ColumnHeader("bin.BITS", 1, length), @@ -105,10 +105,10 @@ static List headers(int length) { new ColumnHeader("bin.LOW_4", 1, length), new ColumnHeader("bin.NEG", 1, length), new ColumnHeader("bin.PIVOT", 1, length), - new ColumnHeader("bin.RESULT_HI", 32, length), - new ColumnHeader("bin.RESULT_LO", 32, length), + new ColumnHeader("bin.RESULT_HI", 16, length), + new ColumnHeader("bin.RESULT_LO", 16, length), new ColumnHeader("bin.SMALL", 1, length), - new ColumnHeader("bin.STAMP", 32, length), + new ColumnHeader("bin.STAMP", 3, length), new ColumnHeader("bin.XXX_BYTE_HI", 1, length), new ColumnHeader("bin.XXX_BYTE_LO", 1, length)); } @@ -429,26 +429,34 @@ public Trace byte6(final UnsignedByte b) { return this; } - public Trace counter(final UnsignedByte b) { + public Trace counter(final Bytes b) { if (filled.get(19)) { throw new IllegalStateException("bin.COUNTER already set"); } else { filled.set(19); } - counter.put(b.toByte()); + final byte[] bs = b.toArrayUnsafe(); + for (int i = bs.length; i < 32; i++) { + counter.put((byte) 0); + } + counter.put(b.toArrayUnsafe()); return this; } - public Trace ctMax(final UnsignedByte b) { + public Trace ctMax(final Bytes b) { if (filled.get(20)) { throw new IllegalStateException("bin.CT_MAX already set"); } else { filled.set(20); } - ctMax.put(b.toByte()); + final byte[] bs = b.toArrayUnsafe(); + for (int i = bs.length; i < 32; i++) { + ctMax.put((byte) 0); + } + ctMax.put(b.toArrayUnsafe()); return this; } @@ -814,43 +822,43 @@ public Trace validateRow() { public Trace fillAndValidateRow() { if (!filled.get(0)) { - acc1.position(acc1.position() + 32); + acc1.position(acc1.position() + 16); } if (!filled.get(1)) { - acc2.position(acc2.position() + 32); + acc2.position(acc2.position() + 16); } if (!filled.get(2)) { - acc3.position(acc3.position() + 32); + acc3.position(acc3.position() + 16); } if (!filled.get(3)) { - acc4.position(acc4.position() + 32); + acc4.position(acc4.position() + 16); } if (!filled.get(4)) { - acc5.position(acc5.position() + 32); + acc5.position(acc5.position() + 16); } if (!filled.get(5)) { - acc6.position(acc6.position() + 32); + acc6.position(acc6.position() + 16); } if (!filled.get(6)) { - argument1Hi.position(argument1Hi.position() + 32); + argument1Hi.position(argument1Hi.position() + 16); } if (!filled.get(7)) { - argument1Lo.position(argument1Lo.position() + 32); + argument1Lo.position(argument1Lo.position() + 16); } if (!filled.get(8)) { - argument2Hi.position(argument2Hi.position() + 32); + argument2Hi.position(argument2Hi.position() + 16); } if (!filled.get(9)) { - argument2Lo.position(argument2Lo.position() + 32); + argument2Lo.position(argument2Lo.position() + 16); } if (!filled.get(11)) { @@ -938,11 +946,11 @@ public Trace fillAndValidateRow() { } if (!filled.get(31)) { - resultHi.position(resultHi.position() + 32); + resultHi.position(resultHi.position() + 16); } if (!filled.get(32)) { - resultLo.position(resultLo.position() + 32); + resultLo.position(resultLo.position() + 16); } if (!filled.get(33)) { @@ -950,7 +958,7 @@ public Trace fillAndValidateRow() { } if (!filled.get(34)) { - stamp.position(stamp.position() + 32); + stamp.position(stamp.position() + 3); } if (!filled.get(35)) { From 4e518ec479af9fc8d6c2fc1e44c0e9efc7ad860b Mon Sep 17 00:00:00 2001 From: Francois Bojarski Date: Wed, 24 Jan 2024 17:47:57 +0100 Subject: [PATCH 4/6] fix(constraint): ras --- .../zktracer/module/bin/BinOperation.java | 4 ++-- .../linea/zktracer/module/bin/Trace.java | 20 ++++++------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java index a488291f77..1fc7afbce1 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java @@ -212,8 +212,8 @@ public void traceBinOperation(int stamp, Trace trace) { for (int ct = 0; ct <= this.ctMax; ct++) { trace .stamp(Bytes.of(stamp)) - .ctMax(Bytes.of(ctMax)) - .counter(Bytes.of(ct)) + .ctMax(UnsignedByte.of(ctMax)) + .counter(UnsignedByte.of(ct)) .inst(UnsignedByte.of(this.opCode().byteValue())) .argument1Hi(arg1Hi) .argument1Lo(arg1Lo) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java index ac7110497d..748c01cf47 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java @@ -108,7 +108,7 @@ static List headers(int length) { new ColumnHeader("bin.RESULT_HI", 16, length), new ColumnHeader("bin.RESULT_LO", 16, length), new ColumnHeader("bin.SMALL", 1, length), - new ColumnHeader("bin.STAMP", 3, length), + new ColumnHeader("bin.STAMP", 4, length), new ColumnHeader("bin.XXX_BYTE_HI", 1, length), new ColumnHeader("bin.XXX_BYTE_LO", 1, length)); } @@ -429,34 +429,26 @@ public Trace byte6(final UnsignedByte b) { return this; } - public Trace counter(final Bytes b) { + public Trace counter(final UnsignedByte b) { if (filled.get(19)) { throw new IllegalStateException("bin.COUNTER already set"); } else { filled.set(19); } - final byte[] bs = b.toArrayUnsafe(); - for (int i = bs.length; i < 32; i++) { - counter.put((byte) 0); - } - counter.put(b.toArrayUnsafe()); + counter.put(b.toByte()); return this; } - public Trace ctMax(final Bytes b) { + public Trace ctMax(final UnsignedByte b) { if (filled.get(20)) { throw new IllegalStateException("bin.CT_MAX already set"); } else { filled.set(20); } - final byte[] bs = b.toArrayUnsafe(); - for (int i = bs.length; i < 32; i++) { - ctMax.put((byte) 0); - } - ctMax.put(b.toArrayUnsafe()); + ctMax.put(b.toByte()); return this; } @@ -958,7 +950,7 @@ public Trace fillAndValidateRow() { } if (!filled.get(34)) { - stamp.position(stamp.position() + 3); + stamp.position(stamp.position() + 4); } if (!filled.get(35)) { From fe4b35d035dc867b6edd91fe0ad357c95df3f9bb Mon Sep 17 00:00:00 2001 From: Francois Bojarski Date: Wed, 24 Jan 2024 23:02:00 +0100 Subject: [PATCH 5/6] fix: ras --- .../net/consensys/linea/zktracer/module/bin/BinOperation.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java index 1fc7afbce1..513ef272fc 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/BinOperation.java @@ -200,6 +200,7 @@ public void traceBinOperation(int stamp, Trace trace) { final int length = ctMax + 1; final int offset = LLARGE - length; + final Bytes stampBytes = Bytes.minimalBytes(stamp); final Bytes arg1Hi = this.arg1.getHigh().slice(offset, length); final Bytes arg1Lo = this.arg1.getLow().slice(offset, length); final Bytes arg2Hi = this.arg2.getHigh().slice(offset, length); @@ -211,7 +212,7 @@ public void traceBinOperation(int stamp, Trace trace) { Stream.concat(this.getFirstEightBits().stream(), this.lastEightBits.stream()).toList(); for (int ct = 0; ct <= this.ctMax; ct++) { trace - .stamp(Bytes.of(stamp)) + .stamp(stampBytes) .ctMax(UnsignedByte.of(ctMax)) .counter(UnsignedByte.of(ct)) .inst(UnsignedByte.of(this.opCode().byteValue())) From 48fbefddf11390a8530dcc4a114c28b2007c9565 Mon Sep 17 00:00:00 2001 From: Francois Bojarski Date: Thu, 25 Jan 2024 11:35:28 +0100 Subject: [PATCH 6/6] fix: corset export --- .../linea/zktracer/module/bin/Trace.java | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java index 748c01cf47..1e77325c15 100644 --- a/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java +++ b/arithmetization/src/main/java/net/consensys/linea/zktracer/module/bin/Trace.java @@ -74,16 +74,16 @@ public class Trace { static List headers(int length) { return List.of( - new ColumnHeader("bin.ACC_1", 16, length), - new ColumnHeader("bin.ACC_2", 16, length), - new ColumnHeader("bin.ACC_3", 16, length), - new ColumnHeader("bin.ACC_4", 16, length), - new ColumnHeader("bin.ACC_5", 16, length), - new ColumnHeader("bin.ACC_6", 16, length), - new ColumnHeader("bin.ARGUMENT_1_HI", 16, length), - new ColumnHeader("bin.ARGUMENT_1_LO", 16, length), - new ColumnHeader("bin.ARGUMENT_2_HI", 16, length), - new ColumnHeader("bin.ARGUMENT_2_LO", 16, length), + new ColumnHeader("bin.ACC_1", 32, length), + new ColumnHeader("bin.ACC_2", 32, length), + new ColumnHeader("bin.ACC_3", 32, length), + new ColumnHeader("bin.ACC_4", 32, length), + new ColumnHeader("bin.ACC_5", 32, length), + new ColumnHeader("bin.ACC_6", 32, length), + new ColumnHeader("bin.ARGUMENT_1_HI", 32, length), + new ColumnHeader("bin.ARGUMENT_1_LO", 32, length), + new ColumnHeader("bin.ARGUMENT_2_HI", 32, length), + new ColumnHeader("bin.ARGUMENT_2_LO", 32, length), new ColumnHeader("bin.BIT_1", 1, length), new ColumnHeader("bin.BIT_B_4", 1, length), new ColumnHeader("bin.BITS", 1, length), @@ -105,10 +105,10 @@ static List headers(int length) { new ColumnHeader("bin.LOW_4", 1, length), new ColumnHeader("bin.NEG", 1, length), new ColumnHeader("bin.PIVOT", 1, length), - new ColumnHeader("bin.RESULT_HI", 16, length), - new ColumnHeader("bin.RESULT_LO", 16, length), + new ColumnHeader("bin.RESULT_HI", 32, length), + new ColumnHeader("bin.RESULT_LO", 32, length), new ColumnHeader("bin.SMALL", 1, length), - new ColumnHeader("bin.STAMP", 4, length), + new ColumnHeader("bin.STAMP", 32, length), new ColumnHeader("bin.XXX_BYTE_HI", 1, length), new ColumnHeader("bin.XXX_BYTE_LO", 1, length)); } @@ -814,43 +814,43 @@ public Trace validateRow() { public Trace fillAndValidateRow() { if (!filled.get(0)) { - acc1.position(acc1.position() + 16); + acc1.position(acc1.position() + 32); } if (!filled.get(1)) { - acc2.position(acc2.position() + 16); + acc2.position(acc2.position() + 32); } if (!filled.get(2)) { - acc3.position(acc3.position() + 16); + acc3.position(acc3.position() + 32); } if (!filled.get(3)) { - acc4.position(acc4.position() + 16); + acc4.position(acc4.position() + 32); } if (!filled.get(4)) { - acc5.position(acc5.position() + 16); + acc5.position(acc5.position() + 32); } if (!filled.get(5)) { - acc6.position(acc6.position() + 16); + acc6.position(acc6.position() + 32); } if (!filled.get(6)) { - argument1Hi.position(argument1Hi.position() + 16); + argument1Hi.position(argument1Hi.position() + 32); } if (!filled.get(7)) { - argument1Lo.position(argument1Lo.position() + 16); + argument1Lo.position(argument1Lo.position() + 32); } if (!filled.get(8)) { - argument2Hi.position(argument2Hi.position() + 16); + argument2Hi.position(argument2Hi.position() + 32); } if (!filled.get(9)) { - argument2Lo.position(argument2Lo.position() + 16); + argument2Lo.position(argument2Lo.position() + 32); } if (!filled.get(11)) { @@ -938,11 +938,11 @@ public Trace fillAndValidateRow() { } if (!filled.get(31)) { - resultHi.position(resultHi.position() + 16); + resultHi.position(resultHi.position() + 32); } if (!filled.get(32)) { - resultLo.position(resultLo.position() + 16); + resultLo.position(resultLo.position() + 32); } if (!filled.get(33)) { @@ -950,7 +950,7 @@ public Trace fillAndValidateRow() { } if (!filled.get(34)) { - stamp.position(stamp.position() + 4); + stamp.position(stamp.position() + 32); } if (!filled.get(35)) {