diff --git a/src/query/executor.rs b/src/query/executor.rs index 788723430..312349008 100644 --- a/src/query/executor.rs +++ b/src/query/executor.rs @@ -1,5 +1,5 @@ use crate::{ - ff::{Field, GaloisField, Serializable}, + ff::{Field, FieldType, Fp32BitPrime, GaloisField, Serializable}, helpers::{ negotiate_prss, query::{QueryConfig, QueryType}, @@ -8,20 +8,28 @@ use crate::{ protocol::{ attribution::input::MCAggregateCreditOutputRow, context::{MaliciousContext, SemiHonestContext}, + prss::Endpoint as PrssEndpoint, step::{self, StepNarrow}, }, - query::runner::IpaRunner, + query::runner::IpaQuery, secret_sharing::{replicated::semi_honest::AdditiveShare, Linear as LinearSecretSharing}, task::JoinHandle, }; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] +use crate::query::runner::execute_test_multiply; use crate::query::runner::QueryResult; +use futures::FutureExt; use generic_array::GenericArray; use rand::rngs::StdRng; use rand_core::SeedableRng; #[cfg(all(feature = "shuttle", test))] use shuttle::future as tokio; -use std::fmt::Debug; +use std::{ + fmt::Debug, + future::{ready, Future}, + pin::Pin, +}; use typenum::Unsigned; pub trait Result: Send + Debug { @@ -63,12 +71,86 @@ where } } -#[allow(unused)] -pub fn start_query( +pub fn execute( config: QueryConfig, gateway: Gateway, input: ByteArrStream, ) -> JoinHandle { + match (config.query_type, config.field_type) { + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + (QueryType::TestMultiply, FieldType::Fp31) => { + do_query(config, gateway, input, |prss, gateway, input| { + Box::pin(execute_test_multiply::( + prss, gateway, input, + )) + }) + } + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + (QueryType::TestMultiply, FieldType::Fp32BitPrime) => { + do_query(config, gateway, input, |prss, gateway, input| { + Box::pin(execute_test_multiply::(prss, gateway, input)) + }) + } + #[cfg(any(test, feature = "weak-field"))] + (QueryType::SemiHonestIpa(ipa_config), FieldType::Fp31) => { + do_query(config, gateway, input, move |prss, gateway, input| { + let ctx = SemiHonestContext::new(prss, gateway); + Box::pin( + IpaQuery::::new(ipa_config) + .execute(ctx, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), + ) + }) + } + (QueryType::SemiHonestIpa(ipa_config), FieldType::Fp32BitPrime) => { + do_query(config, gateway, input, move |prss, gateway, input| { + let ctx = SemiHonestContext::new(prss, gateway); + Box::pin( + IpaQuery::::new(ipa_config) + .execute(ctx, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), + ) + }) + } + #[cfg(any(test, feature = "weak-field"))] + (QueryType::MaliciousIpa(ipa_config), FieldType::Fp31) => { + do_query(config, gateway, input, move |prss, gateway, input| { + let ctx = MaliciousContext::new(prss, gateway); + Box::pin( + IpaQuery::::new(ipa_config) + .execute(ctx, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), + ) + }) + } + (QueryType::MaliciousIpa(ipa_config), FieldType::Fp32BitPrime) => { + do_query(config, gateway, input, move |prss, gateway, input| { + let ctx = MaliciousContext::new(prss, gateway); + Box::pin( + IpaQuery::::new(ipa_config) + .execute(ctx, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), + ) + }) + } + } +} + +pub fn do_query( + config: QueryConfig, + gateway: Gateway, + input: ByteArrStream, + query_impl: F, +) -> JoinHandle +where + F: for<'a> FnOnce( + &'a PrssEndpoint, + &'a Gateway, + ByteArrStream, + ) -> Pin + Send + 'a>> + + Send + + 'static, +{ tokio::spawn(async move { // TODO: make it a generic argument for this function let mut rng = StdRng::from_entropy(); @@ -76,41 +158,17 @@ pub fn start_query( let step = step::Descriptive::default().narrow(&config.query_type); let prss = negotiate_prss(&gateway, &step, &mut rng).await.unwrap(); - match config.query_type { - #[cfg(any(test, feature = "cli", feature = "test-fixture"))] - QueryType::TestMultiply => { - super::runner::TestMultiplyRunner - .run( - SemiHonestContext::new(&prss, &gateway), - config.field_type, - input, - ) - .await - } - QueryType::SemiHonestIpa(ipa_query_config) => { - IpaRunner(ipa_query_config) - .run( - SemiHonestContext::new(&prss, &gateway), - config.field_type, - input, - ) - .await - } - QueryType::MaliciousIpa(ipa_query_config) => Ok(IpaRunner(ipa_query_config) - .malicious_run( - MaliciousContext::new(&prss, &gateway), - config.field_type, - input, - ) - .await), - } + query_impl(&prss, &gateway, input).await }) } #[cfg(all(test, not(feature = "shuttle"), feature = "in-memory-infra"))] mod tests { - use super::*; - use crate::{ff::Fp31, secret_sharing::IntoShares}; + use crate::{ + ff::{Field, Fp31}, + query::ProtocolResult, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + }; #[test] fn serialize_result() { diff --git a/src/query/processor.rs b/src/query/processor.rs index 2fad94683..725602254 100644 --- a/src/query/processor.rs +++ b/src/query/processor.rs @@ -1,4 +1,5 @@ use crate::{ + error::Error as ProtocolError, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, Gateway, GatewayConfig, Role, RoleAssignment, Transport, TransportError, TransportImpl, @@ -85,6 +86,8 @@ pub enum QueryCompletionError { #[from] source: StateError, }, + #[error("query execution failed: {0}")] + ExecutionError(#[from] ProtocolError), } impl Debug for Processor { @@ -201,11 +204,7 @@ impl Processor { ); queries.insert( input.query_id, - QueryState::Running(executor::start_query( - config, - gateway, - input.input_stream, - )), + QueryState::Running(executor::execute(config, gateway, input.input_stream)), ); Ok(()) } else { @@ -242,13 +241,13 @@ impl Processor { match queries.remove(&query_id) { Some(QueryState::Running(handle)) => { queries.insert(query_id, QueryState::AwaitingCompletion); - Ok(CompletionHandle::new( + CompletionHandle::new( RemoveQuery { query_id, queries: &self.queries, }, handle, - )) + ) } Some(state) => { let state_error = StateError::InvalidState { @@ -256,15 +255,15 @@ impl Processor { to: QueryStatus::Running, }; queries.insert(query_id, state); - Err(QueryCompletionError::StateError { + return Err(QueryCompletionError::StateError { source: state_error, - }) + }); } - None => Err(QueryCompletionError::NoSuchQuery(query_id)), + None => return Err(QueryCompletionError::NoSuchQuery(query_id)), } - }?; + }; // release mutex before await - Ok(handle.await.unwrap()) + Ok(handle.await?) } } diff --git a/src/query/runner/ipa.rs b/src/query/runner/ipa.rs index cf27564b8..91067e73b 100644 --- a/src/query/runner/ipa.rs +++ b/src/query/runner/ipa.rs @@ -1,121 +1,74 @@ use crate::{ error::Error, - ff::{FieldType, Fp32BitPrime, GaloisField, PrimeField, Serializable}, + ff::{Gf2, PrimeField, Serializable}, helpers::{query::IpaQueryConfig, ByteArrStream}, protocol::{ - attribution::input::MCAggregateCreditOutputRow, - context::{MaliciousContext, SemiHonestContext}, + attribution::input::{MCAggregateCreditOutputRow, MCCappedCreditsWithAggregationBit}, + basics::Reshare, + boolean::RandomBits, + context::{UpgradableContext, UpgradedContext}, ipa::{ipa, IPAInputRow}, - BreakdownKey, MatchKey, + sort::generate_permutation::ShuffledPermutationWrapper, + BasicProtocols, BreakdownKey, MatchKey, RecordId, + }, + secret_sharing::{ + replicated::{malicious::DowngradeMalicious, semi_honest::AdditiveShare}, + Linear as LinearSecretSharing, }, - query::ProtocolResult, - secret_sharing::replicated::{malicious, semi_honest::AdditiveShare}, }; -use futures_util::StreamExt; -use std::future::Future; +use futures::StreamExt; +use std::marker::PhantomData; use typenum::Unsigned; -pub struct Runner(pub IpaQueryConfig); - -impl Runner { - pub async fn run( - &self, - ctx: SemiHonestContext<'_>, - field: FieldType, - input: ByteArrStream, - ) -> Result, Error> { - Ok(match field { - #[cfg(any(test, feature = "weak-field"))] - FieldType::Fp31 => Box::new( - self.run_internal::(ctx, input) - .await?, - ), - FieldType::Fp32BitPrime => Box::new( - self.run_internal::(ctx, input) - .await?, - ), - }) - } - - // This is intentionally made not async because it does not capture `self`. - fn run_internal<'a, F: PrimeField, MK: GaloisField, BK: GaloisField>( - &self, - ctx: SemiHonestContext<'a>, - input: ByteArrStream, - ) -> impl Future< - Output = std::result::Result< - Vec, BK>>, - Error, - >, - > + 'a - where - IPAInputRow: Serializable, - AdditiveShare: Serializable, - { - let config = self.0; - async move { - let mut input = input.align( as Serializable>::Size::USIZE); - let mut input_vec = Vec::new(); - while let Some(data) = input.next().await { - input_vec.extend(IPAInputRow::::from_byte_slice(&data.unwrap())); - } +pub struct IpaQuery(IpaQueryConfig, PhantomData<(F, C, S)>); - ipa(ctx, input_vec.as_slice(), config).await - } +impl IpaQuery { + pub fn new(config: IpaQueryConfig) -> Self { + Self(config, PhantomData) } +} - pub async fn malicious_run( - &self, - ctx: MaliciousContext<'_>, - field: FieldType, +impl IpaQuery +where + C: UpgradableContext + Send, + C::UpgradedContext: UpgradedContext + RandomBits, + S: LinearSecretSharing + + BasicProtocols, F> + + Reshare, RecordId> + + Serializable + + DowngradeMalicious> + + 'static, + C::UpgradedContext: UpgradedContext, + SB: LinearSecretSharing + + BasicProtocols, Gf2> + + DowngradeMalicious> + + 'static, + F: PrimeField, + IPAInputRow: Serializable, + ShuffledPermutationWrapper>: DowngradeMalicious>, + MCCappedCreditsWithAggregationBit: + DowngradeMalicious>>, + MCAggregateCreditOutputRow: + DowngradeMalicious, BreakdownKey>>, + AdditiveShare: Serializable, +{ + pub async fn execute<'a>( + self, + ctx: C, input: ByteArrStream, - ) -> Box { - match field { - #[cfg(any(test, feature = "weak-field"))] - FieldType::Fp31 => Box::new( - self.malicious_run_internal::(ctx, input) - .await - .expect("IPA query failed"), - ), - FieldType::Fp32BitPrime => Box::new( - self.malicious_run_internal::(ctx, input) - .await - .expect("IPA query failed"), - ), + ) -> Result, BreakdownKey>>, Error> { + let Self(config, _) = self; + + let mut input = + input.align( as Serializable>::Size::USIZE); + let mut input_vec = Vec::new(); + while let Some(data) = input.next().await { + input_vec.extend(IPAInputRow::::from_byte_slice( + &data.unwrap(), + )); } - } - // This is intentionally made not async because it does not capture `self`. - fn malicious_run_internal< - 'a, - F: PrimeField + crate::secret_sharing::replicated::malicious::ExtendableField, - MK: GaloisField, - BK: GaloisField, - >( - &self, - ctx: MaliciousContext<'a>, - input: ByteArrStream, - ) -> impl Future< - Output = std::result::Result< - Vec, BK>>, - Error, - >, - > + 'a - where - IPAInputRow: Serializable, - AdditiveShare: Serializable, - malicious::AdditiveShare: Serializable, - { - let config = self.0; - async move { - let mut input = input.align( as Serializable>::Size::USIZE); - let mut input_vec = Vec::new(); - while let Some(data) = input.next().await { - input_vec.extend(IPAInputRow::::from_byte_slice(&data.unwrap())); - } - - ipa(ctx, input_vec.as_slice(), config).await - } + ipa(ctx, input_vec.as_slice(), config).await } } @@ -176,7 +129,7 @@ mod tests { max_breakdown_key: 3, }; let input = ByteArrStream::from(shares); - Runner(query_config).run_internal::(ctx, input) + IpaQuery::new(query_config).execute(ctx, input) })) .await; @@ -237,7 +190,7 @@ mod tests { max_breakdown_key: 3, }; let input = ByteArrStream::from(shares); - Runner(query_config).malicious_run_internal::(ctx, input) + IpaQuery::new(query_config).execute(ctx, input) })) .await; diff --git a/src/query/runner/mod.rs b/src/query/runner/mod.rs index 7f05f71ef..a2be242b9 100644 --- a/src/query/runner/mod.rs +++ b/src/query/runner/mod.rs @@ -2,9 +2,10 @@ mod ipa; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; -pub(super) use self::ipa::Runner as IpaRunner; use crate::{error::Error, query::ProtocolResult}; + +pub(super) use self::ipa::IpaQuery; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] -pub(super) use test_multiply::Runner as TestMultiplyRunner; +pub(super) use test_multiply::execute_test_multiply; pub(super) type QueryResult = Result, Error>; diff --git a/src/query/runner/test_multiply.rs b/src/query/runner/test_multiply.rs index 616f57c2f..00c33ba69 100644 --- a/src/query/runner/test_multiply.rs +++ b/src/query/runner/test_multiply.rs @@ -1,79 +1,76 @@ use crate::{ error::Error, - ff::{Field, FieldType, Fp32BitPrime, Serializable}, - helpers::{ByteArrStream, TotalRecords}, + ff::{PrimeField, Serializable}, + helpers::{ByteArrStream, Gateway, TotalRecords}, protocol::{ basics::SecureMul, context::{Context, SemiHonestContext}, + prss::Endpoint as PrssEndpoint, RecordId, }, - query::ProtocolResult, + query::runner::QueryResult, secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, }; use futures_util::StreamExt; use typenum::Unsigned; -pub struct Runner; - -impl Runner { - pub async fn run( - &self, - ctx: SemiHonestContext<'_>, - field: FieldType, - input: ByteArrStream, - ) -> Result, Error> { - Ok(match field { - #[cfg(any(test, feature = "weak-field"))] - FieldType::Fp31 => Box::new(self.run_internal::(ctx, input).await?), - FieldType::Fp32BitPrime => { - Box::new(self.run_internal::(ctx, input).await?) - } - }) - } +pub async fn execute_test_multiply<'a, F>( + prss: &'a PrssEndpoint, + gateway: &'a Gateway, + input: ByteArrStream, +) -> QueryResult +where + F: PrimeField, + Replicated: Serializable, +{ + let ctx = SemiHonestContext::new(prss, gateway); + Ok(Box::new( + execute_test_multiply_internal::(ctx, input).await?, + )) +} - async fn run_internal( - &self, - ctx: SemiHonestContext<'_>, - input: ByteArrStream, - ) -> std::result::Result>, Error> - where - Replicated: Serializable, - { - let ctx = ctx.set_total_records(TotalRecords::Indeterminate); +pub async fn execute_test_multiply_internal( + ctx: SemiHonestContext<'_>, + input: ByteArrStream, +) -> Result>, Error> +where + F: PrimeField, + Replicated: Serializable, +{ + let ctx = ctx.set_total_records(TotalRecords::Indeterminate); - let mut input = input.align( as Serializable>::Size::USIZE); - let mut results = Vec::new(); - while let Some(v) = input.next().await { - // multiply pairs - let mut a = None; - let mut record_id = 0_u32; - for share in Replicated::::from_byte_slice(&v.unwrap()) { - match a { - None => a = Some(share), - Some(a_v) => { - let result = a_v - .multiply(&share, ctx.clone(), RecordId::from(record_id)) - .await - .unwrap(); - results.push(result); - record_id += 1; - a = None; - } + let mut input = input.align( as Serializable>::Size::USIZE); + let mut results = Vec::new(); + while let Some(v) = input.next().await { + // multiply pairs + let mut a = None; + let mut record_id = 0_u32; + for share in Replicated::::from_byte_slice(&v.unwrap()) { + match a { + None => a = Some(share), + Some(a_v) => { + let result = a_v + .multiply(&share, ctx.clone(), RecordId::from(record_id)) + .await + .unwrap(); + results.push(result); + record_id += 1; + a = None; } } - - assert!(a.is_none()); } - Ok(results) + assert!(a.is_none()); } + + Ok(results) } #[cfg(all(test, not(feature = "shuttle"), feature = "in-memory-infra"))] mod tests { use super::*; use crate::{ - ff::Fp31, + ff::{Field, Fp31}, secret_sharing::IntoShares, test_fixture::{join3v, Reconstruct, TestWorld}, }; @@ -108,7 +105,7 @@ mod tests { helper_shares .into_iter() .zip(contexts) - .map(|(shares, context)| Runner.run_internal::(context, shares)), + .map(|(shares, context)| execute_test_multiply_internal::(context, shares)), ) .await;