diff --git a/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs index 5488d0a5caf1..6bc586d518a1 100644 --- a/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs @@ -88,40 +88,39 @@ fn create_build_table_outer( // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|part_no| { - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || (false, unitvec![idx]), - |v| v.1.push(idx), - ) - } - idx += 1; - }); + let par_iter = (0..n_partitions).into_par_iter().map(|part_no| { + let mut hash_tbl: HashMap = + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + + let mut offset = 0; + for hashes in hashes { + for hashes in hashes.data_views() { + let len = hashes.len(); + let mut idx = 0; + hashes.iter().for_each(|h| { + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if part_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + populate_multiple_key_hashmap( + &mut hash_tbl, + idx, + *h, + keys, + || (false, unitvec![idx]), + |v| v.1.push(idx), + ) + } + idx += 1; + }); - offset += len as IdxSize; - } + offset += len as IdxSize; } - hash_tbl - }) - }) - .collect() + } + hash_tbl + }); + + POOL.install(|| par_iter.collect()) } /// Probe the build table and add tuples to the results (inner join) @@ -360,40 +359,32 @@ pub(crate) fn create_build_table_semi_anti( // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|part_no| { - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || (), - |_| (), - ) - } - idx += 1; - }); + let par_iter = (0..n_partitions).into_par_iter().map(|part_no| { + let mut hash_tbl: HashMap = + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + + let mut offset = 0; + for hashes in hashes { + for hashes in hashes.data_views() { + let len = hashes.len(); + let mut idx = 0; + hashes.iter().for_each(|h| { + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if part_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + populate_multiple_key_hashmap(&mut hash_tbl, idx, *h, keys, || (), |_| ()) + } + idx += 1; + }); - offset += len as IdxSize; - } + offset += len as IdxSize; } - hash_tbl - }) - }) - .collect() + } + hash_tbl + }); + + POOL.install(|| par_iter.collect()) } #[cfg(feature = "semi_anti_join")] @@ -423,46 +414,43 @@ pub(crate) fn semi_anti_join_multiple_keys_impl<'a>( // next we probe the other relation // code duplication is because we want to only do the swap check once - POOL.install(move || { - probe_hashes - .into_par_iter() - .zip(offsets) - .flat_map(move |(probe_hashes, offset)| { - // local reference - let hash_tbls = &hash_tbls; - let mut results = - Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads()); - let local_offset = offset; - - let mut idx_a = local_offset as IdxSize; - for probe_hashes in probe_hashes.data_views() { - for &h in probe_hashes { - // probe table that contains the hashed value - let current_probe_table = - unsafe { hash_tbls.get_unchecked(hash_to_partition(h, n_tables)) }; - - let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { - let idx_b = idx_hash.idx; - // SAFETY: - // indices in a join operation are always in bounds. - unsafe { - compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) - } - }); - - match entry { - // left and right matches - Some((_, _)) => results.push((idx_a, true)), - // only left values, right = null - None => results.push((idx_a, false)), + probe_hashes + .into_par_iter() + .zip(offsets) + .flat_map(move |(probe_hashes, offset)| { + // local reference + let hash_tbls = &hash_tbls; + let mut results = Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads()); + let local_offset = offset; + + let mut idx_a = local_offset as IdxSize; + for probe_hashes in probe_hashes.data_views() { + for &h in probe_hashes { + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_tbls.get_unchecked(hash_to_partition(h, n_tables)) }; + + let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { + let idx_b = idx_hash.idx; + // SAFETY: + // indices in a join operation are always in bounds. + unsafe { + compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) } - idx_a += 1; + }); + + match entry { + // left and right matches + Some((_, _)) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), } + idx_a += 1; } + } - results - }) - }) + results + }) } #[cfg(feature = "semi_anti_join")] @@ -471,10 +459,10 @@ pub fn _left_anti_multiple_keys( b: &mut DataFrame, join_nulls: bool, ) -> Vec { - semi_anti_join_multiple_keys_impl(a, b, join_nulls) + let par_iter = semi_anti_join_multiple_keys_impl(a, b, join_nulls) .filter(|tpls| !tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } #[cfg(feature = "semi_anti_join")] @@ -483,10 +471,10 @@ pub fn _left_semi_multiple_keys( b: &mut DataFrame, join_nulls: bool, ) -> Vec { - semi_anti_join_multiple_keys_impl(a, b, join_nulls) + let par_iter = semi_anti_join_multiple_keys_impl(a, b, join_nulls) .filter(|tpls| tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } /// Probe the build table and add tuples to the results (inner join) diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs index 93268036c43d..00ad29499715 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs @@ -13,23 +13,23 @@ where // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|partition_no| { - let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); - for keys in &keys { - keys.into_iter().for_each(|k| { - if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { - hash_tbl.insert(k); - } - }); - } - hash_tbl - }) - }) - .collect() + let par_iter = (0..n_partitions).into_par_iter().map(|partition_no| { + let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); + for keys in &keys { + keys.into_iter().for_each(|k| { + if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { + hash_tbl.insert(k); + } + }); + } + hash_tbl + }); + POOL.install(|| par_iter.collect()) } -pub(super) fn semi_anti_impl( +/// Construct a ParallelIterator, but doesn't iterate it. This means the caller +/// context (or wherever it gets iterated) should be in POOL.install. +fn semi_anti_impl( probe: Vec, build: Vec, ) -> impl ParallelIterator @@ -46,40 +46,38 @@ where let n_tables = hash_sets.len(); // next we probe the other relation - POOL.install(move || { - probe - .into_par_iter() - .zip(offsets) - // probes_hashes: Vec processed by this thread - // offset: offset index - .flat_map(move |(probe, offset)| { - // local reference - let hash_sets = &hash_sets; - let probe_iter = probe.into_iter(); + // This is not wrapped in POOL.install because it is not being iterated here + probe + .into_par_iter() + .zip(offsets) + // probes_hashes: Vec processed by this thread + // offset: offset index + .flat_map(move |(probe, offset)| { + // local reference + let hash_sets = &hash_sets; + let probe_iter = probe.into_iter(); - // assume the result tuples equal length of the no. of hashes processed by this thread. - let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); + // assume the result tuples equal length of the no. of hashes processed by this thread. + let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); - probe_iter.enumerate().for_each(|(idx_a, k)| { - let idx_a = (idx_a + offset) as IdxSize; - // probe table that contains the hashed value - let current_probe_table = unsafe { - hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) - }; + probe_iter.enumerate().for_each(|(idx_a, k)| { + let idx_a = (idx_a + offset) as IdxSize; + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) }; - // we already hashed, so we don't have to hash again. - let value = current_probe_table.get(&k); + // we already hashed, so we don't have to hash again. + let value = current_probe_table.get(&k); - match value { - // left and right matches - Some(_) => results.push((idx_a, true)), - // only left values, right = null - None => results.push((idx_a, false)), - } - }); - results - }) - }) + match value { + // left and right matches + Some(_) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), + } + }); + results + }) } pub(super) fn hash_join_tuples_left_anti(probe: Vec, build: Vec) -> Vec @@ -87,10 +85,10 @@ where I: IntoIterator + Copy + Send + Sync, T: Send + Hash + Eq + Sync + Copy + DirtyHash, { - semi_anti_impl(probe, build) + let par_iter = semi_anti_impl(probe, build) .filter(|tpls| !tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } pub(super) fn hash_join_tuples_left_semi(probe: Vec, build: Vec) -> Vec @@ -98,8 +96,8 @@ where I: IntoIterator + Copy + Send + Sync, T: Send + Hash + Eq + Sync + Copy + DirtyHash, { - semi_anti_impl(probe, build) + let par_iter = semi_anti_impl(probe, build) .filter(|tpls| tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } diff --git a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs index d9b849ce1e59..6d97ee4735f4 100644 --- a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs +++ b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs @@ -23,13 +23,12 @@ where let slice_left = s_left.cont_slice().unwrap(); let slice_right = s_right.cont_slice().unwrap(); - let indexes = offsets - .into_par_iter() - .map(|(offset, len)| { - let slice_left = &slice_left[offset..offset + len]; - sorted_join::left::join(slice_left, slice_right, offset as IdxSize) - }) - .collect::>(); + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::left::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + let lefts = indexes.iter().map(|t| &t.0).collect::>(); let rights = indexes.iter().map(|t| &t.1).collect::>(); @@ -96,13 +95,12 @@ where let slice_left = s_left.cont_slice().unwrap(); let slice_right = s_right.cont_slice().unwrap(); - let indexes = offsets - .into_par_iter() - .map(|(offset, len)| { - let slice_left = &slice_left[offset..offset + len]; - sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) - }) - .collect::>(); + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + let lefts = indexes.iter().map(|t| &t.0).collect::>(); let rights = indexes.iter().map(|t| &t.1).collect::>();