Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update circuit evaluation #81

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ serde = { version = "1.0.196", features = ["derive"] }
thiserror = "1.0.59"
strum_macros = "0.26.4"
strum = "0.26.2"
sim-circuit = { git = "https://github.com/brech1/sim-circuit" }

# DSL
circom-circom_algebra = { git = "https://github.com/iden3/circom", package = "circom_algebra" }
Expand All @@ -29,8 +30,3 @@ circom-dag = { git = "https://github.com/iden3/circom", package = "dag" }
circom-parser = { git = "https://github.com/iden3/circom", package = "parser" }
circom-program_structure = { git = "https://github.com/iden3/circom", package = "program_structure" }
circom-type_analysis = { git = "https://github.com/iden3/circom", package = "type_analysis" }

# MPZ
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", package = "mpz-circuits" }
bmr16-mpz = { git = "https://github.com/tkmct/mpz", package = "mpz-circuits" }
sim-circuit = { git = "https://github.com/brech1/sim-circuit" }
61 changes: 54 additions & 7 deletions src/arithmetic_circuit.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,64 @@
use crate::compiler::{ArithmeticGate, CircuitError};
use circom_program_structure::ast::ExpressionInfixOpcode;
use serde::{Deserialize, Serialize};
use serde_json::{from_str, to_string};
use sim_circuit::arithmetic_circuit::ArithmeticCircuit as SimArithmeticCircuit;
use std::{
collections::HashMap,
io::{BufRead, BufReader, BufWriter, Write},
str::FromStr,
};
use strum_macros::{Display as StrumDisplay, EnumString};

/// The supported Arithmetic gate types.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString, StrumDisplay)]
pub enum AGateType {
AAdd,
ADiv,
AEq,
AGEq,
AGt,
ALEq,
ALt,
AMul,
ANeq,
ASub,
AXor,
APow,
AIntDiv,
AMod,
AShiftL,
AShiftR,
ABoolOr,
ABoolAnd,
ABitOr,
ABitAnd,
}

impl From<&ExpressionInfixOpcode> for AGateType {
fn from(opcode: &ExpressionInfixOpcode) -> Self {
match opcode {
ExpressionInfixOpcode::Mul => AGateType::AMul,
ExpressionInfixOpcode::Div => AGateType::ADiv,
ExpressionInfixOpcode::Add => AGateType::AAdd,
ExpressionInfixOpcode::Sub => AGateType::ASub,
ExpressionInfixOpcode::Pow => AGateType::APow,
ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv,
ExpressionInfixOpcode::Mod => AGateType::AMod,
ExpressionInfixOpcode::ShiftL => AGateType::AShiftL,
ExpressionInfixOpcode::ShiftR => AGateType::AShiftR,
ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
ExpressionInfixOpcode::GreaterEq => AGateType::AGEq,
ExpressionInfixOpcode::Lesser => AGateType::ALt,
ExpressionInfixOpcode::Greater => AGateType::AGt,
ExpressionInfixOpcode::Eq => AGateType::AEq,
ExpressionInfixOpcode::NotEq => AGateType::ANeq,
ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr,
ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd,
ExpressionInfixOpcode::BitOr => AGateType::ABitOr,
ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd,
ExpressionInfixOpcode::BitXor => AGateType::AXor,
}
}
}

#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ArithmeticCircuit {
Expand All @@ -29,10 +81,6 @@ pub struct ConstantInfo {
}

impl ArithmeticCircuit {
pub fn to_sim(&self) -> SimArithmeticCircuit {
from_str(&to_string(self).unwrap()).unwrap()
}

pub fn get_bristol_string(&self) -> Result<String, CircuitError> {
let mut output = Vec::new();
let mut writer = BufWriter::new(&mut output);
Expand Down Expand Up @@ -216,7 +264,6 @@ impl BristolLine {
#[cfg(test)]
mod tests {
use super::*;
use crate::compiler::AGateType;
use std::io::{BufReader, Cursor};

// Helper function to create a sample ArithmeticCircuit
Expand Down
163 changes: 1 addition & 162 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,15 @@
//! This module defines the data structures used to represent the arithmetic circuit.

use crate::{
arithmetic_circuit::{ArithmeticCircuit, CircuitInfo, ConstantInfo},
arithmetic_circuit::{AGateType, ArithmeticCircuit, CircuitInfo, ConstantInfo},
program::ProgramError,
topological_sort::topological_sort,
};
use bmr16_mpz::{
arithmetic::{
circuit::ArithmeticCircuit as MpzCircuit,
ops::{add, cmul, mul, sub},
types::CrtRepr,
ArithCircuitError as MpzCircuitError,
},
ArithmeticCircuitBuilder,
};
use circom_program_structure::ast::ExpressionInfixOpcode;
use log::debug;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use strum_macros::{Display as StrumDisplay, EnumString};
use thiserror::Error;

/// Types of gates that can be used in an arithmetic circuit.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString, StrumDisplay)]
pub enum AGateType {
AAdd,
ADiv,
AEq,
AGEq,
AGt,
ALEq,
ALt,
AMul,
ANeq,
ASub,
AXor,
APow,
AIntDiv,
AMod,
AShiftL,
AShiftR,
ABoolOr,
ABoolAnd,
ABitOr,
ABitAnd,
}

impl From<&ExpressionInfixOpcode> for AGateType {
fn from(opcode: &ExpressionInfixOpcode) -> Self {
match opcode {
ExpressionInfixOpcode::Mul => AGateType::AMul,
ExpressionInfixOpcode::Div => AGateType::ADiv,
ExpressionInfixOpcode::Add => AGateType::AAdd,
ExpressionInfixOpcode::Sub => AGateType::ASub,
ExpressionInfixOpcode::Pow => AGateType::APow,
ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv,
ExpressionInfixOpcode::Mod => AGateType::AMod,
ExpressionInfixOpcode::ShiftL => AGateType::AShiftL,
ExpressionInfixOpcode::ShiftR => AGateType::AShiftR,
ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
ExpressionInfixOpcode::GreaterEq => AGateType::AGEq,
ExpressionInfixOpcode::Lesser => AGateType::ALt,
ExpressionInfixOpcode::Greater => AGateType::AGt,
ExpressionInfixOpcode::Eq => AGateType::AEq,
ExpressionInfixOpcode::NotEq => AGateType::ANeq,
ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr,
ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd,
ExpressionInfixOpcode::BitOr => AGateType::ABitOr,
ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd,
ExpressionInfixOpcode::BitXor => AGateType::AXor,
}
}
}

/// Represents a signal in the circuit, with a name and an optional value.
#[derive(Debug, Serialize, Deserialize)]
pub struct Signal {
Expand Down Expand Up @@ -541,94 +478,6 @@ impl Compiler {
})
}

/// Builds an arithmetic circuit using the mpz circuit builder.
pub fn build_mpz_circuit(&self, report: &CircuitReport) -> Result<MpzCircuit, CircuitError> {
let builder = ArithmeticCircuitBuilder::new();

// Initialize CRT signals map with the circuit inputs
let mut crt_signals: HashMap<u32, CrtRepr> =
report
.inputs
.iter()
.try_fold(HashMap::new(), |mut acc, signal| {
let input = builder
.add_input::<u32>(signal.names[0].to_string())
.map_err(CircuitError::MPZCircuitError)?;
acc.insert(signal.id, input.repr);
Ok::<_, CircuitError>(acc)
})?;

// Initialize a vec for indices of gates that need processing
let mut to_process = std::collections::VecDeque::new();
to_process.extend(0..self.gates.len());

while let Some(index) = to_process.pop_front() {
let gate = &self.gates[index];

if let (Some(lh_in_repr), Some(rh_in_repr)) =
(crt_signals.get(&gate.lh_in), crt_signals.get(&gate.rh_in))
{
let result_repr = match gate.op {
AGateType::AAdd => {
add(&mut builder.state().borrow_mut(), lh_in_repr, rh_in_repr)
.map_err(|e| e.into())
}
AGateType::AMul => {
// Get the constant value from one of the signals if available
let constant_value = self
.signals
.get(&gate.lh_in)
.and_then(|signal| signal.value.map(|v| v as u64))
.or_else(|| {
self.signals
.get(&gate.rh_in)
.and_then(|signal| signal.value.map(|v| v as u64))
});

// Perform multiplication depending on whether one input is a constant
if let Some(value) = constant_value {
Ok::<_, CircuitError>(cmul(
&mut builder.state().borrow_mut(),
lh_in_repr,
value,
))
} else {
mul(&mut builder.state().borrow_mut(), lh_in_repr, rh_in_repr)
.map_err(|e| e.into())
}
}
AGateType::ASub => {
sub(&mut builder.state().borrow_mut(), lh_in_repr, rh_in_repr)
.map_err(|e| e.into())
}
_ => {
return Err(CircuitError::UnsupportedGateType(format!(
"{:?} not supported by MPZ",
gate.op
)))
}
}?;

crt_signals.insert(gate.out, result_repr);
} else {
// Not ready to process, push back for later attempt.
to_process.push_back(index);
}
}

// Add output signals
for signal in &report.outputs {
let output_repr = crt_signals
.get(&signal.id)
.ok_or_else(|| CircuitError::UnprocessedNode)?;
builder.add_output(output_repr);
}

builder
.build()
.map_err(|_| CircuitError::MPZCircuitBuilderError)
}

/// Returns a node id and increments the count.
fn get_node_id(&mut self) -> u32 {
self.node_count += 1;
Expand Down Expand Up @@ -694,10 +543,6 @@ pub enum CircuitError {
DisconnectedSignal,
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error("MPZ arithmetic circuit error: {0}")]
MPZCircuitError(MpzCircuitError),
#[error("MPZ arithmetic circuit builder error")]
MPZCircuitBuilderError,
#[error(transparent)]
ParseIntError(#[from] std::num::ParseIntError),
#[error("Signal already declared")]
Expand All @@ -720,12 +565,6 @@ impl From<CircuitError> for ProgramError {
}
}

impl From<MpzCircuitError> for CircuitError {
fn from(e: MpzCircuitError) -> Self {
CircuitError::MPZCircuitError(e)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 3 additions & 1 deletion src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
//!
//! Handles execution of statements and expressions for arithmetic circuit generation within a `Runtime` environment.

use crate::compiler::{AGateType, Compiler};
use crate::arithmetic_circuit::AGateType;
use crate::compiler::Compiler;
use crate::program::ProgramError;
use crate::runtime::{
generate_u32, increment_indices, u32_to_access, Context, DataAccess, DataType, NestedValue,
Expand Down Expand Up @@ -759,6 +760,7 @@ fn to_equivalent_infix(op: &ExpressionPrefixOpcode) -> (u32, ExpressionInfixOpco
ExpressionPrefixOpcode::Complement => (u32::MAX, ExpressionInfixOpcode::BitXor),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading
Loading