diff --git a/Cargo.toml b/Cargo.toml index 70a986c04f156..d2a4958158f21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,6 +78,7 @@ bevy_winit = { path = "crates/bevy_winit", optional = true, version = "0.1" } [dev-dependencies] rand = "0.7.2" serde = { version = "1", features = ["derive"]} +criterion = "0.3" [[example]] name = "hello_world" @@ -179,6 +180,10 @@ path = "examples/ecs/startup_system.rs" name = "ecs_guide" path = "examples/ecs/ecs_guide.rs" +[[example]] +name = "parallel_query" +path = "examples/ecs/parallel_query.rs" + [[example]] name = "breakout" path = "examples/game/breakout.rs" @@ -242,3 +247,8 @@ path = "examples/window/multiple_windows.rs" [[example]] name = "window_settings" path = "examples/window/window_settings.rs" + +[[bench]] +name = "iter" +path = "crates/bevy_tasks/benches/iter.rs" +harness = false diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index 0d27c4530b500..353fe14117434 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -3,6 +3,7 @@ use bevy_hecs::{ Archetype, Component, ComponentError, Entity, Fetch, Query as HecsQuery, QueryOne, Ref, RefMut, World, }; +use bevy_tasks::ParallelIterator; use std::marker::PhantomData; /// Provides scoped access to a World according to a given [HecsQuery] @@ -148,6 +149,29 @@ impl<'w, Q: HecsQuery> QueryBorrow<'w, Q> { iter: None, } } + + /// Like `iter`, but returns child iterators of at most `batch_size` + /// elements + /// + /// Useful for distributing work over a threadpool using the + /// ParallelIterator interface. + /// + /// Batch size needs to be chosen based on the task being done in + /// parallel. The elements in each batch are computed serially, while + /// the batches themselves are computed in parallel. + /// + /// A too small batch size can cause too much overhead, since scheduling + /// each batch could take longer than running the batch. On the other + /// hand, a too large batch size risks that one batch is still running + /// long after the rest have finished. + pub fn par_iter<'q>(&'q mut self, batch_size: u32) -> ParIter<'q, 'w, Q> { + ParIter { + borrow: self, + archetype_index: 0, + batch_size, + batch: 0, + } + } } unsafe impl<'w, Q: HecsQuery> Send for QueryBorrow<'w, Q> {} @@ -257,3 +281,61 @@ impl ChunkIter { } } } + +/// Batched version of `QueryIter` +pub struct ParIter<'q, 'w, Q: HecsQuery> { + borrow: &'q mut QueryBorrow<'w, Q>, + archetype_index: u32, + batch_size: u32, + batch: u32, +} + +impl<'q, 'w, Q: HecsQuery> ParallelIterator> for ParIter<'q, 'w, Q> { + type Item = >::Item; + + fn next_batch(&mut self) -> Option> { + loop { + let archetype = self.borrow.archetypes.get(self.archetype_index as usize)?; + let offset = self.batch_size * self.batch; + if offset >= archetype.len() { + self.archetype_index += 1; + self.batch = 0; + continue; + } + if let Some(fetch) = unsafe { Q::Fetch::get(archetype, offset as usize) } { + self.batch += 1; + return Some(Batch { + _marker: PhantomData, + state: ChunkIter { + fetch, + len: self.batch_size.min(archetype.len() - offset), + }, + }); + } else { + self.archetype_index += 1; + debug_assert_eq!( + self.batch, 0, + "query fetch should always reject at the first batch or not at all" + ); + continue; + } + } + } +} + +/// A sequence of entities yielded by `ParIter` +pub struct Batch<'q, Q: HecsQuery> { + _marker: PhantomData<&'q ()>, + state: ChunkIter, +} + +impl<'q, 'w, Q: HecsQuery> Iterator for Batch<'q, Q> { + type Item = >::Item; + + fn next(&mut self) -> Option { + let components = unsafe { self.state.next()? }; + Some(components) + } +} + +unsafe impl<'q, Q: HecsQuery> Send for Batch<'q, Q> {} diff --git a/crates/bevy_tasks/benches/iter.rs b/crates/bevy_tasks/benches/iter.rs new file mode 100644 index 0000000000000..b6bdd910a1314 --- /dev/null +++ b/crates/bevy_tasks/benches/iter.rs @@ -0,0 +1,148 @@ +use bevy_tasks::{ParallelIterator, TaskPoolBuilder}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; + +struct ParChunks<'a, T>(std::slice::Chunks<'a, T>); +impl<'a, T> ParallelIterator> for ParChunks<'a, T> +where + T: 'a + Send + Sync, +{ + type Item = &'a T; + + fn next_batch(&mut self) -> Option> { + self.0.next().map(|s| s.iter()) + } +} + +struct ParChunksMut<'a, T>(std::slice::ChunksMut<'a, T>); +impl<'a, T> ParallelIterator> for ParChunksMut<'a, T> +where + T: 'a + Send + Sync, +{ + type Item = &'a mut T; + + fn next_batch(&mut self) -> Option> { + self.0.next().map(|s| s.iter_mut()) + } +} + +fn bench_overhead(c: &mut Criterion) { + fn noop(_: &mut usize) {}; + + let mut v = (0..10000).collect::>(); + c.bench_function("overhead_iter", |b| { + b.iter(|| { + v.iter_mut().for_each(noop); + }) + }); + + let mut v = (0..10000).collect::>(); + let mut group = c.benchmark_group("overhead_par_iter"); + for thread_count in &[1, 2, 4, 8, 16, 32] { + let pool = TaskPoolBuilder::new().num_threads(*thread_count).build(); + group.bench_with_input( + BenchmarkId::new("threads", thread_count), + thread_count, + |b, _| { + b.iter(|| { + ParChunksMut(v.chunks_mut(100)).for_each(&pool, noop); + }) + }, + ); + } + group.finish(); +} + +fn bench_for_each(c: &mut Criterion) { + fn busy_work(n: usize) { + let mut i = n; + while i > 0 { + i = black_box(i - 1); + } + } + + let mut v = (0..10000).collect::>(); + c.bench_function("for_each_iter", |b| { + b.iter(|| { + v.iter_mut().for_each(|x| { + busy_work(10000); + *x *= *x; + }); + }) + }); + + let mut v = (0..10000).collect::>(); + let mut group = c.benchmark_group("for_each_par_iter"); + for thread_count in &[1, 2, 4, 8, 16, 32] { + let pool = TaskPoolBuilder::new().num_threads(*thread_count).build(); + group.bench_with_input( + BenchmarkId::new("threads", thread_count), + thread_count, + |b, _| { + b.iter(|| { + ParChunksMut(v.chunks_mut(100)).for_each(&pool, |x| { + busy_work(10000); + *x *= *x; + }); + }) + }, + ); + } + group.finish(); +} + +fn bench_many_maps(c: &mut Criterion) { + fn busy_doubles(mut x: usize, n: usize) -> usize { + for _ in 0..n { + x = black_box(x.wrapping_mul(2)); + } + x + } + + let v = (0..10000).collect::>(); + c.bench_function("many_maps_iter", |b| { + b.iter(|| { + v.iter() + .map(|x| busy_doubles(*x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .for_each(drop); + }) + }); + + let v = (0..10000).collect::>(); + let mut group = c.benchmark_group("many_maps_par_iter"); + for thread_count in &[1, 2, 4, 8, 16, 32] { + let pool = TaskPoolBuilder::new().num_threads(*thread_count).build(); + group.bench_with_input( + BenchmarkId::new("threads", thread_count), + thread_count, + |b, _| { + b.iter(|| { + ParChunks(v.chunks(100)) + .map(|x| busy_doubles(*x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .map(|x| busy_doubles(x, 1000)) + .for_each(&pool, drop); + }) + }, + ); + } + group.finish(); +} + +criterion_group!(benches, bench_overhead, bench_for_each, bench_many_maps); +criterion_main!(benches); diff --git a/crates/bevy_tasks/src/iter/adapters.rs b/crates/bevy_tasks/src/iter/adapters.rs new file mode 100644 index 0000000000000..1d6740406a09d --- /dev/null +++ b/crates/bevy_tasks/src/iter/adapters.rs @@ -0,0 +1,224 @@ +use crate::iter::ParallelIterator; + +pub struct Chain { + pub(crate) left: T, + pub(crate) right: U, + pub(crate) left_in_progress: bool, +} + +impl ParallelIterator for Chain +where + B: Iterator + Send, + T: ParallelIterator, + U: ParallelIterator, +{ + type Item = T::Item; + + fn next_batch(&mut self) -> Option { + if self.left_in_progress { + match self.left.next_batch() { + b @ Some(_) => return b, + None => self.left_in_progress = false, + } + } + self.right.next_batch() + } +} + +pub struct Map { + pub(crate) iter: P, + pub(crate) f: F, +} + +impl ParallelIterator> for Map +where + B: Iterator + Send, + U: ParallelIterator, + F: FnMut(U::Item) -> T + Send + Clone, +{ + type Item = T; + + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.map(self.f.clone())) + } +} + +pub struct Filter { + pub(crate) iter: P, + pub(crate) predicate: F, +} + +impl ParallelIterator> for Filter +where + B: Iterator + Send, + P: ParallelIterator, + F: FnMut(&P::Item) -> bool + Send + Clone, +{ + type Item = P::Item; + + fn next_batch(&mut self) -> Option> { + self.iter + .next_batch() + .map(|b| b.filter(self.predicate.clone())) + } +} + +pub struct FilterMap { + pub(crate) iter: P, + pub(crate) f: F, +} + +impl ParallelIterator> for FilterMap +where + B: Iterator + Send, + P: ParallelIterator, + F: FnMut(P::Item) -> Option + Send + Clone, +{ + type Item = R; + + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.filter_map(self.f.clone())) + } +} + +pub struct FlatMap { + pub(crate) iter: P, + pub(crate) f: F, +} + +impl ParallelIterator> for FlatMap +where + B: Iterator + Send, + P: ParallelIterator, + F: FnMut(P::Item) -> U + Send + Clone, + U: IntoIterator, + U::IntoIter: Send, +{ + type Item = U::Item; + + // This extends each batch using the flat map. The other option is + // to turn each IntoIter into its own batch. + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.flat_map(self.f.clone())) + } +} + +pub struct Flatten

{ + pub(crate) iter: P, +} + +impl ParallelIterator> for Flatten

+where + B: Iterator + Send, + P: ParallelIterator, + B::Item: IntoIterator, + ::IntoIter: Send, +{ + type Item = ::Item; + + // This extends each batch using the flatten. The other option is to + // turn each IntoIter into its own batch. + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.flatten()) + } +} + +pub struct Fuse

{ + pub(crate) iter: Option

, +} + +impl ParallelIterator for Fuse

+where + B: Iterator + Send, + P: ParallelIterator, +{ + type Item = P::Item; + + fn next_batch(&mut self) -> Option { + match &mut self.iter { + Some(iter) => match iter.next_batch() { + b @ Some(_) => b, + None => { + self.iter = None; + None + } + }, + None => None, + } + } +} + +pub struct Inspect { + pub(crate) iter: P, + pub(crate) f: F, +} + +impl ParallelIterator> for Inspect +where + B: Iterator + Send, + P: ParallelIterator, + F: FnMut(&P::Item) + Send + Clone, +{ + type Item = P::Item; + + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.inspect(self.f.clone())) + } +} + +pub struct Copied

{ + pub(crate) iter: P, +} + +impl<'a, B, P, T> ParallelIterator> for Copied

+where + B: Iterator + Send, + P: ParallelIterator, + T: 'a + Copy, +{ + type Item = T; + + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.copied()) + } +} + +pub struct Cloned

{ + pub(crate) iter: P, +} + +impl<'a, B, P, T> ParallelIterator> for Cloned

+where + B: Iterator + Send, + P: ParallelIterator, + T: 'a + Copy, +{ + type Item = T; + + fn next_batch(&mut self) -> Option> { + self.iter.next_batch().map(|b| b.cloned()) + } +} + +pub struct Cycle

{ + pub(crate) iter: P, + pub(crate) curr: Option

, +} + +impl ParallelIterator for Cycle

+where + B: Iterator + Send, + P: ParallelIterator + Clone, +{ + type Item = P::Item; + + fn next_batch(&mut self) -> Option { + match self.curr.as_mut().and_then(|c| c.next_batch()) { + batch @ Some(_) => batch, + None => { + self.curr = Some(self.iter.clone()); + self.next_batch() + } + } + } +} diff --git a/crates/bevy_tasks/src/iter/mod.rs b/crates/bevy_tasks/src/iter/mod.rs new file mode 100644 index 0000000000000..2cd8cf9576dbb --- /dev/null +++ b/crates/bevy_tasks/src/iter/mod.rs @@ -0,0 +1,510 @@ +use crate::TaskPool; + +mod adapters; +pub use adapters::*; + +/// ParallelIterator closely emulates the std::iter::Iterator +/// interface. However, it uses bevy_task to compute batches in parallel. +/// +/// Note that the overhead of ParallelIterator is high relative to some +/// workloads. In particular, if the batch size is too small or task being +/// run in parallel is inexpensive, *a ParallelIterator could take longer +/// than a normal Iterator*. Therefore, you should profile your code before +/// using ParallelIterator. +pub trait ParallelIterator +where + B: Iterator + Send, + Self: Sized + Send, +{ + type Item; + + /// Returns the next batch of items for processing. + /// + /// Each batch is an iterator with items of the same type as the + /// ParallelIterator. Returns `None` when there are no batches left. + fn next_batch(&mut self) -> Option; + + /// Returns the bounds on the remaining number of items in the + /// parallel iterator. + /// + /// See [`Iterator::size_hint()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.size_hint) + fn size_hint(&self) -> (usize, Option) { + (0, None) + } + + /// Consumes the parallel iterator and returns the number of items. + /// + /// See [`Iterator::count()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.count) + fn count(mut self, pool: &TaskPool) -> usize { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + s.spawn(async move { batch.count() }) + } + }) + .iter() + .sum() + } + + /// Consumes the parallel iterator and returns the last item. + /// + /// See [`Iterator::last()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.last) + fn last(mut self, _pool: &TaskPool) -> Option { + let mut last_item = None; + while let Some(batch) = self.next_batch() { + last_item = batch.last(); + } + last_item + } + + /// Consumes the parallel iterator and returns the nth item. + /// + /// See [`Iterator::nth()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.nth) + // TODO: Optimize with size_hint on each batch + fn nth(mut self, _pool: &TaskPool, n: usize) -> Option { + let mut i = 0; + while let Some(batch) = self.next_batch() { + for item in batch { + if i == n { + return Some(item); + } + i += 1; + } + } + None + } + + /// Takes two parallel iterators and returns a parallel iterators over + /// both in sequence. + /// + /// See [`Iterator::chain()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain) + // TODO: Use IntoParallelIterator for U + fn chain(self, other: U) -> Chain + where + U: ParallelIterator, + { + Chain { + left: self, + right: other, + left_in_progress: true, + } + } + + /// Takes a closure and creates a parallel iterator which calls that + /// closure on each item. + /// + /// See [`Iterator::map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.map) + fn map(self, f: F) -> Map + where + F: FnMut(Self::Item) -> T + Send + Clone, + { + Map { iter: self, f } + } + + /// Calls a closure on each item of a parallel iterator. + /// + /// See [`Iterator::for_each()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.for_each) + fn for_each(mut self, pool: &TaskPool, f: F) + where + F: FnMut(Self::Item) + Send + Clone + Sync, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { + batch.for_each(newf); + }); + } + }); + } + + /// Creates a parallel iterator which uses a closure to determine + /// if an element should be yielded. + /// + /// See [`Iterator::filter()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter) + fn filter(self, predicate: F) -> Filter + where + F: FnMut(&Self::Item) -> bool, + { + Filter { + iter: self, + predicate, + } + } + + /// Creates a parallel iterator that both filters and maps. + /// + /// See [`Iterator::filter_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.filter_map) + fn filter_map(self, f: F) -> FilterMap + where + F: FnMut(Self::Item) -> Option, + { + FilterMap { iter: self, f } + } + + /// Creates a parallel iterator that works like map, but flattens + /// nested structure. + /// + /// See [`Iterator::flat_map()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flat_map) + fn flat_map(self, f: F) -> FlatMap + where + F: FnMut(Self::Item) -> U, + U: IntoIterator, + { + FlatMap { iter: self, f } + } + + /// Creates a parallel iterator that flattens nested structure. + /// + /// See [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten) + fn flatten(self) -> Flatten + where + Self::Item: IntoIterator, + { + Flatten { iter: self } + } + + /// Creates a parallel iterator which ends after the first None. + /// + /// See [`Iterator::fuse()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fuse) + fn fuse(self) -> Fuse { + Fuse { iter: Some(self) } + } + + /// Does something with each item of a parallel iterator, passing + /// the value on. + /// + /// See [`Iterator::inspect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.inspect) + fn inspect(self, f: F) -> Inspect + where + F: FnMut(&Self::Item), + { + Inspect { iter: self, f } + } + + /// Borrows a parallel iterator, rather than consuming it. + /// + /// See [`Iterator::by_ref()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.by_ref) + fn by_ref(&mut self) -> &mut Self { + self + } + + /// Transforms a parallel iterator into a collection. + /// + /// See [`Iterator::collect()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.collect) + // TODO: Investigate optimizations for less copying + fn collect(mut self, pool: &TaskPool) -> C + where + C: std::iter::FromIterator, + Self::Item: Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + s.spawn(async move { batch.collect::>() }); + } + }) + .into_iter() + .flatten() + .collect() + } + + /// Consumes a parallel iterator, creating two collections from it. + /// + /// See [`Iterator::partition()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.partition) + // TODO: Investigate optimizations for less copying + fn partition(mut self, pool: &TaskPool, f: F) -> (C, C) + where + C: Default + Extend + Send, + F: FnMut(&Self::Item) -> bool + Send + Sync + Clone, + Self::Item: Send + 'static, + { + let (mut a, mut b) = <(C, C)>::default(); + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.partition::, F>(newf) }) + } + }) + .into_iter() + .for_each(|(c, d)| { + a.extend(c); + b.extend(d); + }); + (a, b) + } + + /// Repeatedly applies a function to items of each batch of a parallel + /// iterator, producing a Vec of final values. + /// + /// *Note that this folds each batch independently and returns a Vec of + /// results (in batch order).* + /// + /// See [`Iterator::fold()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold) + fn fold(mut self, pool: &TaskPool, init: C, f: F) -> Vec + where + F: FnMut(C, Self::Item) -> C + Send + Sync + Clone, + C: Clone + Send + Sync + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + let newi = init.clone(); + s.spawn(async move { batch.fold(newi, newf) }); + } + }) + } + + /// Tests if every element of the parallel iterator matches a predicate. + /// + /// *Note that all is **not** short circuiting.* + /// + /// See [`Iterator::all()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.all) + fn all(mut self, pool: &TaskPool, f: F) -> bool + where + F: FnMut(Self::Item) -> bool + Send + Sync + Clone, + { + pool.scope(|s| { + while let Some(mut batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.all(newf) }); + } + }) + .into_iter() + .all(std::convert::identity) + } + + /// Tests if any element of the parallel iterator matches a predicate. + /// + /// *Note that any is **not** short circuiting.* + /// + /// See [`Iterator::any()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.any) + fn any(mut self, pool: &TaskPool, f: F) -> bool + where + F: FnMut(Self::Item) -> bool + Send + Sync + Clone, + { + pool.scope(|s| { + while let Some(mut batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.any(newf) }); + } + }) + .into_iter() + .any(std::convert::identity) + } + + /// Searches for an element in a parallel iterator, returning its index. + /// + /// *Note that position consumes the whole iterator.* + /// + /// See [`Iterator::position()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.position) + // TODO: Investigate optimizations for less copying + fn position(mut self, pool: &TaskPool, f: F) -> Option + where + F: FnMut(Self::Item) -> bool + Send + Sync + Clone, + { + let poses = pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let mut newf = f.clone(); + s.spawn(async move { + let mut len = 0; + let mut pos = None; + for item in batch { + if pos.is_none() && newf(item) { + pos = Some(len); + } + len += 1; + } + (len, pos) + }); + } + }); + let mut start = 0; + for (len, pos) in poses { + if let Some(pos) = pos { + return Some(start + pos); + } + start += len; + } + None + } + + /// Returns the maximum item of a parallel iterator. + /// + /// See [`Iterator::max()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max) + fn max(mut self, pool: &TaskPool) -> Option + where + Self::Item: Ord + Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + s.spawn(async move { batch.max() }); + } + }) + .into_iter() + .flatten() + .max() + } + + /// Returns the minimum item of a parallel iterator. + /// + /// See [`Iterator::min()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min) + fn min(mut self, pool: &TaskPool) -> Option + where + Self::Item: Ord + Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + s.spawn(async move { batch.min() }); + } + }) + .into_iter() + .flatten() + .min() + } + + /// Returns the item that gives the maximum value from the specified function. + /// + /// See [`Iterator::max_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by_key) + fn max_by_key(mut self, pool: &TaskPool, f: F) -> Option + where + R: Ord, + F: FnMut(&Self::Item) -> R + Send + Sync + Clone, + Self::Item: Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.max_by_key(newf) }); + } + }) + .into_iter() + .flatten() + .max_by_key(f) + } + + /// Returns the item that gives the maximum value with respect to the specified comparison function. + /// + /// See [`Iterator::max_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.max_by) + fn max_by(mut self, pool: &TaskPool, f: F) -> Option + where + F: FnMut(&Self::Item, &Self::Item) -> std::cmp::Ordering + Send + Sync + Clone, + Self::Item: Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.max_by(newf) }); + } + }) + .into_iter() + .flatten() + .max_by(f) + } + + /// Returns the item that gives the minimum value from the specified function. + /// + /// See [`Iterator::min_by_key()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by_key) + fn min_by_key(mut self, pool: &TaskPool, f: F) -> Option + where + R: Ord, + F: FnMut(&Self::Item) -> R + Send + Sync + Clone, + Self::Item: Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.min_by_key(newf) }); + } + }) + .into_iter() + .flatten() + .min_by_key(f) + } + + /// Returns the item that gives the minimum value with respect to the specified comparison function. + /// + /// See [`Iterator::min_by()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.min_by) + fn min_by(mut self, pool: &TaskPool, f: F) -> Option + where + F: FnMut(&Self::Item, &Self::Item) -> std::cmp::Ordering + Send + Sync + Clone, + Self::Item: Send + 'static, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + let newf = f.clone(); + s.spawn(async move { batch.min_by(newf) }); + } + }) + .into_iter() + .flatten() + .min_by(f) + } + + /// Creates a parallel iterator which copies all of its items. + /// + /// See [`Iterator::copied()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.copied) + fn copied<'a, T>(self) -> Copied + where + Self: ParallelIterator, + T: 'a + Copy, + { + Copied { iter: self } + } + + /// Creates a parallel iterator which clones all of its items. + /// + /// See [`Iterator::cloned()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cloned) + fn cloned<'a, T>(self) -> Cloned + where + Self: ParallelIterator, + T: 'a + Copy, + { + Cloned { iter: self } + } + + /// Repeats a parallel iterator endlessly. + /// + /// See [`Iterator::cycle()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.cycle) + fn cycle(self) -> Cycle + where + Self: Clone, + { + Cycle { + iter: self, + curr: None, + } + } + + /// Sums the items of a parallel iterator. + /// + /// See [`Iterator::sum()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.sum) + fn sum(mut self, pool: &TaskPool) -> R + where + S: std::iter::Sum + Send + 'static, + R: std::iter::Sum, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + s.spawn(async move { batch.sum() }); + } + }) + .into_iter() + .sum() + } + + /// Multiplies all the items of a parallel iterator. + /// + /// See [`Iterator::product()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.product) + fn product(mut self, pool: &TaskPool) -> R + where + S: std::iter::Product + Send + 'static, + R: std::iter::Product, + { + pool.scope(|s| { + while let Some(batch) = self.next_batch() { + s.spawn(async move { batch.product() }); + } + }) + .into_iter() + .product() + } +} diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index d90c86204c4d2..28ffdda055fbc 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -13,8 +13,12 @@ pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IOTaskPool}; mod countdown_event; pub use countdown_event::CountdownEvent; +mod iter; +pub use iter::ParallelIterator; + pub mod prelude { pub use crate::{ + iter::ParallelIterator, slice::{ParallelSlice, ParallelSliceMut}, usages::{AsyncComputeTaskPool, ComputeTaskPool, IOTaskPool}, }; diff --git a/examples/ecs/parallel_query.rs b/examples/ecs/parallel_query.rs new file mode 100644 index 0000000000000..1ee58ac6b877b --- /dev/null +++ b/examples/ecs/parallel_query.rs @@ -0,0 +1,74 @@ +use bevy::{prelude::*, tasks::prelude::*}; +use rand::random; + +struct Velocity(Vec2); + +fn spawn_system( + mut commands: Commands, + asset_server: Res, + mut materials: ResMut>, +) { + commands.spawn(Camera2dComponents::default()); + let texture_handle = asset_server.load("assets/branding/icon.png").unwrap(); + let material = materials.add(texture_handle.into()); + for _ in 0..128 { + commands + .spawn(SpriteComponents { + material, + translation: Translation::new(0.0, 0.0, 0.0), + scale: Scale(0.1), + ..Default::default() + }) + .with(Velocity( + 20.0 * Vec2::new(random::() - 0.5, random::() - 0.5), + )); + } +} + +// Move sprites according to their velocity +fn move_system(pool: Res, mut sprites: Query<(&mut Translation, &Velocity)>) { + // Compute the new location of each sprite in parallel on the + // ComputeTaskPool using batches of 32 sprties + // + // This example is only for demonstrative purposes. Using a + // ParallelIterator for an inexpensive operation like addition on only 128 + // elements will not typically be faster than just using a normal Iterator. + // See the ParallelIterator documentation for more information on when + // to use or not use ParallelIterator over a normal Iterator. + sprites.iter().par_iter(32).for_each(&pool, |(mut t, v)| { + t.0 += v.0.extend(0.0); + }); +} + +// Bounce sprties outside the window +fn bounce_system( + pool: Res, + windows: Res, + mut sprites: Query<(&Translation, &mut Velocity)>, +) { + let Window { width, height, .. } = windows.get_primary().expect("No primary window"); + let left = *width as f32 / -2.0; + let right = *width as f32 / 2.0; + let bottom = *height as f32 / -2.0; + let top = *height as f32 / 2.0; + sprites + .iter() + // Batch size of 32 is chosen to limit the overhead of + // ParallelIterator, since negating a vector is very inexpensive. + .par_iter(32) + // Filter out sprites that don't need to be bounced + .filter(|(t, _)| !(left < t.x() && t.x() < right && bottom < t.y() && t.y() < top)) + // For simplicity, just reverse the velocity; don't use realistic bounces + .for_each(&pool, |(_, mut v)| { + v.0 = -v.0; + }); +} + +fn main() { + App::build() + .add_default_plugins() + .add_startup_system(spawn_system.system()) + .add_system(move_system.system()) + .add_system(bounce_system.system()) + .run(); +}