Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace DrainFilter with ExtractIf #341

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ documentation = "https://docs.rs/smallvec/"
write = []
specialization = []
may_dangle = []
drain_filter = []
drain_keep_rest = ["drain_filter"]
extract_if = []

[dependencies]
serde = { version = "1", optional = true, default-features = false }
Expand Down
172 changes: 44 additions & 128 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
//! When this feature is enabled, `SmallVec<u8, _>` implements the `std::io::Write` trait.
//! This feature is not compatible with `#![no_std]` programs.
//!
//! ### `drain_filter`
//! ### `extract_if`
//!
//! **This feature is unstable.** It may change to match the unstable `drain_filter` method in libstd.
//! **This feature is unstable.** It may change to match the unstable `extract_if` method in libstd.
//!
//! Enables the `drain_filter` method, which produces an iterator that calls a user-provided
//! Enables the `extract_if` method, which produces an iterator that calls a user-provided
//! closure to determine which elements of the vector to remove and yield from the iterator.
//!
//! ### `specialization`
Expand Down Expand Up @@ -380,13 +380,13 @@ impl<'a, T: 'a, const N: usize> Drop for Drain<'a, T, N> {
}
}

#[cfg(feature = "drain_filter")]
#[cfg(feature = "extract_if")]
/// An iterator which uses a closure to determine if an element should be removed.
///
/// Returned from [`SmallVec::drain_filter`][1].
/// Returned from [`SmallVec::extract_if`][1].
///
/// [1]: struct.SmallVec.html#method.drain_filter
pub struct DrainFilter<'a, T, const N: usize, F>
/// [1]: struct.SmallVec.html#method.extract_if
pub struct ExtractIf<'a, T, const N: usize, F>
where
F: FnMut(&mut T) -> bool,
{
Expand All @@ -399,29 +399,23 @@ where
old_len: usize,
/// The filter test predicate.
pred: F,
/// A flag that indicates a panic has occurred in the filter test predicate.
/// This is used as a hint in the drop implementation to prevent consumption
/// of the remainder of the `DrainFilter`. Any unprocessed items will be
/// backshifted in the `vec`, but no further items will be dropped or
/// tested by the filter predicate.
panic_flag: bool,
}

#[cfg(feature = "drain_filter")]
impl<T, const N: usize, F> core::fmt::Debug for DrainFilter<'_, T, N, F>
#[cfg(feature = "extract_if")]
impl<T, const N: usize, F> core::fmt::Debug for ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
T: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("DrainFilter")
f.debug_tuple("ExtractIf")
.field(&self.vec.as_slice())
.finish()
}
}

#[cfg(feature = "drain_filter")]
impl<T, F, const N: usize> Iterator for DrainFilter<'_, T, N, F>
#[cfg(feature = "extract_if")]
impl<T, F, const N: usize> Iterator for ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
Expand All @@ -432,9 +426,7 @@ where
while self.idx < self.old_len {
let i = self.idx;
let v = core::slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
self.panic_flag = true;
let drained = (self.pred)(&mut v[i]);
self.panic_flag = false;
// Update the index *after* the predicate is called. If the index
// is updated prior and the predicate panics, the element at this
// index would be leaked.
Expand All @@ -444,8 +436,8 @@ where
return Some(core::ptr::read(&v[i]));
} else if self.del > 0 {
let del = self.del;
let src: *const Self::Item = &v[i];
let dst: *mut Self::Item = &mut v[i - del];
let src: *const T = &v[i];
let dst: *mut T = &mut v[i - del];
core::ptr::copy_nonoverlapping(src, dst, 1);
}
}
Expand All @@ -458,109 +450,27 @@ where
}
}

#[cfg(feature = "drain_filter")]
impl<T, F, const N: usize> Drop for DrainFilter<'_, T, N, F>
#[cfg(feature = "extract_if")]
impl<T, F, const N: usize> Drop for ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
fn drop(&mut self) {
struct BackshiftOnDrop<'a, 'b, T, const N: usize, F>
where
F: FnMut(&mut T) -> bool,
{
drain: &'b mut DrainFilter<'a, T, N, F>,
}

impl<'a, 'b, T, const N: usize, F> Drop for BackshiftOnDrop<'a, 'b, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
fn drop(&mut self) {
unsafe {
if self.drain.idx < self.drain.old_len && self.drain.del > 0 {
// This is a pretty messed up state, and there isn't really an
// obviously right thing to do. We don't want to keep trying
// to execute `pred`, so we just backshift all the unprocessed
// elements and tell the vec that they still exist. The backshift
// is required to prevent a double-drop of the last successfully
// drained item prior to a panic in the predicate.
let ptr = self.drain.vec.as_mut_ptr();
let src = ptr.add(self.drain.idx);
let dst = src.sub(self.drain.del);
let tail_len = self.drain.old_len - self.drain.idx;
src.copy_to(dst, tail_len);
}
self.drain.vec.set_len(self.drain.old_len - self.drain.del);
}
}
}

let backshift = BackshiftOnDrop { drain: self };

// Attempt to consume any remaining elements if the filter predicate
// has not yet panicked. We'll backshift any remaining elements
// whether we've already panicked or if the consumption here panics.
if !backshift.drain.panic_flag {
backshift.drain.for_each(drop);
}
}
}

#[cfg(feature = "drain_keep_rest")]
impl<T, F, const N: usize> DrainFilter<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
/// Keep unyielded elements in the source `Vec`.
///
/// # Examples
///
/// ```
/// # use smallvec::{smallvec, SmallVec};
///
/// let mut vec: SmallVec<char, 2> = smallvec!['a', 'b', 'c'];
/// let mut drain = vec.drain_filter(|_| true);
///
/// assert_eq!(drain.next().unwrap(), 'a');
///
/// // This call keeps 'b' and 'c' in the vec.
/// drain.keep_rest();
///
/// // If we wouldn't call `keep_rest()`,
/// // `vec` would be empty.
/// assert_eq!(vec, SmallVec::<char, 2>::from_slice(&['b', 'c']));
/// ```
pub fn keep_rest(self) {
// At this moment layout looks like this:
//
// _____________________/-- old_len
// / \
// [kept] [yielded] [tail]
// \_______/ ^-- idx
// \-- del
//
// Normally `Drop` impl would drop [tail] (via .for_each(drop), ie still calling `pred`)
//
// 1. Move [tail] after [kept]
// 2. Update length of the original vec to `old_len - del`
// a. In case of ZST, this is the only thing we want to do
// 3. Do *not* drop self, as everything is put in a consistent state already, there is nothing to do
let mut this = ManuallyDrop::new(self);

unsafe {
// ZSTs have no identity, so we don't need to move them around.
let needs_move = core::mem::size_of::<T>() != 0;

if needs_move && this.idx < this.old_len && this.del > 0 {
let ptr = this.vec.as_mut_ptr();
let src = ptr.add(this.idx);
let dst = src.sub(this.del);
let tail_len = this.old_len - this.idx;
if self.idx < self.old_len && self.del > 0 {
// This is a pretty messed up state, and there isn't really an
// obviously right thing to do. We don't want to keep trying
// to execute `pred`, so we just backshift all the unprocessed
// elements and tell the vec that they still exist. The backshift
// is required to prevent a double-drop of the last successfully
// drained item prior to a panic in the predicate.
let ptr = self.vec.as_mut_ptr();
let src = ptr.add(self.idx);
let dst = src.sub(self.del);
let tail_len = self.old_len - self.idx;
src.copy_to(dst, tail_len);
}

let new_len = this.old_len - this.del;
this.vec.set_len(new_len);
self.vec.set_len(self.old_len - self.del);
}
}
}
Expand Down Expand Up @@ -961,11 +871,18 @@ impl<T, const N: usize> SmallVec<T, N> {
}
}

#[cfg(feature = "drain_filter")]
#[cfg(feature = "extract_if")]
/// Creates an iterator which uses a closure to determine if an element should be removed.
///
/// If the closure returns true, the element is removed and yielded. If the closure returns
/// false, the element will remain in the vector and will not be yielded by the iterator.
/// If the closure returns true, the element is removed and yielded.
/// If the closure returns false, the element will remain in the vector and will not be yielded
/// by the iterator.
///
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
///
/// [`retain`]: SmallVec::retain
///
/// Using this method is equivalent to the following code:
/// ```
Expand All @@ -984,11 +901,11 @@ impl<T, const N: usize> SmallVec<T, N> {
///
/// # assert_eq!(vec, SmallVec::<i32, 8>::from_slice(&[1i32, 4, 5]));
/// ```
/// ///
/// But `drain_filter` is easier to use. `drain_filter` is also more efficient,
///
/// But `extract_if` is easier to use. `extract_if` is also more efficient,
/// because it can backshift the elements of the array in bulk.
///
/// Note that `drain_filter` also lets you mutate every element in the filter closure,
/// Note that `extract_if` also lets you mutate every element in the filter closure,
/// regardless of whether you choose to keep or remove it.
///
/// # Examples
Expand All @@ -999,13 +916,13 @@ impl<T, const N: usize> SmallVec<T, N> {
/// # use smallvec::SmallVec;
/// let mut numbers: SmallVec<i32, 16> = SmallVec::from_slice(&[1i32, 2, 3, 4, 5, 6, 8, 9, 11, 13, 14, 15]);
///
/// let evens = numbers.drain_filter(|x| *x % 2 == 0).collect::<SmallVec<i32, 16>>();
/// let evens = numbers.extract_if(|x| *x % 2 == 0).collect::<SmallVec<i32, 16>>();
/// let odds = numbers;
///
/// assert_eq!(evens, SmallVec::<i32, 16>::from_slice(&[2i32, 4, 6, 8, 14]));
/// assert_eq!(odds, SmallVec::<i32, 16>::from_slice(&[1i32, 3, 5, 9, 11, 13, 15]));
/// ```
pub fn drain_filter<F>(&mut self, filter: F) -> DrainFilter<'_, T, N, F>
pub fn extract_if<F>(&mut self, filter: F) -> ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
Expand All @@ -1016,13 +933,12 @@ impl<T, const N: usize> SmallVec<T, N> {
self.set_len(0);
}

DrainFilter {
ExtractIf {
vec: self,
idx: 0,
del: 0,
old_len,
pred: filter,
panic_flag: false,
}
}

Expand Down
22 changes: 4 additions & 18 deletions src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{smallvec, SmallVec};

use std::iter::FromIterator;
use core::iter::FromIterator;

use alloc::borrow::ToOwned;
use alloc::boxed::Box;
Expand Down Expand Up @@ -1060,27 +1060,13 @@ fn test_clone_from() {
assert_eq!(&*b, &[20, 21, 22]);
}

#[cfg(feature = "drain_filter")]
#[cfg(feature = "extract_if")]
#[test]
fn drain_filter() {
fn test_extract_if() {
let mut a: SmallVec<u8, 2> = smallvec![1u8, 2, 3, 4, 5, 6, 7, 8];

let b: SmallVec<u8, 2> = a.drain_filter(|x| *x % 3 == 0).collect();
let b: SmallVec<u8, 2> = a.extract_if(|x| *x % 3 == 0).collect();

assert_eq!(a, SmallVec::<u8, 2>::from_slice(&[1u8, 2, 4, 5, 7, 8]));
assert_eq!(b, SmallVec::<u8, 2>::from_slice(&[3u8, 6]));
}

#[cfg(feature = "drain_keep_rest")]
#[test]
fn drain_keep_rest() {
let mut a: SmallVec<i32, 3> = smallvec![1i32, 2, 3, 4, 5, 6, 7, 8];
let mut df = a.drain_filter(|x| *x % 2 == 0);

assert_eq!(df.next().unwrap(), 2);
assert_eq!(df.next().unwrap(), 4);

df.keep_rest();

assert_eq!(a, SmallVec::<i32, 3>::from_slice(&[1i32, 3, 5, 6, 7, 8]));
}