Skip to content

Commit

Permalink
FEAT: Add method squeeze_into
Browse files Browse the repository at this point in the history
This method can squeeze into a particular dimensionality.

Squeezing means removing axes of length 1. When squeezing to a
particular dimensionality, we may have to still pad out the shape with
extra 1-shape axes to fill the dimensionality.
  • Loading branch information
bluss committed Dec 6, 2021
1 parent f31add8 commit 2a9ba60
Showing 1 changed file with 134 additions and 9 deletions.
143 changes: 134 additions & 9 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,37 +785,77 @@ where
}

/// Remove axes with length one, except never removing the last axis.
///
/// This function is a no-op for const dim.
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
if let Some(_) = D::NDIM {
return;
}

// infallible for dyn dim
let (d, s) = squeeze_into(dim, strides).unwrap();
*dim = d;
*strides = s;
}

/// Remove axes with length one, except never removing the last axis.
///
/// Return an error if there are more non-unitary dimensions than can be stored
/// in `E`. Infallible for dyn dim.
///
/// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
/// dynamic 0D, the output can be too.
///
/// For const dim, this may instead pad the dimensionality with ones if it needs
/// to grow to fill the target dimensionality; the dimension is padded in the
/// start.
pub(crate) fn squeeze_into<D, E>(dim: &D, strides: &D) -> Result<(E, E), ShapeError>
where
D: Dimension,
E: Dimension,
{
debug_assert_eq!(dim.ndim(), strides.ndim());

// Count axes with dim == 1; we keep axes with d == 0 or d > 1
let mut ndim_new = 0;
for &d in dim.slice() {
if d != 1 { ndim_new += 1; }
}
ndim_new = Ord::max(1, ndim_new);
let mut new_dim = D::zeros(ndim_new);
let mut new_strides = D::zeros(ndim_new);
let mut fill_ones = 0;
if let Some(e_ndim) = E::NDIM {
if e_ndim < ndim_new {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
}
fill_ones = e_ndim - ndim_new;
ndim_new = e_ndim;
} else {
// dynamic-dimensional
// use minimum one dimension unless input has less than one dim
if dim.ndim() > 0 && ndim_new == 0 {
ndim_new = 1;
fill_ones = 1;
}
}

let mut new_dim = E::zeros(ndim_new);
let mut new_strides = E::zeros(ndim_new);
let mut i = 0;
while i < fill_ones {
new_dim[i] = 1;
new_strides[i] = 1;
i += 1;
}
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
if d != 1 {
new_dim[i] = d;
new_strides[i] = s;
i += 1;
}
}
if i == 0 {
new_dim[i] = 1;
new_strides[i] = 1;
}
*dim = new_dim;
*strides = new_strides;
Ok((new_dim, new_strides))
}


Expand Down Expand Up @@ -1220,6 +1260,91 @@ mod test {
assert_eq!(s, sans);
}

#[test]
#[cfg(feature = "std")]
fn test_squeeze_into() {
use super::squeeze_into;

let dyndim = Dim::<&[usize]>;

// squeeze to ixdyn
let d = dyndim(&[1, 2, 1, 1, 3, 1]);
let s = dyndim(&[!0, !0, !0, 9, 10, !0]);
let dans = dyndim(&[2, 3]);
let sans = dyndim(&[!0, 10]);
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
assert_eq!(d2, dans);
assert_eq!(s2, sans);

// squeeze to ixdyn does not go below 1D
let d = dyndim(&[1, 1]);
let s = dyndim(&[3, 4]);
let dans = dyndim(&[1]);
let sans = dyndim(&[1]);
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
assert_eq!(d2, dans);
assert_eq!(s2, sans);

let d = Dim([1, 1]);
let s = Dim([3, 4]);
let dans = Dim([1]);
let sans = Dim([1]);
let (d2, s2) = squeeze_into::<_, Ix1>(&d, &s).unwrap();
assert_eq!(d2, dans);
assert_eq!(s2, sans);

// squeeze to zero-dim
let (d2, s2) = squeeze_into::<_, Ix0>(&d, &s).unwrap();
assert_eq!(d2, Ix0());
assert_eq!(s2, Ix0());

let d = Dim([0, 1, 3, 4]);
let s = Dim([2, 3, 4, 5]);
let dans = Dim([0, 3, 4]);
let sans = Dim([2, 4, 5]);
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
assert_eq!(d2, dans);
assert_eq!(s2, sans);

// Pad with ones
let d = Dim([0, 1, 3, 1]);
let s = Dim([2, 3, 4, 5]);
let dans = Dim([1, 0, 3]);
let sans = Dim([1, 2, 4]);
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
assert_eq!(d2, dans);
assert_eq!(s2, sans);

// Try something that doesn't fit
let d = Dim([0, 1, 3, 1]);
let s = Dim([2, 3, 4, 5]);
let res = squeeze_into::<_, Ix1>(&d, &s);
assert!(res.is_err());
let res = squeeze_into::<_, Ix0>(&d, &s);
assert!(res.is_err());

// Squeeze 0d to 0d
let d = Dim([]);
let s = Dim([]);
let res = squeeze_into::<_, Ix0>(&d, &s);
assert!(res.is_ok());
// grow 0d to 2d
let dans = Dim([1, 1]);
let sans = Dim([1, 1]);
let (d2, s2) = squeeze_into::<_, Ix2>(&d, &s).unwrap();
assert_eq!(d2, dans);
assert_eq!(s2, sans);

// Squeeze 0d to 0d dynamic
let d = dyndim(&[]);
let s = dyndim(&[]);
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
let dans = d;
let sans = s;
assert_eq!(d2, dans);
assert_eq!(s2, sans);
}

#[test]
fn test_merge_axes_from_the_back() {
let dyndim = Dim::<&[usize]>;
Expand Down

0 comments on commit 2a9ba60

Please sign in to comment.