diff --git a/prover/Cargo.lock b/prover/Cargo.lock index fbbf7bbf36..59a9101528 100644 --- a/prover/Cargo.lock +++ b/prover/Cargo.lock @@ -4230,7 +4230,7 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "scroll-proving-sdk" version = "0.1.0" -source = "git+https://github.com/scroll-tech/scroll-proving-sdk.git?rev=7f8dca4#7f8dca4f6af29995fed4ea45048e30a8784e1b6d" +source = "git+https://github.com/scroll-tech/scroll-proving-sdk.git?rev=aa1a9fa#aa1a9fa50d309a7834ed52787b54af1c60feacf3" dependencies = [ "anyhow", "async-trait", @@ -4242,7 +4242,6 @@ dependencies = [ "hex", "http 1.1.0", "log", - "prover 0.13.0", "rand", "reqwest 0.12.4", "reqwest-middleware", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 52e0dfd3ed..fdd5da28bc 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -7,10 +7,10 @@ edition = "2021" [patch.crates-io] -ethers-signers = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } +ethers-signers = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } halo2curves = { git = "https://github.com/scroll-tech/halo2curves", branch = "v0.1.0" } [patch."https://github.com/privacy-scaling-explorations/halo2.git"] -halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "v1.1" } +halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "v1.1" } [patch."https://github.com/privacy-scaling-explorations/poseidon.git"] poseidon = { git = "https://github.com/scroll-tech/poseidon.git", branch = "main" } [patch."https://github.com/privacy-scaling-explorations/bls12_381"] @@ -28,10 +28,20 @@ futures = "0.3.30" ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } ethers-providers = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "v1.1" } -snark-verifier-sdk = { git = "https://github.com/scroll-tech/snark-verifier", branch = "develop", default-features = false, features = ["loader_halo2", "loader_evm", "halo2-pse"] } -prover_darwin = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.12.2", package = "prover", default-features = false, features = ["parallel_syn", "scroll"] } -prover_darwin_v2 = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.13.1", package = "prover", default-features = false, features = ["parallel_syn", "scroll"] } -scroll-proving-sdk = { git = "https://github.com/scroll-tech/scroll-proving-sdk.git", rev = "7f8dca4"} +snark-verifier-sdk = { git = "https://github.com/scroll-tech/snark-verifier", branch = "develop", default-features = false, features = [ + "loader_halo2", + "loader_evm", + "halo2-pse", +] } +prover_darwin = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.12.2", package = "prover", default-features = false, features = [ + "parallel_syn", + "scroll", +] } +prover_darwin_v2 = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.13.1", package = "prover", default-features = false, features = [ + "parallel_syn", + "scroll", +] } +scroll-proving-sdk = { git = "https://github.com/scroll-tech/scroll-proving-sdk.git", rev = "aa1a9fa" } base64 = "0.13.1" reqwest = { version = "0.12.4", features = ["gzip"] } reqwest-middleware = "0.3" diff --git a/prover/src/main.rs b/prover/src/main.rs index 5a9b09bb56..6c51f8c856 100644 --- a/prover/src/main.rs +++ b/prover/src/main.rs @@ -4,17 +4,14 @@ mod config; mod prover; mod types; -mod utils; mod zk_circuits_handler; use clap::{ArgAction, Parser}; -use prover::LocalProver; +use prover::{LocalProver, LocalProverConfig}; use scroll_proving_sdk::{ - config::Config, prover::ProverBuilder, utils::{get_version, init_tracing}, }; -use utils::get_prover_type; #[derive(Parser, Debug)] #[clap(disable_version_flag = true)] @@ -43,23 +40,10 @@ async fn main() -> anyhow::Result<()> { std::process::exit(0); } - let cfg: Config = Config::from_file(args.config_file)?; - let mut prover_types = vec![]; - cfg.prover.circuit_types.iter().for_each(|circuit_type| { - if let Some(pt) = get_prover_type(*circuit_type) { - if !prover_types.contains(&pt) { - prover_types.push(pt); - } - } - }); - let local_prover = LocalProver::new( - cfg.prover - .local - .clone() - .ok_or_else(|| anyhow::anyhow!("Missing local prover configuration"))?, - prover_types, - ); - let prover = ProverBuilder::new(cfg) + let cfg = LocalProverConfig::from_file(args.config_file)?; + let sdk_config = cfg.sdk_config.clone(); + let local_prover = LocalProver::new(cfg); + let prover = ProverBuilder::new(sdk_config) .with_proving_service(Box::new(local_prover)) .build() .await?; diff --git a/prover/src/prover.rs b/prover/src/prover.rs index 73b12360d7..b0a69c39bb 100644 --- a/prover/src/prover.rs +++ b/prover/src/prover.rs @@ -1,12 +1,8 @@ -use crate::{ - types::ProverType, - utils::get_prover_type, - zk_circuits_handler::{CircuitsHandler, CircuitsHandlerProvider}, -}; -use anyhow::Result; +use crate::zk_circuits_handler::{CircuitsHandler, CircuitsHandlerProvider}; +use anyhow::{anyhow, Result}; use async_trait::async_trait; use scroll_proving_sdk::{ - config::LocalProverConfig, + config::Config as SdkConfig, prover::{ proving_service::{ GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest, @@ -15,15 +11,44 @@ use scroll_proving_sdk::{ ProvingService, }, }; +use serde::{Deserialize, Serialize}; use std::{ + fs::File, sync::{Arc, Mutex}, time::{SystemTime, UNIX_EPOCH}, }; use tokio::{runtime::Handle, sync::RwLock, task::JoinHandle}; +#[derive(Clone, Serialize, Deserialize)] +pub struct LocalProverConfig { + pub sdk_config: SdkConfig, + pub high_version_circuit: CircuitConfig, + pub low_version_circuit: CircuitConfig, +} + +impl LocalProverConfig { + pub fn from_reader(reader: R) -> Result + where + R: std::io::Read, + { + serde_json::from_reader(reader).map_err(|e| anyhow!(e)) + } + + pub fn from_file(file_name: String) -> Result { + let file = File::open(file_name)?; + Self::from_reader(&file) + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct CircuitConfig { + pub hard_fork_name: String, + pub params_path: String, + pub assets_path: String, +} + pub struct LocalProver { config: LocalProverConfig, - prover_types: Vec, circuits_handler_provider: RwLock, next_task_id: Arc>, current_task: Arc>>>>, @@ -35,20 +60,11 @@ impl ProvingService for LocalProver { true } async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse { - let mut prover_types = vec![]; - req.circuit_types.iter().for_each(|circuit_type| { - if let Some(pt) = get_prover_type(*circuit_type) { - if !prover_types.contains(&pt) { - prover_types.push(pt); - } - } - }); - let vks = self .circuits_handler_provider .read() .await - .init_vks(&self.config, prover_types) + .init_vks(&self.config, req.proof_types) .await; GetVkResponse { vks, error: None } } @@ -57,7 +73,7 @@ impl ProvingService for LocalProver { .circuits_handler_provider .write() .await - .get_circuits_handler(&req.hard_fork_name, self.prover_types.clone()) + .get_circuits_handler(&req.hard_fork_name) .expect("failed to get circuit handler"); match self.do_prove(req, handler).await { @@ -114,13 +130,12 @@ impl ProvingService for LocalProver { } impl LocalProver { - pub fn new(config: LocalProverConfig, prover_types: Vec) -> Self { + pub fn new(config: LocalProverConfig) -> Self { let circuits_handler_provider = CircuitsHandlerProvider::new(config.clone()) .expect("failed to create circuits handler provider"); Self { config, - prover_types, circuits_handler_provider: RwLock::new(circuits_handler_provider), next_task_id: Arc::new(Mutex::new(0)), current_task: Arc::new(Mutex::new(None)), @@ -150,7 +165,7 @@ impl LocalProver { Ok(ProveResponse { task_id: task_id.to_string(), - circuit_type: req.circuit_type, + proof_type: req.proof_type, circuit_version: req.circuit_version, hard_fork_name: req.hard_fork_name, status: TaskStatus::Proving, diff --git a/prover/src/types.rs b/prover/src/types.rs index 39a99b37c8..273749238e 100644 --- a/prover/src/types.rs +++ b/prover/src/types.rs @@ -5,46 +5,6 @@ use scroll_proving_sdk::prover::types::CircuitType; pub type CommonHash = H256; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ProverType { - Chunk, - Batch, -} - -impl ProverType { - fn from_u8(v: u8) -> Self { - match v { - 1 => ProverType::Chunk, - 2 => ProverType::Batch, - _ => { - panic!("invalid prover_type") - } - } - } -} - -impl Serialize for ProverType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match *self { - ProverType::Chunk => serializer.serialize_u8(1), - ProverType::Batch => serializer.serialize_u8(2), - } - } -} - -impl<'de> Deserialize<'de> for ProverType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let v: u8 = u8::deserialize(deserializer)?; - Ok(ProverType::from_u8(v)) - } -} - #[derive(Serialize, Deserialize, Default)] pub struct Task { #[serde(rename = "type", default)] diff --git a/prover/src/utils.rs b/prover/src/utils.rs deleted file mode 100644 index b0602e9965..0000000000 --- a/prover/src/utils.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::types::ProverType; -use scroll_proving_sdk::prover::types::CircuitType; - -pub fn get_circuit_types(prover_type: ProverType) -> Vec { - match prover_type { - ProverType::Chunk => vec![CircuitType::Chunk], - ProverType::Batch => vec![CircuitType::Batch, CircuitType::Bundle], - } -} - -pub fn get_prover_type(task_type: CircuitType) -> Option { - match task_type { - CircuitType::Undefined => None, - CircuitType::Chunk => Some(ProverType::Chunk), - CircuitType::Batch => Some(ProverType::Batch), - CircuitType::Bundle => Some(ProverType::Batch), - } -} diff --git a/prover/src/zk_circuits_handler.rs b/prover/src/zk_circuits_handler.rs index 5f007a1b35..08ca004b7e 100644 --- a/prover/src/zk_circuits_handler.rs +++ b/prover/src/zk_circuits_handler.rs @@ -2,15 +2,12 @@ mod common; mod darwin; mod darwin_v2; -use crate::{config::AssetsDirEnvConfig, types::ProverType, utils::get_circuit_types}; +use crate::{config::AssetsDirEnvConfig, prover::LocalProverConfig}; use anyhow::{bail, Result}; use async_trait::async_trait; use darwin::DarwinHandler; use darwin_v2::DarwinV2Handler; -use scroll_proving_sdk::{ - config::LocalProverConfig, - prover::{proving_service::ProveRequest, CircuitType}, -}; +use scroll_proving_sdk::prover::{proving_service::ProveRequest, ProofType}; use std::{collections::HashMap, sync::Arc}; type HardForkName = String; @@ -23,15 +20,13 @@ pub mod utils { #[async_trait] pub trait CircuitsHandler: Send + Sync { - async fn get_vk(&self, task_type: CircuitType) -> Option>; + async fn get_vk(&self, task_type: ProofType) -> Option>; async fn get_proof_data(&self, prove_request: ProveRequest) -> Result; } -type CircuitsHandlerBuilder = fn( - prover_types: Vec, - config: &LocalProverConfig, -) -> Result>; +type CircuitsHandlerBuilder = + fn(proof_types: Vec, config: &LocalProverConfig) -> Result>; pub struct CircuitsHandlerProvider { config: LocalProverConfig, @@ -49,7 +44,7 @@ impl CircuitsHandlerProvider { } fn handler_builder( - prover_types: Vec, + proof_types: Vec, config: &LocalProverConfig, ) -> Result> { log::info!( @@ -58,7 +53,7 @@ impl CircuitsHandlerProvider { ); AssetsDirEnvConfig::enable_first(); DarwinHandler::new( - prover_types, + proof_types, &config.low_version_circuit.params_path, &config.low_version_circuit.assets_path, ) @@ -70,7 +65,7 @@ impl CircuitsHandlerProvider { ); fn next_handler_builder( - prover_types: Vec, + proof_types: Vec, config: &LocalProverConfig, ) -> Result> { log::info!( @@ -79,7 +74,7 @@ impl CircuitsHandlerProvider { ); AssetsDirEnvConfig::enable_second(); DarwinV2Handler::new( - prover_types, + proof_types, &config.high_version_circuit.params_path, &config.high_version_circuit.assets_path, ) @@ -104,7 +99,6 @@ impl CircuitsHandlerProvider { pub fn get_circuits_handler( &mut self, hard_fork_name: &String, - prover_types: Vec, ) -> Result>> { match &self.current_fork_name { Some(fork_name) if fork_name == hard_fork_name => { @@ -121,8 +115,11 @@ impl CircuitsHandlerProvider { ); if let Some(builder) = self.circuits_handler_builder_map.get(hard_fork_name) { log::info!("building circuits handler for {hard_fork_name}"); - let handler = builder(prover_types, &self.config) - .expect("failed to build circuits handler"); + let handler = builder( + self.config.sdk_config.prover.supported_proof_types.clone(), + &self.config, + ) + .expect("failed to build circuits handler"); self.current_fork_name = Some(hard_fork_name.clone()); let arc_handler = Arc::new(handler); self.current_circuit = Some(arc_handler.clone()); @@ -137,26 +134,24 @@ impl CircuitsHandlerProvider { pub async fn init_vks( &self, config: &LocalProverConfig, - prover_types: Vec, + proof_types: Vec, ) -> Vec { let mut vks = Vec::new(); for (hard_fork_name, build) in self.circuits_handler_builder_map.iter() { let handler = - build(prover_types.clone(), config).expect("failed to build circuits handler"); - - for prover_type in prover_types.iter() { - for task_type in get_circuit_types(*prover_type).into_iter() { - let vk = handler - .get_vk(task_type) - .await - .map_or("".to_string(), utils::encode_vk); - log::info!( - "vk for {hard_fork_name}, is {vk}, task_type: {:?}", - task_type - ); - if !vk.is_empty() { - vks.push(vk) - } + build(proof_types.clone(), config).expect("failed to build circuits handler"); + + for prover_type in &proof_types { + let vk = handler + .get_vk(*prover_type) + .await + .map_or("".to_string(), utils::encode_vk); + log::info!( + "vk for {hard_fork_name}, is {vk}, prover_type: {:?}", + prover_type + ); + if !vk.is_empty() { + vks.push(vk) } } } diff --git a/prover/src/zk_circuits_handler/common.rs b/prover/src/zk_circuits_handler/common.rs index e88628ad65..43a45ce394 100644 --- a/prover/src/zk_circuits_handler/common.rs +++ b/prover/src/zk_circuits_handler/common.rs @@ -1,10 +1,9 @@ use std::{collections::BTreeMap, rc::Rc}; -use crate::types::ProverType; - use once_cell::sync::OnceCell; use halo2_proofs::{halo2curves::bn256::Bn256, poly::kzg::commitment::ParamsKZG}; +use scroll_proving_sdk::prover::ProofType; static mut PARAMS_MAP: OnceCell>>> = OnceCell::new(); @@ -20,9 +19,9 @@ where } } -pub fn get_degrees(prover_types: &std::collections::HashSet, f: F) -> Vec +pub fn get_degrees(prover_types: &std::collections::HashSet, f: F) -> Vec where - F: FnMut(&ProverType) -> Vec, + F: FnMut(&ProofType) -> Vec, { prover_types .iter() diff --git a/prover/src/zk_circuits_handler/darwin.rs b/prover/src/zk_circuits_handler/darwin.rs index 1644dabb4b..2fb8c10a0e 100644 --- a/prover/src/zk_circuits_handler/darwin.rs +++ b/prover/src/zk_circuits_handler/darwin.rs @@ -1,9 +1,8 @@ use super::{common::*, CircuitsHandler}; -use crate::types::ProverType; use anyhow::{bail, Context, Ok, Result}; use async_trait::async_trait; use once_cell::sync::Lazy; -use scroll_proving_sdk::prover::{proving_service::ProveRequest, CircuitType}; +use scroll_proving_sdk::prover::{proving_service::ProveRequest, ProofType}; use serde::Deserialize; use tokio::sync::RwLock; @@ -45,50 +44,55 @@ pub struct DarwinHandler { impl DarwinHandler { pub fn new_multi( - prover_types: Vec, + proof_types: Vec, params_dir: &str, assets_dir: &str, ) -> Result { let class_name = std::intrinsics::type_name::(); - let prover_types_set = prover_types + let proof_types_set = proof_types .into_iter() - .collect::>(); + .collect::>(); let mut handler = Self { batch_prover: None, chunk_prover: None, }; - let degrees: Vec = get_degrees(&prover_types_set, |prover_type| match prover_type { - ProverType::Chunk => ZKEVM_DEGREES.clone(), - ProverType::Batch => AGG_DEGREES.clone(), + let degrees: Vec = get_degrees(&proof_types_set, |proof_type| match proof_type { + ProofType::Chunk => ZKEVM_DEGREES.clone(), + ProofType::Batch => AGG_DEGREES.clone(), + ProofType::Bundle => AGG_DEGREES.clone(), + _ => unreachable!(), }); let params_map = get_params_map_instance(|| { log::info!( "calling get_params_map from {}, prover_types: {:?}, degrees: {:?}", class_name, - prover_types_set, + proof_types_set, degrees ); CommonProver::load_params_map(params_dir, °rees) }); - for prover_type in prover_types_set { - match prover_type { - ProverType::Chunk => { + for proof_type in proof_types_set { + match proof_type { + ProofType::Chunk => { handler.chunk_prover = Some(RwLock::new(ChunkProver::from_params_and_assets( params_map, assets_dir, ))); } - ProverType::Batch => { - handler.batch_prover = Some(RwLock::new(BatchProver::from_params_and_assets( - params_map, assets_dir, - ))) + ProofType::Batch | ProofType::Bundle => { + if handler.batch_prover.is_none() { + handler.batch_prover = Some(RwLock::new( + BatchProver::from_params_and_assets(params_map, assets_dir), + )) + } } + _ => unreachable!(), } } Ok(handler) } - pub fn new(prover_types: Vec, params_dir: &str, assets_dir: &str) -> Result { - Self::new_multi(prover_types, params_dir, assets_dir) + pub fn new(proof_types: Vec, params_dir: &str, assets_dir: &str) -> Result { + Self::new_multi(proof_types, params_dir, assets_dir) } async fn gen_chunk_proof_raw(&self, chunk_trace: Vec) -> Result { @@ -176,17 +180,17 @@ impl DarwinHandler { #[async_trait] impl CircuitsHandler for DarwinHandler { - async fn get_vk(&self, task_type: CircuitType) -> Option> { + async fn get_vk(&self, task_type: ProofType) -> Option> { match task_type { - CircuitType::Chunk => self.chunk_prover.as_ref().unwrap().read().await.get_vk(), - CircuitType::Batch => self + ProofType::Chunk => self.chunk_prover.as_ref().unwrap().read().await.get_vk(), + ProofType::Batch => self .batch_prover .as_ref() .unwrap() .read() .await .get_batch_vk(), - CircuitType::Bundle => self + ProofType::Bundle => self .batch_prover .as_ref() .unwrap() @@ -198,10 +202,10 @@ impl CircuitsHandler for DarwinHandler { } async fn get_proof_data(&self, prove_request: ProveRequest) -> Result { - match prove_request.circuit_type { - CircuitType::Chunk => self.gen_chunk_proof(prove_request).await, - CircuitType::Batch => self.gen_batch_proof(prove_request).await, - CircuitType::Bundle => self.gen_bundle_proof(prove_request).await, + match prove_request.proof_type { + ProofType::Chunk => self.gen_chunk_proof(prove_request).await, + ProofType::Batch => self.gen_batch_proof(prove_request).await, + ProofType::Bundle => self.gen_bundle_proof(prove_request).await, _ => unreachable!(), } } @@ -250,15 +254,15 @@ mod tests { #[tokio::test] async fn test_circuits() -> Result<()> { let bi_handler = DarwinHandler::new_multi( - vec![ProverType::Chunk, ProverType::Batch], + vec![ProofType::Chunk, ProofType::Batch], &PARAMS_PATH, &ASSETS_PATH, )?; let chunk_handler = bi_handler; - let chunk_vk = chunk_handler.get_vk(CircuitType::Chunk).await.unwrap(); + let chunk_vk = chunk_handler.get_vk(ProofType::Chunk).await.unwrap(); - check_vk(CircuitType::Chunk, chunk_vk, "chunk vk must be available"); + check_vk(ProofType::Chunk, chunk_vk, "chunk vk must be available"); let chunk_dir_paths = get_chunk_dir_paths()?; log::info!("chunk_dir_paths, {:?}", chunk_dir_paths); let mut chunk_infos = vec![]; @@ -279,8 +283,8 @@ mod tests { } let batch_handler = chunk_handler; - let batch_vk = batch_handler.get_vk(CircuitType::Batch).await.unwrap(); - check_vk(CircuitType::Batch, batch_vk, "batch vk must be available"); + let batch_vk = batch_handler.get_vk(ProofType::Batch).await.unwrap(); + check_vk(ProofType::Batch, batch_vk, "batch vk must be available"); let batch_task_detail = make_batch_task_detail(chunk_infos, chunk_proofs); log::info!("start to prove batch"); let batch_proof = batch_handler.gen_batch_proof_raw(batch_task_detail).await?; @@ -303,19 +307,19 @@ mod tests { // } } - fn check_vk(proof_type: CircuitType, vk: Vec, info: &str) { + fn check_vk(proof_type: ProofType, vk: Vec, info: &str) { log::info!("check_vk, {:?}", proof_type); let vk_from_file = read_vk(proof_type).unwrap(); assert_eq!(vk_from_file, encode_vk(vk), "{info}") } - fn read_vk(proof_type: CircuitType) -> Result { + fn read_vk(proof_type: ProofType) -> Result { log::info!("read_vk, {:?}", proof_type); let vk_file = match proof_type { - CircuitType::Chunk => CHUNK_VK_PATH.clone(), - CircuitType::Batch => BATCH_VK_PATH.clone(), - CircuitType::Bundle => todo!(), - CircuitType::Undefined => unreachable!(), + ProofType::Chunk => CHUNK_VK_PATH.clone(), + ProofType::Batch => BATCH_VK_PATH.clone(), + ProofType::Bundle => todo!(), + ProofType::Undefined => unreachable!(), }; let data = std::fs::read(vk_file)?; diff --git a/prover/src/zk_circuits_handler/darwin_v2.rs b/prover/src/zk_circuits_handler/darwin_v2.rs index d6e5813ff9..f2ada7b434 100644 --- a/prover/src/zk_circuits_handler/darwin_v2.rs +++ b/prover/src/zk_circuits_handler/darwin_v2.rs @@ -1,9 +1,8 @@ use super::{common::*, CircuitsHandler}; -use crate::types::ProverType; use anyhow::{bail, Context, Ok, Result}; use async_trait::async_trait; use once_cell::sync::Lazy; -use scroll_proving_sdk::prover::{proving_service::ProveRequest, CircuitType}; +use scroll_proving_sdk::prover::{proving_service::ProveRequest, ProofType}; use serde::Deserialize; use tokio::sync::RwLock; @@ -45,50 +44,55 @@ pub struct DarwinV2Handler { impl DarwinV2Handler { pub fn new_multi( - prover_types: Vec, + proof_types: Vec, params_dir: &str, assets_dir: &str, ) -> Result { let class_name = std::intrinsics::type_name::(); - let prover_types_set = prover_types + let proof_types_set = proof_types .into_iter() - .collect::>(); + .collect::>(); let mut handler = Self { batch_prover: None, chunk_prover: None, }; - let degrees: Vec = get_degrees(&prover_types_set, |prover_type| match prover_type { - ProverType::Chunk => ZKEVM_DEGREES.clone(), - ProverType::Batch => AGG_DEGREES.clone(), + let degrees: Vec = get_degrees(&proof_types_set, |prover_type| match prover_type { + ProofType::Chunk => ZKEVM_DEGREES.clone(), + ProofType::Batch => AGG_DEGREES.clone(), + ProofType::Bundle => AGG_DEGREES.clone(), + _ => unreachable!(), }); let params_map = get_params_map_instance(|| { log::info!( "calling get_params_map from {}, prover_types: {:?}, degrees: {:?}", class_name, - prover_types_set, + proof_types_set, degrees ); CommonProver::load_params_map(params_dir, °rees) }); - for prover_type in prover_types_set { - match prover_type { - ProverType::Chunk => { + for proof_type in proof_types_set { + match proof_type { + ProofType::Chunk => { handler.chunk_prover = Some(RwLock::new(ChunkProver::from_params_and_assets( params_map, assets_dir, ))); } - ProverType::Batch => { - handler.batch_prover = Some(RwLock::new(BatchProver::from_params_and_assets( - params_map, assets_dir, - ))) + ProofType::Batch | ProofType::Bundle => { + if handler.batch_prover.is_none() { + handler.batch_prover = Some(RwLock::new( + BatchProver::from_params_and_assets(params_map, assets_dir), + )) + } } + _ => unreachable!(), } } Ok(handler) } - pub fn new(prover_types: Vec, params_dir: &str, assets_dir: &str) -> Result { - Self::new_multi(prover_types, params_dir, assets_dir) + pub fn new(proof_types: Vec, params_dir: &str, assets_dir: &str) -> Result { + Self::new_multi(proof_types, params_dir, assets_dir) } async fn gen_chunk_proof_raw(&self, chunk_trace: Vec) -> Result { @@ -176,17 +180,17 @@ impl DarwinV2Handler { #[async_trait] impl CircuitsHandler for DarwinV2Handler { - async fn get_vk(&self, task_type: CircuitType) -> Option> { + async fn get_vk(&self, task_type: ProofType) -> Option> { match task_type { - CircuitType::Chunk => self.chunk_prover.as_ref().unwrap().read().await.get_vk(), - CircuitType::Batch => self + ProofType::Chunk => self.chunk_prover.as_ref().unwrap().read().await.get_vk(), + ProofType::Batch => self .batch_prover .as_ref() .unwrap() .read() .await .get_batch_vk(), - CircuitType::Bundle => self + ProofType::Bundle => self .batch_prover .as_ref() .unwrap() @@ -198,10 +202,10 @@ impl CircuitsHandler for DarwinV2Handler { } async fn get_proof_data(&self, prove_request: ProveRequest) -> Result { - match prove_request.circuit_type { - CircuitType::Chunk => self.gen_chunk_proof(prove_request).await, - CircuitType::Batch => self.gen_batch_proof(prove_request).await, - CircuitType::Bundle => self.gen_bundle_proof(prove_request).await, + match prove_request.proof_type { + ProofType::Chunk => self.gen_chunk_proof(prove_request).await, + ProofType::Batch => self.gen_batch_proof(prove_request).await, + ProofType::Bundle => self.gen_bundle_proof(prove_request).await, _ => unreachable!(), } } @@ -254,15 +258,15 @@ mod tests { #[tokio::test] async fn test_circuits() -> Result<()> { let bi_handler = DarwinV2Handler::new_multi( - vec![ProverType::Chunk, ProverType::Batch], + vec![ProofType::Chunk, ProofType::Batch], &PARAMS_PATH, &ASSETS_PATH, )?; let chunk_handler = bi_handler; - let chunk_vk = chunk_handler.get_vk(CircuitType::Chunk).await.unwrap(); + let chunk_vk = chunk_handler.get_vk(ProofType::Chunk).await.unwrap(); - check_vk(CircuitType::Chunk, chunk_vk, "chunk vk must be available"); + check_vk(ProofType::Chunk, chunk_vk, "chunk vk must be available"); let chunk_dir_paths = get_chunk_dir_paths()?; log::info!("chunk_dir_paths, {:?}", chunk_dir_paths); let mut chunk_traces = vec![]; @@ -284,8 +288,8 @@ mod tests { } let batch_handler = chunk_handler; - let batch_vk = batch_handler.get_vk(CircuitType::Batch).await.unwrap(); - check_vk(CircuitType::Batch, batch_vk, "batch vk must be available"); + let batch_vk = batch_handler.get_vk(ProofType::Batch).await.unwrap(); + check_vk(ProofType::Batch, batch_vk, "batch vk must be available"); let batch_task_detail = make_batch_task_detail(chunk_traces, chunk_proofs, None); log::info!("start to prove batch"); let batch_proof = batch_handler.gen_batch_proof_raw(batch_task_detail).await?; @@ -361,19 +365,19 @@ mod tests { } } - fn check_vk(proof_type: CircuitType, vk: Vec, info: &str) { + fn check_vk(proof_type: ProofType, vk: Vec, info: &str) { log::info!("check_vk, {:?}", proof_type); let vk_from_file = read_vk(proof_type).unwrap(); assert_eq!(vk_from_file, encode_vk(vk), "{info}") } - fn read_vk(proof_type: CircuitType) -> Result { + fn read_vk(proof_type: ProofType) -> Result { log::info!("read_vk, {:?}", proof_type); let vk_file = match proof_type { - CircuitType::Chunk => CHUNK_VK_PATH.clone(), - CircuitType::Batch => BATCH_VK_PATH.clone(), - CircuitType::Bundle => todo!(), - CircuitType::Undefined => unreachable!(), + ProofType::Chunk => CHUNK_VK_PATH.clone(), + ProofType::Batch => BATCH_VK_PATH.clone(), + ProofType::Bundle => todo!(), + ProofType::Undefined => unreachable!(), }; let data = std::fs::read(vk_file)?;