Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infinite recursion and off-by-one error in triu/tril #1418

Merged
merged 6 commits into from
Aug 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 129 additions & 54 deletions src/tri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use core::cmp::{max, min};
use core::cmp::min;

use num_traits::Zero;

use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip};
use crate::{
dimension::{is_layout_c, is_layout_f},
Array,
ArrayBase,
Axis,
Data,
Dimension,
Zip,
};

impl<S, A, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
A: Clone + Zero,
D::Smaller: Copy,
{
/// Upper triangular of an array.
///
Expand All @@ -30,38 +37,56 @@ where
/// ```
bluss marked this conversation as resolved.
Show resolved Hide resolved
/// use ndarray::array;
///
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
/// let res = arr.triu(0);
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
/// let arr = array![
/// [1, 2, 3],
/// [4, 5, 6],
/// [7, 8, 9]
/// ];
/// assert_eq!(
/// arr.triu(0),
/// array![
/// [1, 2, 3],
/// [0, 5, 6],
/// [0, 0, 9]
/// ]
/// );
bluss marked this conversation as resolved.
Show resolved Hide resolved
/// ```
pub fn triu(&self, k: isize) -> Array<A, D>
{
if self.ndim() <= 1 {
return self.to_owned();
}
match is_layout_f(&self.dim, &self.strides) {
true => {
let n = self.ndim();
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.tril(-k);
tril.swap_axes(n - 2, n - 1);

tril
}
false => {
let mut res = Array::zeros(self.raw_dim());
Zip::indexed(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
let lower = max(row_num as isize + k, 0);
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
});

res
}

// Performance optimization for F-order arrays.
// C-order array check prevents infinite recursion in edge cases like [[1]].
// k-size check prevents underflow when k == isize::MIN
let n = self.ndim();
if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.tril(-k);
bluss marked this conversation as resolved.
Show resolved Hide resolved
tril.swap_axes(n - 2, n - 1);

return tril;
}

let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(n - 1));
let nrows = self.len_of(Axis(n - 2));
let indices = Array::from_iter(0..nrows);
Zip::from(self.rows())
.and(res.rows_mut())
.and_broadcast(&indices)
.for_each(|src, mut dst, row_num| {
let mut lower = match k >= 0 {
true => row_num.saturating_add(k as usize), // Avoid overflow
false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0
};
lower = min(lower, ncols);
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
bluss marked this conversation as resolved.
Show resolved Hide resolved
});

res
}

/// Lower triangular of an array.
Expand All @@ -75,45 +100,65 @@ where
/// ```
/// use ndarray::array;
///
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
/// let res = arr.tril(0);
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
/// let arr = array![
/// [1, 2, 3],
/// [4, 5, 6],
/// [7, 8, 9]
/// ];
/// assert_eq!(
/// arr.tril(0),
/// array![
/// [1, 0, 0],
/// [4, 5, 0],
/// [7, 8, 9]
/// ]
/// );
/// ```
pub fn tril(&self, k: isize) -> Array<A, D>
{
if self.ndim() <= 1 {
return self.to_owned();
}
match is_layout_f(&self.dim, &self.strides) {
true => {
let n = self.ndim();
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.triu(-k);
tril.swap_axes(n - 2, n - 1);

tril
}
false => {
let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(self.ndim() - 1)) as isize;
Zip::indexed(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
let upper = min(row_num as isize + k, ncols) + 1;
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
});

res
}

// Performance optimization for F-order arrays.
// C-order array check prevents infinite recursion in edge cases like [[1]].
// k-size check prevents underflow when k == isize::MIN
let n = self.ndim();
if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.triu(-k);
tril.swap_axes(n - 2, n - 1);

return tril;
}

let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(n - 1));
let nrows = self.len_of(Axis(n - 2));
let indices = Array::from_iter(0..nrows);
Zip::from(self.rows())
.and(res.rows_mut())
.and_broadcast(&indices)
.for_each(|src, mut dst, row_num| {
// let row_num = i.into_dimension().last_elem();
let mut upper = match k >= 0 {
true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow
false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow
};
upper = min(upper, ncols);
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
});

res
}
}

#[cfg(test)]
mod tests
{
use core::isize;

use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
use alloc::vec;

Expand Down Expand Up @@ -188,6 +233,19 @@ mod tests
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
}

#[test]
fn test_2d_single()
{
let x = array![[1]];

assert_eq!(x.triu(0), array![[1]]);
assert_eq!(x.tril(0), array![[1]]);
assert_eq!(x.triu(1), array![[0]]);
assert_eq!(x.tril(1), array![[1]]);
assert_eq!(x.triu(-1), array![[1]]);
assert_eq!(x.tril(-1), array![[0]]);
}

#[test]
fn test_3d()
{
Expand Down Expand Up @@ -285,8 +343,25 @@ mod tests
let res = x.triu(0);
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);

let res = x.tril(0);
assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);

let x = array![[1, 2], [3, 4], [5, 6]];
let res = x.triu(0);
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);

let res = x.tril(0);
assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
}

#[test]
fn test_odd_k()
{
bluss marked this conversation as resolved.
Show resolved Hide resolved
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
let z = Array2::zeros([3, 3]);
assert_eq!(x.triu(isize::MIN), x);
assert_eq!(x.tril(isize::MIN), z);
assert_eq!(x.triu(isize::MAX), z);
assert_eq!(x.tril(isize::MAX), x);
}
}