Skip to content

Commit

Permalink
Adds triu and tril methods that mimic NumPy.
Browse files Browse the repository at this point in the history
Includes branched implementations for f- and c-order arrays.
  • Loading branch information
akern40 committed May 22, 2024
1 parent e734ce8 commit 6316fbe
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1616,3 +1616,6 @@ pub(crate) fn is_aligned<T>(ptr: *const T) -> bool
{
(ptr as usize) % ::std::mem::align_of::<T>() == 0
}

// Triangular constructors
mod tri;
280 changes: 280 additions & 0 deletions src/tri.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
// Copyright 2014-2024 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use core::cmp::{max, min};

use num_traits::Zero;

use crate::{
dimension::is_layout_f,
Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip,
};

impl<S, A, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
A: Clone + Zero,
D::Smaller: Copy,
{
/// Upper triangular of an array.
///
/// Return a copy of the array with elements below the *k*-th diagonal zeroed.
/// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes.
/// For 0D and 1D arrays, `triu` will return an unchanged clone.
///
/// See also [`ArrayBase::tril`]
///
/// ```
/// use ndarray::array;
///
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
/// let res = arr.triu(0);
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
/// ```
pub fn triu(&self, k: isize) -> Array<A, D> {
match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) {
true => {
let n = self.ndim();
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.tril(-k);
tril.swap_axes(n - 2, n - 1);

tril
},
false => {
let mut res = Array::zeros(self.raw_dim());
Zip::indexed(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
let lower = max(row_num as isize + k, 0);
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
});

res
}
}
}

/// Lower triangular of an array.
///
/// Return a copy of the array with elements above the *k*-th diagonal zeroed.
/// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes.
/// For 0D and 1D arrays, `tril` will return an unchanged clone.
///
/// See also [`ArrayBase::triu`]
///
/// ```
/// use ndarray::array;
///
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
/// let res = arr.tril(0);
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
/// ```
pub fn tril(&self, k: isize) -> Array<A, D> {
match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) {
true => {
let n = self.ndim();
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.triu(-k);
tril.swap_axes(n - 2, n - 1);

tril
},
false => {
let mut res = Array::zeros(self.raw_dim());
Zip::indexed(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
// This ncols must go inside the loop to avoid panic on 1D arrays.
// Statistically-neglible difference in performance vs defining ncols at top.
let ncols = src.len_of(Axis(src.ndim() - 1)) as isize;
let row_num = i.into_dimension().last_elem();
let upper = min(row_num as isize + k, ncols) + 1;
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
});

res
}
}
}
}

#[cfg(test)]
mod tests {
use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};

#[test]
fn test_keep_order() {
let x = Array2::<f64>::ones((3, 3).f());
let res = x.triu(0);
assert!(dimension::is_layout_f(&res.dim, &res.strides));

let res = x.tril(0);
assert!(dimension::is_layout_f(&res.dim, &res.strides));
}

#[test]
fn test_0d() {
let x = Array0::<f64>::ones(());
let res = x.triu(0);
assert_eq!(res, x);

let res = x.tril(0);
assert_eq!(res, x);

let x = Array0::<f64>::ones(().f());
let res = x.triu(0);
assert_eq!(res, x);

let res = x.tril(0);
assert_eq!(res, x);
}

#[test]
fn test_1d() {
let x = array![1, 2, 3];
let res = x.triu(0);
assert_eq!(res, x);

let res = x.triu(0);
assert_eq!(res, x);

let x = Array1::<f64>::ones(3.f());
let res = x.triu(0);
assert_eq!(res, x);

let res = x.triu(0);
assert_eq!(res, x);
}

#[test]
fn test_2d() {
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];

// Upper
let res = x.triu(0);
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);

// Lower
let res = x.tril(0);
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);

let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();

// Upper
let res = x.triu(0);
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);

// Lower
let res = x.tril(0);
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
}

#[test]
fn test_3d() {
let x = array![
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
];

// Upper
let res = x.triu(0);
assert_eq!(
res,
array![
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
]
);

// Lower
let res = x.tril(0);
assert_eq!(
res,
array![
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
]
);

let x = Array3::from_shape_vec(
(3, 3, 3).f(),
vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
)
.unwrap();

// Upper
let res = x.triu(0);
assert_eq!(
res,
array![
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
]
);

// Lower
let res = x.tril(0);
assert_eq!(
res,
array![
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
]
);
}

#[test]
fn test_off_axis() {
let x = array![
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
];

let res = x.triu(1);
assert_eq!(
res,
array![
[[0, 2, 3], [0, 0, 6], [0, 0, 0]],
[[0, 11, 12], [0, 0, 15], [0, 0, 0]],
[[0, 20, 21], [0, 0, 24], [0, 0, 0]]
]
);

let res = x.triu(-1);
assert_eq!(
res,
array![
[[1, 2, 3], [4, 5, 6], [0, 8, 9]],
[[10, 11, 12], [13, 14, 15], [0, 17, 18]],
[[19, 20, 21], [22, 23, 24], [0, 26, 27]]
]
);
}

#[test]
fn test_odd_shape() {
let x = array![[1, 2, 3], [4, 5, 6]];
let res = x.triu(0);
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);

let x = array![[1, 2], [3, 4], [5, 6]];
let res = x.triu(0);
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
}
}

0 comments on commit 6316fbe

Please sign in to comment.