Skip to content

Commit

Permalink
merge fix
Browse files Browse the repository at this point in the history
  • Loading branch information
EkamSinghPandher committed Oct 30, 2024
1 parent 55c6501 commit a738cbc
Showing 1 changed file with 0 additions and 120 deletions.
120 changes: 0 additions & 120 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,19 +326,8 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
let pols = polynomials.len();
let degree = polynomials[0].len();
let log_n = log2_strict(degree);
<<<<<<< HEAD

if log_n + rate_bits > 1 && polynomials.len() > 0 {
=======

#[cfg(any(test, doctest))]
init_gpu();

if log_n + rate_bits > 1
&& polynomials.len() > 0
&& pols * (1 << (log_n + rate_bits)) < (1 << 31)
{
>>>>>>> 8a00c2bc54a76355a0bf73dcaabb560d688cab4d
let _num_gpus: usize = std::env::var("NUM_OF_GPUS")
.expect("NUM_OF_GPUS should be set")
.parse()
Expand Down Expand Up @@ -494,116 +483,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// If blinding, salt with two random elements to each leaf vector.
let salt_size = if blinding { SALT_SIZE } else { 0 };
// println!("salt_size: {:?}", salt_size);
<<<<<<< HEAD

=======

#[cfg(all(feature = "cuda", feature = "batch"))]
let num_gpus: usize = std::env::var("NUM_OF_GPUS")
.expect("NUM_OF_GPUS should be set")
.parse()
.unwrap();
// let num_gpus: usize = 1;
#[cfg(all(feature = "cuda", feature = "batch"))]
println!("get num of gpus: {:?}", num_gpus);
#[cfg(all(feature = "cuda", feature = "batch"))]
let total_num_of_fft = polynomials.len();
// println!("total_num_of_fft: {:?}", total_num_of_fft);
#[cfg(all(feature = "cuda", feature = "batch"))]
let per_device_batch = total_num_of_fft.div_ceil(num_gpus);

#[cfg(all(feature = "cuda", feature = "batch"))]
let chunk_size = total_num_of_fft.div_ceil(num_gpus);

#[cfg(all(feature = "cuda", feature = "batch"))]
if log_n > 10 && polynomials.len() > 0 {
println!("log_n: {:?}", log_n);
let start_lde = std::time::Instant::now();

// let poly_chunk = polynomials;
// let id = 0;
let ret = polynomials
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(id, poly_chunk)| {
println!(
"invoking ntt_batch, device_id: {:?}, per_device_batch: {:?}",
id, per_device_batch
);

let start = std::time::Instant::now();

let input_domain_size = 1 << log2_strict(degree);
let device_input_data: HostOrDeviceSlice<'_, F> =
HostOrDeviceSlice::cuda_malloc(
id as i32,
input_domain_size * polynomials.len(),
)
.unwrap();
let device_input_data = std::sync::RwLock::new(device_input_data);

poly_chunk.par_iter().enumerate().for_each(|(i, p)| {
// println!("copy for index: {:?}", i);
let _guard = device_input_data.read().unwrap();
let _ = _guard.copy_from_host_offset(
p.coeffs.as_slice(),
input_domain_size * i,
input_domain_size,
);
});

println!("data transform elapsed: {:?}", start.elapsed());
let mut cfg_lde = NTTConfig::default();
cfg_lde.batches = per_device_batch as u32;
cfg_lde.extension_rate_bits = rate_bits as u32;
cfg_lde.are_inputs_on_device = true;
cfg_lde.are_outputs_on_device = true;
cfg_lde.with_coset = true;
println!(
"start cuda_malloc with elements: {:?}",
(1 << log_n) * per_device_batch
);
let mut device_output_data: HostOrDeviceSlice<'_, F> =
HostOrDeviceSlice::cuda_malloc(id as i32, (1 << log_n) * per_device_batch)
.unwrap();

let start = std::time::Instant::now();
lde_batch::<F>(
id,
device_output_data.as_mut_ptr(),
device_input_data.read().unwrap().as_ptr(),
log2_strict(degree),
cfg_lde,
);
println!("real lde_batch elapsed: {:?}", start.elapsed());
let start = std::time::Instant::now();
let nums: Vec<usize> = (0..poly_chunk.len()).collect();
let r = nums
.par_iter()
.map(|i| {
let mut host_data: Vec<F> = vec![F::ZERO; 1 << log_n];
let _ = device_output_data.copy_to_host_offset(
host_data.as_mut_slice(),
(1 << log_n) * i,
1 << log_n,
);
PolynomialValues::new(host_data).values
})
.collect::<Vec<Vec<F>>>();
println!("collect data from gpu used: {:?}", start.elapsed());
r
})
.chain(
(0..salt_size)
.into_par_iter()
.map(|_| F::rand_vec(degree << rate_bits)),
)
.collect();
println!("real lde elapsed: {:?}", start_lde.elapsed());
return ret;
}

>>>>>>> 8a00c2bc54a76355a0bf73dcaabb560d688cab4d
let ret = polynomials
.par_iter()
.map(|p| {
Expand Down

0 comments on commit a738cbc

Please sign in to comment.