Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into mutual-tls
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed May 23, 2023
2 parents aeb578d + 8e625e6 commit aa9849d
Show file tree
Hide file tree
Showing 20 changed files with 902 additions and 216 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ real-world-infra = []
dhat-heap = ["cli", "test-fixture"]
# Enable this feature to enable our colossally weak Fp31.
weak-field = []
step-trace = []

[dependencies]
aes = "0.8"
Expand Down Expand Up @@ -98,6 +99,7 @@ dhat = "0.3.2"
tikv-jemallocator = "0.5.0"

[dev-dependencies]
command-fds = "0.2.2"
permutation = "0.4.1"
proptest = "1.0.0"
tempfile = "3"
Expand Down
37 changes: 15 additions & 22 deletions benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ use ipa::{
ff::Fp32BitPrime,
helpers::{query::IpaQueryConfig, GatewayConfig},
test_fixture::{
ipa::{
generate_random_user_records_in_reverse_chronological_order, ipa_in_the_clear,
test_ipa, IpaSecurityModel,
},
TestWorld, TestWorldConfig,
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
EventGenerator, EventGeneratorConfig, TestWorld, TestWorldConfig,
},
};
use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};
Expand Down Expand Up @@ -38,7 +35,7 @@ struct Args {
query_size: usize,
/// The maximum number of records for each person.
#[arg(short = 'u', long, default_value = "50")]
records_per_user: usize,
records_per_user: u32,
/// The contribution cap for each person.
#[arg(short = 'c', long, default_value = "3")]
per_user_cap: u32,
Expand Down Expand Up @@ -109,23 +106,19 @@ async fn run(args: Args) -> Result<(), Error> {
"Using random seed: {seed} for {q} records",
q = args.query_size
);
let mut rng = StdRng::seed_from_u64(seed);

let mut raw_data = Vec::with_capacity(args.query_size + args.records_per_user);
while raw_data.len() < args.query_size {
let mut records_for_user = generate_random_user_records_in_reverse_chronological_order(
&mut rng,
args.records_per_user,
args.breakdown_keys,
args.max_trigger_value,
);
records_for_user.truncate(args.query_size - raw_data.len());
raw_data.append(&mut records_for_user);
}
let rng = StdRng::seed_from_u64(seed);
let raw_data = EventGenerator::with_config(
rng,
EventGeneratorConfig {
max_trigger_value: NonZeroU32::try_from(args.max_trigger_value).unwrap(),
max_breakdown_key: NonZeroU32::try_from(args.breakdown_keys).unwrap(),
max_events_per_user: NonZeroU32::try_from(args.records_per_user).unwrap(),
..Default::default()
},
)
.take(args.query_size)
.collect::<Vec<_>>();

// Sort the records in chronological order
// This is part of the IPA spec. Callers should do this before sending a batch of records in for processing.
raw_data.sort_unstable_by(|a, b| a.timestamp.cmp(&b.timestamp));
let expected_results =
ipa_in_the_clear(&raw_data, args.per_user_cap, args.attribution_window());

Expand Down
44 changes: 40 additions & 4 deletions src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ use ipa::{
net::{ClientIdentity, HttpTransport, MpcHelperClient},
AppSetup,
};
use std::{error::Error, fs, path::PathBuf};
use std::{
error::Error,
fs,
net::TcpListener,
os::fd::{FromRawFd, RawFd},
path::PathBuf,
process,
};
use tracing::{error, info};

#[cfg(not(target_env = "msvc"))]
#[global_allocator]
Expand Down Expand Up @@ -93,6 +101,12 @@ struct ServerArgs {
#[arg(short, long, default_value = "3000")]
port: Option<u16>,

/// Use the supplied prebound socket instead of binding a new socket
///
/// This is only intended for avoiding port conflicts in tests.
#[arg(hide = true, long)]
server_socket_fd: Option<RawFd>,

/// Use insecure HTTP
#[arg(short = 'k', long)]
disable_https: bool,
Expand Down Expand Up @@ -172,8 +186,25 @@ async fn server(args: ServerArgs) -> Result<(), Box<dyn Error>> {

let _app = setup.connect(transport.clone());

let listener = args.server_socket_fd
.map(|fd| {
// SAFETY:
// 1. The `--server-socket-fd` option is only intended for use in tests, not in production.
// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has
// only one owner.
let listener = unsafe { TcpListener::from_raw_fd(fd) };
if listener.local_addr().is_ok() {
info!("adopting fd {fd} as listening socket");
Ok(listener)
} else {
Err(Box::<dyn Error>::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket")))
}
})
.transpose()?;

let (_addr, server_handle) = server
.start(
.start_on(
listener,
// TODO, trace based on the content of the query.
None as Option<()>,
)
Expand All @@ -185,13 +216,18 @@ async fn server(args: ServerArgs) -> Result<(), Box<dyn Error>> {
}

#[tokio::main]
pub async fn main() -> Result<(), Box<dyn Error>> {
pub async fn main() {
let args = Args::parse();
let _handle = args.logging.setup_logging();

match args.command {
let res = match args.command {
None => server(args.server).await,
Some(HelperCommand::Keygen(args)) => keygen(args),
Some(HelperCommand::TestSetup(args)) => test_setup(args),
};

if let Err(e) = res {
error!("{e}");
process::exit(1);
}
}
146 changes: 110 additions & 36 deletions src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,31 @@ use hyper::http::uri::Scheme;
use ipa::{
cli::{
playbook::{secure_mul, semi_honest, InputSource},
Verbosity,
CsvSerializer, Verbosity,
},
config::{NetworkConfig, PeerConfig},
config::{ClientConfig, NetworkConfig, PeerConfig},
ff::{Field, FieldType, Fp31, Fp32BitPrime, Serializable},
helpers::query::{IpaQueryConfig, QueryConfig, QueryType},
net::{ClientIdentity, MpcHelperClient},
protocol::{BreakdownKey, MatchKey, QueryId},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_fixture::ipa::{ipa_in_the_clear, TestRawDataRecord},
test_fixture::{
ipa::{ipa_in_the_clear, TestRawDataRecord},
EventGenerator, EventGeneratorConfig,
},
};
use rand::thread_rng;
use std::{
error::Error,
fmt::Debug,
fs,
fs::OpenOptions,
io,
io::{stdout, Write},
ops::Add,
path::PathBuf,
time::Duration,
};
use std::{error::Error, fmt::Debug, fs, ops::Add, path::PathBuf, time::Duration};
use tokio::time::sleep;

#[derive(Debug, Parser)]
Expand Down Expand Up @@ -74,7 +88,29 @@ enum TestAction {
/// Execute end-to-end multiplication.
Multiply,
/// Execute IPA in semi-honest majority setting
SemiHonestIPA,
SemiHonestIpa(IpaQueryConfig),
/// Generate inputs for IPA
GenIpaInputs {
/// Number of records to generate
#[clap(long, short = 'n')]
count: u32,

/// The destination file for generated records
#[arg(long)]
output_file: Option<PathBuf>,

#[clap(flatten)]
gen_args: EventGeneratorConfig,
},
}

#[derive(Debug, clap::Args)]
struct GenInputArgs {
/// Maximum records per user
#[clap(long)]
max_per_user: u32,
/// number of breakdowns
breakdowns: u32,
}

async fn clients_ready(clients: &[MpcHelperClient; 3]) -> bool {
Expand Down Expand Up @@ -121,7 +157,7 @@ where
i += 1;
}

tracing::info!("{table}");
tracing::info!("\n{table}\n");

assert!(
mismatch.is_empty(),
Expand All @@ -132,12 +168,19 @@ where

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
fn make_clients(disable_https: bool, config_path: Option<&PathBuf>) -> [MpcHelperClient; 3] {
let scheme = if disable_https {
let args = Args::parse();
let _handle = args.logging.setup_logging();

let make_clients = || async {
let scheme = if args.disable_https {
Scheme::HTTP
} else {
Scheme::HTTPS
};

let config_path = args.network.as_deref();
let mut wait = args.wait;

let config = if let Some(path) = config_path {
NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap()
} else {
Expand All @@ -147,52 +190,83 @@ async fn main() -> Result<(), Box<dyn Error>> {
PeerConfig::new("localhost:3001".parse().unwrap(), None),
PeerConfig::new("localhost:3002".parse().unwrap(), None),
],
client: ClientConfig::default(),
}
}
.override_scheme(&scheme);
MpcHelperClient::from_conf(&config, ClientIdentity::None)
}
let clients = MpcHelperClient::from_conf(&config, ClientIdentity::None);
while wait > 0 && !clients_ready(&clients).await {
tracing::debug!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}

let args = Args::parse();
let _handle = args.logging.setup_logging();
clients
};

let clients = make_clients(args.disable_https, args.network.as_ref());
match args.action {
TestAction::Multiply => multiply(&args, &make_clients().await).await,
TestAction::SemiHonestIpa(config) => {
semi_honest_ipa(&args, &config, &make_clients().await).await
}
TestAction::GenIpaInputs {
count,
output_file,
gen_args,
} => gen_inputs(count, output_file, gen_args).unwrap(),
};

let mut wait = args.wait;
while wait > 0 && !clients_ready(&clients).await {
println!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}
Ok(())
}

match args.action {
TestAction::Multiply => multiply(args, &clients).await,
TestAction::SemiHonestIPA => semi_honest_ipa(args, &clients).await,
fn gen_inputs(
count: u32,
output_file: Option<PathBuf>,
args: EventGeneratorConfig,
) -> io::Result<()> {
let event_gen = EventGenerator::with_config(thread_rng(), args).take(count as usize);
let mut writer: Box<dyn Write> = if let Some(path) = output_file {
Box::new(OpenOptions::new().write(true).create_new(true).open(path)?)
} else {
Box::new(stdout().lock())
};

for event in event_gen {
event.to_csv(&mut writer)?;
writer.write(&[b'\n'])?;
}

Ok(())
}

async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn semi_honest_ipa(
args: &Args,
ipa_query_config: &IpaQueryConfig,
helper_clients: &[MpcHelperClient; 3],
) {
let input = InputSource::from(&args.input);
let ipa_query_config = IpaQueryConfig {
per_user_credit_cap: 3,
max_breakdown_key: 3,
num_multi_bits: 3,
attribution_window_seconds: None,
};
let query_type = QueryType::Ipa(ipa_query_config.clone());
let query_config = QueryConfig {
field_type: args.input.field,
query_type,
};
let query_id = helper_clients[0].create_query(query_config).await.unwrap();
let input_rows = input.iter::<TestRawDataRecord>().collect::<Vec<_>>();
let expected = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);
let expected = {
let mut r = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);

// pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results
// truncate shouldn't happen unless in_the_clear is badly broken
r.resize(
usize::try_from(ipa_query_config.max_breakdown_key).unwrap(),
0,
);
r
};

let actual = match args.input.field {
FieldType::Fp31 => {
Expand All @@ -213,7 +287,7 @@ async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
}

async fn multiply_in_field<F: Field>(
args: Args,
args: &Args,
helper_clients: &[MpcHelperClient; 3],
query_id: QueryId,
) where
Expand All @@ -229,7 +303,7 @@ async fn multiply_in_field<F: Field>(
validate(expected, actual);
}

async fn multiply(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn multiply(args: &Args, helper_clients: &[MpcHelperClient; 3]) {
let query_config = QueryConfig {
field_type: args.input.field,
query_type: QueryType::TestMultiply,
Expand Down
Loading

0 comments on commit aa9849d

Please sign in to comment.