Skip to content

Commit

Permalink
Merge pull request private-attribution#683 from andyleiserson/runner
Browse files Browse the repository at this point in the history
Reduce duplicated code involved in running a query
  • Loading branch information
akoshelev authored Jun 7, 2023
2 parents 43e9fc8 + 158e936 commit 4164900
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 206 deletions.
128 changes: 93 additions & 35 deletions src/query/executor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
ff::{Field, GaloisField, Serializable},
ff::{Field, FieldType, Fp32BitPrime, GaloisField, Serializable},
helpers::{
negotiate_prss,
query::{QueryConfig, QueryType},
Expand All @@ -8,20 +8,28 @@ use crate::{
protocol::{
attribution::input::MCAggregateCreditOutputRow,
context::{MaliciousContext, SemiHonestContext},
prss::Endpoint as PrssEndpoint,
step::{self, StepNarrow},
},
query::runner::IpaRunner,
query::runner::IpaQuery,
secret_sharing::{replicated::semi_honest::AdditiveShare, Linear as LinearSecretSharing},
task::JoinHandle,
};

#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
use crate::query::runner::execute_test_multiply;
use crate::query::runner::QueryResult;
use futures::FutureExt;
use generic_array::GenericArray;
use rand::rngs::StdRng;
use rand_core::SeedableRng;
#[cfg(all(feature = "shuttle", test))]
use shuttle::future as tokio;
use std::fmt::Debug;
use std::{
fmt::Debug,
future::{ready, Future},
pin::Pin,
};
use typenum::Unsigned;

pub trait Result: Send + Debug {
Expand Down Expand Up @@ -63,54 +71,104 @@ where
}
}

#[allow(unused)]
pub fn start_query(
pub fn execute(
config: QueryConfig,
gateway: Gateway,
input: ByteArrStream,
) -> JoinHandle<QueryResult> {
match (config.query_type, config.field_type) {
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
(QueryType::TestMultiply, FieldType::Fp31) => {
do_query(config, gateway, input, |prss, gateway, input| {
Box::pin(execute_test_multiply::<crate::ff::Fp31>(
prss, gateway, input,
))
})
}
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
(QueryType::TestMultiply, FieldType::Fp32BitPrime) => {
do_query(config, gateway, input, |prss, gateway, input| {
Box::pin(execute_test_multiply::<Fp32BitPrime>(prss, gateway, input))
})
}
#[cfg(any(test, feature = "weak-field"))]
(QueryType::SemiHonestIpa(ipa_config), FieldType::Fp31) => {
do_query(config, gateway, input, move |prss, gateway, input| {
let ctx = SemiHonestContext::new(prss, gateway);
Box::pin(
IpaQuery::<crate::ff::Fp31, _, _>::new(ipa_config)
.execute(ctx, input)
.then(|res| ready(res.map(|out| Box::new(out) as Box<dyn Result>))),
)
})
}
(QueryType::SemiHonestIpa(ipa_config), FieldType::Fp32BitPrime) => {
do_query(config, gateway, input, move |prss, gateway, input| {
let ctx = SemiHonestContext::new(prss, gateway);
Box::pin(
IpaQuery::<Fp32BitPrime, _, _>::new(ipa_config)
.execute(ctx, input)
.then(|res| ready(res.map(|out| Box::new(out) as Box<dyn Result>))),
)
})
}
#[cfg(any(test, feature = "weak-field"))]
(QueryType::MaliciousIpa(ipa_config), FieldType::Fp31) => {
do_query(config, gateway, input, move |prss, gateway, input| {
let ctx = MaliciousContext::new(prss, gateway);
Box::pin(
IpaQuery::<crate::ff::Fp31, _, _>::new(ipa_config)
.execute(ctx, input)
.then(|res| ready(res.map(|out| Box::new(out) as Box<dyn Result>))),
)
})
}
(QueryType::MaliciousIpa(ipa_config), FieldType::Fp32BitPrime) => {
do_query(config, gateway, input, move |prss, gateway, input| {
let ctx = MaliciousContext::new(prss, gateway);
Box::pin(
IpaQuery::<Fp32BitPrime, _, _>::new(ipa_config)
.execute(ctx, input)
.then(|res| ready(res.map(|out| Box::new(out) as Box<dyn Result>))),
)
})
}
}
}

pub fn do_query<F>(
config: QueryConfig,
gateway: Gateway,
input: ByteArrStream,
query_impl: F,
) -> JoinHandle<QueryResult>
where
F: for<'a> FnOnce(
&'a PrssEndpoint,
&'a Gateway,
ByteArrStream,
) -> Pin<Box<dyn Future<Output = QueryResult> + Send + 'a>>
+ Send
+ 'static,
{
tokio::spawn(async move {
// TODO: make it a generic argument for this function
let mut rng = StdRng::from_entropy();
// Negotiate PRSS first
let step = step::Descriptive::default().narrow(&config.query_type);
let prss = negotiate_prss(&gateway, &step, &mut rng).await.unwrap();

match config.query_type {
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
QueryType::TestMultiply => {
super::runner::TestMultiplyRunner
.run(
SemiHonestContext::new(&prss, &gateway),
config.field_type,
input,
)
.await
}
QueryType::SemiHonestIpa(ipa_query_config) => {
IpaRunner(ipa_query_config)
.run(
SemiHonestContext::new(&prss, &gateway),
config.field_type,
input,
)
.await
}
QueryType::MaliciousIpa(ipa_query_config) => Ok(IpaRunner(ipa_query_config)
.malicious_run(
MaliciousContext::new(&prss, &gateway),
config.field_type,
input,
)
.await),
}
query_impl(&prss, &gateway, input).await
})
}

#[cfg(all(test, not(feature = "shuttle"), feature = "in-memory-infra"))]
mod tests {
use super::*;
use crate::{ff::Fp31, secret_sharing::IntoShares};
use crate::{
ff::{Field, Fp31},
query::ProtocolResult,
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
};

#[test]
fn serialize_result() {
Expand Down
23 changes: 11 additions & 12 deletions src/query/processor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
error::Error as ProtocolError,
helpers::{
query::{PrepareQuery, QueryConfig, QueryInput},
Gateway, GatewayConfig, Role, RoleAssignment, Transport, TransportError, TransportImpl,
Expand Down Expand Up @@ -85,6 +86,8 @@ pub enum QueryCompletionError {
#[from]
source: StateError,
},
#[error("query execution failed: {0}")]
ExecutionError(#[from] ProtocolError),
}

impl Debug for Processor {
Expand Down Expand Up @@ -201,11 +204,7 @@ impl Processor {
);
queries.insert(
input.query_id,
QueryState::Running(executor::start_query(
config,
gateway,
input.input_stream,
)),
QueryState::Running(executor::execute(config, gateway, input.input_stream)),
);
Ok(())
} else {
Expand Down Expand Up @@ -242,29 +241,29 @@ impl Processor {
match queries.remove(&query_id) {
Some(QueryState::Running(handle)) => {
queries.insert(query_id, QueryState::AwaitingCompletion);
Ok(CompletionHandle::new(
CompletionHandle::new(
RemoveQuery {
query_id,
queries: &self.queries,
},
handle,
))
)
}
Some(state) => {
let state_error = StateError::InvalidState {
from: QueryStatus::from(&state),
to: QueryStatus::Running,
};
queries.insert(query_id, state);
Err(QueryCompletionError::StateError {
return Err(QueryCompletionError::StateError {
source: state_error,
})
});
}
None => Err(QueryCompletionError::NoSuchQuery(query_id)),
None => return Err(QueryCompletionError::NoSuchQuery(query_id)),
}
}?;
}; // release mutex before await

Ok(handle.await.unwrap())
Ok(handle.await?)
}
}

Expand Down
Loading

0 comments on commit 4164900

Please sign in to comment.