Skip to content

Commit

Permalink
Merge pull request #647 from akoshelev/test-mpc-inputs
Browse files Browse the repository at this point in the history
Event generation using test_mpc and IPA integration test + migrate to HTTP2
  • Loading branch information
akoshelev authored May 22, 2023
2 parents 5b04442 + 4f2ae94 commit 8e625e6
Show file tree
Hide file tree
Showing 18 changed files with 803 additions and 191 deletions.
37 changes: 15 additions & 22 deletions benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ use ipa::{
ff::Fp32BitPrime,
helpers::{query::IpaQueryConfig, GatewayConfig},
test_fixture::{
ipa::{
generate_random_user_records_in_reverse_chronological_order, ipa_in_the_clear,
test_ipa, IpaSecurityModel,
},
TestWorld, TestWorldConfig,
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
EventGenerator, EventGeneratorConfig, TestWorld, TestWorldConfig,
},
};
use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};
Expand Down Expand Up @@ -38,7 +35,7 @@ struct Args {
query_size: usize,
/// The maximum number of records for each person.
#[arg(short = 'u', long, default_value = "50")]
records_per_user: usize,
records_per_user: u32,
/// The contribution cap for each person.
#[arg(short = 'c', long, default_value = "3")]
per_user_cap: u32,
Expand Down Expand Up @@ -109,23 +106,19 @@ async fn run(args: Args) -> Result<(), Error> {
"Using random seed: {seed} for {q} records",
q = args.query_size
);
let mut rng = StdRng::seed_from_u64(seed);

let mut raw_data = Vec::with_capacity(args.query_size + args.records_per_user);
while raw_data.len() < args.query_size {
let mut records_for_user = generate_random_user_records_in_reverse_chronological_order(
&mut rng,
args.records_per_user,
args.breakdown_keys,
args.max_trigger_value,
);
records_for_user.truncate(args.query_size - raw_data.len());
raw_data.append(&mut records_for_user);
}
let rng = StdRng::seed_from_u64(seed);
let raw_data = EventGenerator::with_config(
rng,
EventGeneratorConfig {
max_trigger_value: NonZeroU32::try_from(args.max_trigger_value).unwrap(),
max_breakdown_key: NonZeroU32::try_from(args.breakdown_keys).unwrap(),
max_events_per_user: NonZeroU32::try_from(args.records_per_user).unwrap(),
..Default::default()
},
)
.take(args.query_size)
.collect::<Vec<_>>();

// Sort the records in chronological order
// This is part of the IPA spec. Callers should do this before sending a batch of records in for processing.
raw_data.sort_unstable_by(|a, b| a.timestamp.cmp(&b.timestamp));
let expected_results =
ipa_in_the_clear(&raw_data, args.per_user_cap, args.attribution_window());

Expand Down
139 changes: 105 additions & 34 deletions src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hyper::http::uri::Scheme;
use ipa::{
cli::{
playbook::{secure_mul, semi_honest, InputSource},
Verbosity,
CsvSerializer, Verbosity,
},
config::NetworkConfig,
ff::{Field, FieldType, Fp31, Fp32BitPrime, Serializable},
Expand All @@ -16,9 +16,21 @@ use ipa::{
test_fixture::{
config::TestConfigBuilder,
ipa::{ipa_in_the_clear, TestRawDataRecord},
EventGenerator, EventGeneratorConfig,
},
};
use std::{error::Error, fmt::Debug, fs, ops::Add, path::PathBuf, time::Duration};
use rand::thread_rng;
use std::{
error::Error,
fmt::Debug,
fs,
fs::OpenOptions,
io,
io::{stdout, Write},
ops::Add,
path::PathBuf,
time::Duration,
};
use tokio::time::sleep;

#[derive(Debug, Parser)]
Expand Down Expand Up @@ -77,7 +89,29 @@ enum TestAction {
/// Execute end-to-end multiplication.
Multiply,
/// Execute IPA in semi-honest majority setting
SemiHonestIPA,
SemiHonestIpa(IpaQueryConfig),
/// Generate inputs for IPA
GenIpaInputs {
/// Number of records to generate
#[clap(long, short = 'n')]
count: u32,

/// The destination file for generated records
#[arg(long)]
output_file: Option<PathBuf>,

#[clap(flatten)]
gen_args: EventGeneratorConfig,
},
}

#[derive(Debug, clap::Args)]
struct GenInputArgs {
/// Maximum records per user
#[clap(long)]
max_per_user: u32,
/// number of breakdowns
breakdowns: u32,
}

async fn clients_ready(clients: &[MpcHelperClient; 3]) -> bool {
Expand Down Expand Up @@ -124,7 +158,7 @@ where
i += 1;
}

tracing::info!("{table}");
tracing::info!("\n{table}\n");

assert!(
mismatch.is_empty(),
Expand All @@ -135,61 +169,98 @@ where

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
fn make_clients(disable_https: bool, config_path: Option<&PathBuf>) -> [MpcHelperClient; 3] {
let scheme = if disable_https {
let args = Args::parse();
let _handle = args.logging.setup_logging();

let make_clients = || async {
let scheme = if args.disable_https {
Scheme::HTTP
} else {
Scheme::HTTPS
};

let config_path = args.network.as_deref();
let mut wait = args.wait;

let config = if let Some(path) = config_path {
NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap()
} else {
TestConfigBuilder::with_default_test_ports().build().network
}
.override_scheme(&scheme);
MpcHelperClient::from_conf(&config)
}
let clients = MpcHelperClient::from_conf(&config);
while wait > 0 && !clients_ready(&clients).await {
tracing::debug!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}

let args = Args::parse();
let _handle = args.logging.setup_logging();
clients
};

let clients = make_clients(args.disable_https, args.network.as_ref());
match args.action {
TestAction::Multiply => multiply(&args, &make_clients().await).await,
TestAction::SemiHonestIpa(config) => {
semi_honest_ipa(&args, &config, &make_clients().await).await
}
TestAction::GenIpaInputs {
count,
output_file,
gen_args,
} => gen_inputs(count, output_file, gen_args).unwrap(),
};

let mut wait = args.wait;
while wait > 0 && !clients_ready(&clients).await {
println!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}
Ok(())
}

match args.action {
TestAction::Multiply => multiply(args, &clients).await,
TestAction::SemiHonestIPA => semi_honest_ipa(args, &clients).await,
fn gen_inputs(
count: u32,
output_file: Option<PathBuf>,
args: EventGeneratorConfig,
) -> io::Result<()> {
let event_gen = EventGenerator::with_config(thread_rng(), args).take(count as usize);
let mut writer: Box<dyn Write> = if let Some(path) = output_file {
Box::new(OpenOptions::new().write(true).create_new(true).open(path)?)
} else {
Box::new(stdout().lock())
};

for event in event_gen {
event.to_csv(&mut writer)?;
writer.write(&[b'\n'])?;
}

Ok(())
}

async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn semi_honest_ipa(
args: &Args,
ipa_query_config: &IpaQueryConfig,
helper_clients: &[MpcHelperClient; 3],
) {
let input = InputSource::from(&args.input);
let ipa_query_config = IpaQueryConfig {
per_user_credit_cap: 3,
max_breakdown_key: 3,
num_multi_bits: 3,
attribution_window_seconds: None,
};
let query_type = QueryType::Ipa(ipa_query_config.clone());
let query_config = QueryConfig {
field_type: args.input.field,
query_type,
};
let query_id = helper_clients[0].create_query(query_config).await.unwrap();
let input_rows = input.iter::<TestRawDataRecord>().collect::<Vec<_>>();
let expected = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);
let expected = {
let mut r = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);

// pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results
// truncate shouldn't happen unless in_the_clear is badly broken
r.resize(
usize::try_from(ipa_query_config.max_breakdown_key).unwrap(),
0,
);
r
};

let actual = match args.input.field {
FieldType::Fp31 => {
Expand All @@ -210,7 +281,7 @@ async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
}

async fn multiply_in_field<F: Field>(
args: Args,
args: &Args,
helper_clients: &[MpcHelperClient; 3],
query_id: QueryId,
) where
Expand All @@ -226,7 +297,7 @@ async fn multiply_in_field<F: Field>(
validate(expected, actual);
}

async fn multiply(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn multiply(args: &Args, helper_clients: &[MpcHelperClient; 3]) {
let query_config = QueryConfig {
field_type: args.input.field,
query_type: QueryType::TestMultiply,
Expand Down
22 changes: 22 additions & 0 deletions src/cli/csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::{io, io::Write};

pub trait Serializer {
/// Converts self into a CSV-encoded byte string
/// ## Errors
/// If this conversion fails due to insufficient capacity in `buf` or other reasons.
fn to_csv<W: Write>(&self, buf: &mut W) -> io::Result<()>;
}

#[cfg(any(test, feature = "test-fixture"))]
impl Serializer for crate::test_fixture::ipa::TestRawDataRecord {
fn to_csv<W: Write>(&self, buf: &mut W) -> io::Result<()> {
// fmt::write is cool because it does not allocate when serializing integers
write!(buf, "{},", self.timestamp)?;
write!(buf, "{},", self.user_id)?;
write!(buf, "{},", u8::from(self.is_trigger_report))?;
write!(buf, "{},", self.breakdown_key)?;
write!(buf, "{}", self.trigger_value)?;

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod csv;
#[cfg(feature = "web-app")]
mod keygen;
mod metric_collector;
Expand All @@ -7,6 +8,7 @@ pub mod playbook;
mod test_setup;
mod verbosity;

pub use csv::Serializer as CsvSerializer;
#[cfg(feature = "web-app")]
pub use keygen::{keygen, KeygenArgs};
pub use metric_collector::{install_collector, CollectorHandle};
Expand Down
16 changes: 14 additions & 2 deletions src/cli/test_setup.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
cli::{keygen, KeygenArgs},
config::{NetworkConfig, PeerConfig},
config::{ClientConfig, NetworkConfig, PeerConfig},
};
use clap::Args;
use std::{
Expand All @@ -24,6 +24,10 @@ pub struct TestSetupArgs {
#[arg(long)]
disable_https: bool,

/// Configure helper clients to use HTTP1 instead of default HTTP version (HTTP2 at the moment).
#[arg(long, default_value_t = false)]
use_http1: bool,

#[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])]
ports: Vec<u16>,
}
Expand Down Expand Up @@ -66,7 +70,15 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), Box<dyn Error>> {
.try_into()
.unwrap();

let network_config = toml::to_string_pretty(&NetworkConfig { peers })?;
let client_config = if args.use_http1 {
ClientConfig::use_http1()
} else {
ClientConfig::default()
};
let network_config = toml::to_string_pretty(&NetworkConfig {
peers,
client: client_config,
})?;

fs::write(args.output_dir.join("network.toml"), network_config)?;

Expand Down
Loading

0 comments on commit 8e625e6

Please sign in to comment.