Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Event generation using test_mpc and IPA integration test + migrate to HTTP2 #647

Merged
merged 13 commits into from
May 22, 2023
37 changes: 15 additions & 22 deletions benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ use ipa::{
ff::Fp32BitPrime,
helpers::{query::IpaQueryConfig, GatewayConfig},
test_fixture::{
ipa::{
generate_random_user_records_in_reverse_chronological_order, ipa_in_the_clear,
test_ipa, IpaSecurityModel,
},
TestWorld, TestWorldConfig,
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
EventGenerator, EventGeneratorConfig, TestWorld, TestWorldConfig,
},
};
use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};
Expand Down Expand Up @@ -38,7 +35,7 @@ struct Args {
query_size: usize,
/// The maximum number of records for each person.
#[arg(short = 'u', long, default_value = "50")]
records_per_user: usize,
records_per_user: u32,
/// The contribution cap for each person.
#[arg(short = 'c', long, default_value = "3")]
per_user_cap: u32,
Expand Down Expand Up @@ -109,23 +106,19 @@ async fn run(args: Args) -> Result<(), Error> {
"Using random seed: {seed} for {q} records",
q = args.query_size
);
let mut rng = StdRng::seed_from_u64(seed);

let mut raw_data = Vec::with_capacity(args.query_size + args.records_per_user);
while raw_data.len() < args.query_size {
let mut records_for_user = generate_random_user_records_in_reverse_chronological_order(
&mut rng,
args.records_per_user,
args.breakdown_keys,
args.max_trigger_value,
);
records_for_user.truncate(args.query_size - raw_data.len());
raw_data.append(&mut records_for_user);
}
let rng = StdRng::seed_from_u64(seed);
let raw_data = EventGenerator::with_config(
rng,
EventGeneratorConfig {
max_trigger_value: NonZeroU32::try_from(args.max_trigger_value).unwrap(),
max_breakdown_key: NonZeroU32::try_from(args.breakdown_keys).unwrap(),
max_events_per_user: NonZeroU32::try_from(args.records_per_user).unwrap(),
..Default::default()
},
)
.take(args.query_size)
.collect::<Vec<_>>();

// Sort the records in chronological order
// This is part of the IPA spec. Callers should do this before sending a batch of records in for processing.
raw_data.sort_unstable_by(|a, b| a.timestamp.cmp(&b.timestamp));
let expected_results =
ipa_in_the_clear(&raw_data, args.per_user_cap, args.attribution_window());

Expand Down
139 changes: 105 additions & 34 deletions src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hyper::http::uri::Scheme;
use ipa::{
cli::{
playbook::{secure_mul, semi_honest, InputSource},
Verbosity,
CsvSerializer, Verbosity,
},
config::NetworkConfig,
ff::{Field, FieldType, Fp31, Fp32BitPrime, Serializable},
Expand All @@ -16,9 +16,21 @@ use ipa::{
test_fixture::{
config::TestConfigBuilder,
ipa::{ipa_in_the_clear, TestRawDataRecord},
EventGenerator, EventGeneratorConfig,
},
};
use std::{error::Error, fmt::Debug, fs, ops::Add, path::PathBuf, time::Duration};
use rand::thread_rng;
use std::{
error::Error,
fmt::Debug,
fs,
fs::OpenOptions,
io,
io::{stdout, Write},
ops::Add,
path::PathBuf,
time::Duration,
};
use tokio::time::sleep;

#[derive(Debug, Parser)]
Expand Down Expand Up @@ -77,7 +89,29 @@ enum TestAction {
/// Execute end-to-end multiplication.
Multiply,
/// Execute IPA in semi-honest majority setting
SemiHonestIPA,
SemiHonestIpa(IpaQueryConfig),
/// Generate inputs for IPA
GenIpaInputs {
/// Number of records to generate
#[clap(long, short = 'c')]
akoshelev marked this conversation as resolved.
Show resolved Hide resolved
count: u32,

/// The destination file for generated records
#[arg(long)]
output_file: Option<PathBuf>,

#[clap(flatten)]
gen_args: EventGeneratorConfig,
},
}

#[derive(Debug, clap::Args)]
struct GenInputArgs {
/// Maximum records per user
#[clap(long)]
max_per_user: u32,
/// number of breakdowns
breakdowns: u32,
}

async fn clients_ready(clients: &[MpcHelperClient; 3]) -> bool {
Expand Down Expand Up @@ -124,7 +158,7 @@ where
i += 1;
}

tracing::info!("{table}");
tracing::info!("\n{table}\n");

assert!(
mismatch.is_empty(),
Expand All @@ -135,61 +169,98 @@ where

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
fn make_clients(disable_https: bool, config_path: Option<&PathBuf>) -> [MpcHelperClient; 3] {
let scheme = if disable_https {
let args = Args::parse();
let _handle = args.logging.setup_logging();

let make_clients = || async {
let scheme = if args.disable_https {
Scheme::HTTP
} else {
Scheme::HTTPS
};

let config_path = args.network.as_deref();
let mut wait = args.wait;

let config = if let Some(path) = config_path {
NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap()
} else {
TestConfigBuilder::with_default_test_ports().build().network
}
.override_scheme(&scheme);
MpcHelperClient::from_conf(&config)
}
let clients = MpcHelperClient::from_conf(&config);
while wait > 0 && !clients_ready(&clients).await {
tracing::debug!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}

let args = Args::parse();
let _handle = args.logging.setup_logging();
clients
};

let clients = make_clients(args.disable_https, args.network.as_ref());
match args.action {
TestAction::Multiply => multiply(&args, &make_clients().await).await,
TestAction::SemiHonestIpa(config) => {
semi_honest_ipa(&args, &config, &make_clients().await).await
}
TestAction::GenIpaInputs {
count,
output_file,
gen_args,
} => gen_inputs(count, output_file, gen_args).unwrap(),
};

let mut wait = args.wait;
while wait > 0 && !clients_ready(&clients).await {
println!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}
Ok(())
}

match args.action {
TestAction::Multiply => multiply(args, &clients).await,
TestAction::SemiHonestIPA => semi_honest_ipa(args, &clients).await,
fn gen_inputs(
count: u32,
output_file: Option<PathBuf>,
args: EventGeneratorConfig,
) -> io::Result<()> {
let event_gen = EventGenerator::with_config(thread_rng(), args).take(count as usize);
let mut writer: Box<dyn Write> = if let Some(path) = output_file {
Box::new(OpenOptions::new().write(true).create_new(true).open(path)?)
} else {
Box::new(stdout().lock())
};

for event in event_gen {
event.to_csv(&mut writer)?;
writer.write(&[b'\n'])?;
}

Ok(())
}

async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn semi_honest_ipa(
args: &Args,
ipa_query_config: &IpaQueryConfig,
helper_clients: &[MpcHelperClient; 3],
) {
let input = InputSource::from(&args.input);
let ipa_query_config = IpaQueryConfig {
per_user_credit_cap: 3,
max_breakdown_key: 3,
num_multi_bits: 3,
attribution_window_seconds: None,
};
let query_type = QueryType::Ipa(ipa_query_config.clone());
let query_config = QueryConfig {
field_type: args.input.field,
query_type,
};
let query_id = helper_clients[0].create_query(query_config).await.unwrap();
let input_rows = input.iter::<TestRawDataRecord>().collect::<Vec<_>>();
let expected = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);
let expected = {
let mut r = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);

// pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results
// truncate shouldn't happen unless in_the_clear is badly broken
r.resize(
usize::try_from(ipa_query_config.max_breakdown_key).unwrap(),
0,
);
r
};

let actual = match args.input.field {
FieldType::Fp31 => {
Expand All @@ -210,7 +281,7 @@ async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
}

async fn multiply_in_field<F: Field>(
args: Args,
args: &Args,
helper_clients: &[MpcHelperClient; 3],
query_id: QueryId,
) where
Expand All @@ -226,7 +297,7 @@ async fn multiply_in_field<F: Field>(
validate(expected, actual);
}

async fn multiply(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn multiply(args: &Args, helper_clients: &[MpcHelperClient; 3]) {
let query_config = QueryConfig {
field_type: args.input.field,
query_type: QueryType::TestMultiply,
Expand Down
22 changes: 22 additions & 0 deletions src/cli/csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::{io, io::Write};

pub trait CsvSerializer {
akoshelev marked this conversation as resolved.
Show resolved Hide resolved
/// Converts self into a CSV-encoded byte string
/// ## Errors
/// If this conversion fails due to insufficient capacity in `buf` or other reasons.
fn to_csv<W: Write>(&self, buf: &mut W) -> io::Result<()>;
}

#[cfg(any(test, feature = "test-fixture"))]
impl CsvSerializer for crate::test_fixture::ipa::TestRawDataRecord {
fn to_csv<W: Write>(&self, buf: &mut W) -> io::Result<()> {
// fmt::write is cool because it does not allocate when serializing integers
write!(buf, "{},", self.timestamp)?;
write!(buf, "{},", self.user_id)?;
write!(buf, "{},", u8::from(self.is_trigger_report))?;
write!(buf, "{},", self.breakdown_key)?;
write!(buf, "{}", self.trigger_value)?;

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod csv;
#[cfg(feature = "web-app")]
mod keygen;
mod metric_collector;
Expand All @@ -7,6 +8,7 @@ pub mod playbook;
mod test_setup;
mod verbosity;

pub use csv::CsvSerializer;
#[cfg(feature = "web-app")]
pub use keygen::{keygen, KeygenArgs};
pub use metric_collector::{install_collector, CollectorHandle};
Expand Down
5 changes: 5 additions & 0 deletions src/helpers/transport/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,15 @@ impl Step for QueryType {}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct IpaQueryConfig {
#[cfg_attr(feature = "clap", arg(long, default_value = "5"))]
pub per_user_credit_cap: u32,
#[cfg_attr(feature = "clap", arg(long, default_value = "5"))]
pub max_breakdown_key: u32,
#[cfg_attr(feature = "clap", arg(long))]
pub attribution_window_seconds: Option<NonZeroU32>,
#[cfg_attr(feature = "clap", arg(long, default_value = "3"))]
pub num_multi_bits: u32,
}

Expand Down
38 changes: 15 additions & 23 deletions src/protocol/ipa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,8 @@ pub mod tests {
},
test_fixture::{
input::GenericReportTestInput,
ipa::{
generate_random_user_records_in_reverse_chronological_order, ipa_in_the_clear,
test_ipa, IpaSecurityModel,
},
Reconstruct, Runner, TestWorld, TestWorldConfig,
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
EventGenerator, EventGeneratorConfig, Reconstruct, Runner, TestWorld, TestWorldConfig,
},
};
use generic_array::GenericArray;
Expand Down Expand Up @@ -878,31 +875,26 @@ pub mod tests {
pub async fn random_ipa_check() {
const MAX_BREAKDOWN_KEY: u32 = 64;
const MAX_TRIGGER_VALUE: u32 = 5;
const NUM_USERS: usize = 8;
const MAX_RECORDS_PER_USER: usize = 8;
const NUM_USERS: u64 = 8;
const MAX_RECORDS_PER_USER: u32 = 8;
const NUM_MULTI_BITS: u32 = 3;
const ATTRIBUTION_WINDOW_SECONDS: Option<NonZeroU32> = NonZeroU32::new(86_400);
type TestField = Fp32BitPrime;

let random_seed = thread_rng().gen();
println!("Using random seed: {random_seed}");
let mut rng = StdRng::seed_from_u64(random_seed);

let mut random_user_records = Vec::with_capacity(NUM_USERS);
for _ in 0..NUM_USERS {
let records_for_user = generate_random_user_records_in_reverse_chronological_order(
&mut rng,
MAX_RECORDS_PER_USER,
MAX_BREAKDOWN_KEY,
let rng = StdRng::seed_from_u64(random_seed);
let raw_data = EventGenerator::with_config(
rng,
EventGeneratorConfig::new(
NUM_USERS,
MAX_TRIGGER_VALUE,
);
random_user_records.push(records_for_user);
}
let mut raw_data = random_user_records.concat();

// Sort the records in chronological order
// This is part of the IPA spec. Callers should do this before sending a batch of records in for processing.
raw_data.sort_unstable_by(|a, b| a.timestamp.cmp(&b.timestamp));
MAX_BREAKDOWN_KEY,
MAX_RECORDS_PER_USER,
),
)
.take(usize::try_from(NUM_USERS * u64::from(MAX_RECORDS_PER_USER)).unwrap())
.collect::<Vec<_>>();

let config = TestWorldConfig {
gateway_config: GatewayConfig::new(raw_data.len().clamp(4, 1024)),
Expand Down
Loading