From 1e01e69fbe8976bb94f08934a046a55d6672263e Mon Sep 17 00:00:00 2001 From: Muhammad Ragib Hasin Date: Mon, 28 Oct 2024 22:35:57 +0600 Subject: [PATCH] Add nalgebra compatibility for Dual2 and Dual2Vec --- src/dual2.rs | 706 ++++++++++++++++++++++++++++++++++++++++++++- src/dual2_vec.rs | 727 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 1430 insertions(+), 3 deletions(-) diff --git a/src/dual2.rs b/src/dual2.rs index 8373fe9..1b980de 100644 --- a/src/dual2.rs +++ b/src/dual2.rs @@ -1,4 +1,6 @@ use crate::{DualNum, DualNumFloat}; +use approx::{AbsDiffEq, RelativeEq, UlpsEq}; +use nalgebra::*; use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -11,7 +13,7 @@ use std::ops::{ }; /// A scalar second order dual number for the calculation of second derivatives. -#[derive(PartialEq, Eq, Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Dual2, F> { /// Real part of the second order dual number @@ -174,3 +176,705 @@ impl, F: fmt::Display> fmt::Display for Dual2 { impl_second_derivatives!(Dual2, [v1, v2]); impl_dual!(Dual2, [v1, v2]); + +/** + * The SimdValue trait is for rearranging data into a form more suitable for Simd, + * and rearranging it back into a usable form. It is not documented particularly well. + * + * The primary job of this SimdValue impl is to allow people to use `simba::simd::f32x4` etc, + * instead of f32/f64. Those types implement nalgebra::SimdRealField/ComplexField, so they + * behave like scalars. When we use them, we would have `Dual` etc, with our + * F parameter set to `::Element`. We will need to be able to split up that type + * into four of Dual in order to get out of simd-land. That's what the SimdValue trait is for. + * + * Ultimately, someone will have to to implement SimdRealField on Dual and call the + * simd_ functions of ``. That's future work for someone who finds + * num_dual is not fast enough. + * + * Unfortunately, doing anything with SIMD is blocked on + * . + * + */ +impl nalgebra::SimdValue for Dual2 +where + T: DualNum + SimdValue + Scalar, + T::Element: DualNum + Scalar, +{ + // Say T = simba::f32x4. T::Element is f32. T::SimdBool is AutoSimd<[bool; 4]>. + // AutoSimd<[f32; 4]> stores an actual [f32; 4], i.e. four floats in one slot. + // So our Dual has 4 * (1+N) floats in it, stored in blocks of + // four. When we want to do any math on it but ignore its f32x4 storage mode, we need to break + // that type into FOUR of Dual; then we do math on it, then we bring it back + // together. + // + // Hence this definition of Element: + type Element = Dual2; + type SimdBool = T::SimdBool; + + const LANES: usize = T::LANES; + + #[inline] + fn splat(val: Self::Element) -> Self { + // Need to make `lanes` copies of each of: + // - the real part + // - each of the N epsilon parts + let re = T::splat(val.re); + let v1 = T::splat(val.v1); + let v2 = T::splat(val.v2); + Self::new(re, v1, v2) + } + + #[inline] + fn extract(&self, i: usize) -> Self::Element { + let re = self.re.extract(i); + let v1 = self.v1.extract(i); + let v2 = self.v2.extract(i); + Self::Element { + re, + v1, + v2, + f: PhantomData, + } + } + + #[inline] + unsafe fn extract_unchecked(&self, i: usize) -> Self::Element { + let re = self.re.extract_unchecked(i); + let v1 = self.v1.extract_unchecked(i); + let v2 = self.v2.extract_unchecked(i); + Self::Element { + re, + v1, + v2, + f: PhantomData, + } + } + + #[inline] + fn replace(&mut self, i: usize, val: Self::Element) { + self.re.replace(i, val.re); + self.v1.replace(i, val.v1); + self.v2.replace(i, val.v2); + } + + #[inline] + unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) { + self.re.replace_unchecked(i, val.re); + self.v1.replace_unchecked(i, val.v1); + self.v2.replace_unchecked(i, val.v2); + } + + #[inline] + fn select(self, cond: Self::SimdBool, other: Self) -> Self { + let re = self.re.select(cond, other.re); + let v1 = self.v1.select(cond, other.v1); + let v2 = self.v2.select(cond, other.v2); + Self::new(re, v1, v2) + } +} + +/// Comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + PartialEq, F: Float> PartialEq for Dual2 { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.re.eq(&other.re) + } +} +/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + PartialOrd, F: Float> PartialOrd for Dual2 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + self.re.partial_cmp(&other.re) + } +} +/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + approx::AbsDiffEq, F: Float> approx::AbsDiffEq for Dual2 { + type Epsilon = Self; + #[inline] + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + self.re.abs_diff_eq(&other.re, epsilon.re) + } + + #[inline] + fn default_epsilon() -> Self::Epsilon { + Self::from_re(T::default_epsilon()) + } +} +/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + approx::RelativeEq, F: Float> approx::RelativeEq for Dual2 { + #[inline] + fn default_max_relative() -> Self::Epsilon { + Self::from_re(T::default_max_relative()) + } + + #[inline] + fn relative_eq( + &self, + other: &Self, + epsilon: Self::Epsilon, + max_relative: Self::Epsilon, + ) -> bool { + self.re.relative_eq(&other.re, epsilon.re, max_relative.re) + } +} +impl + UlpsEq, F: Float> UlpsEq for Dual2 { + #[inline] + fn default_max_ulps() -> u32 { + T::default_max_ulps() + } + + #[inline] + fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { + T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps) + } +} + +impl nalgebra::Field for Dual2 +where + T: DualNum + SimdValue, + T::Element: DualNum + Scalar + Float, +{ +} + +use simba::scalar::{SubsetOf, SupersetOf}; + +impl SubsetOf> for Dual2 +where + TSuper: DualNum + SupersetOf, + T: DualNum, +{ + #[inline(always)] + fn to_superset(&self) -> Dual2 { + let re = TSuper::from_subset(&self.re); + let v1 = TSuper::from_subset(&self.v1); + let v2 = TSuper::from_subset(&self.v2); + Dual2 { + re, + v1, + v2, + f: PhantomData, + } + } + #[inline(always)] + fn from_superset(element: &Dual2) -> Option { + let re = TSuper::to_subset(&element.re)?; + let v1 = TSuper::to_subset(&element.v1)?; + let v2 = TSuper::to_subset(&element.v2)?; + Some(Self::new(re, v1, v2)) + } + #[inline(always)] + fn from_superset_unchecked(element: &Dual2) -> Self { + let re = TSuper::to_subset_unchecked(&element.re); + let v1 = TSuper::to_subset_unchecked(&element.v1); + let v2 = TSuper::to_subset_unchecked(&element.v2); + Self::new(re, v1, v2) + } + #[inline(always)] + fn is_in_subset(element: &Dual2) -> bool { + TSuper::is_in_subset(&element.re) + && TSuper::is_in_subset(&element.v1) + && TSuper::is_in_subset(&element.v2) + } +} + +impl SupersetOf for Dual2 +where + TSuper: DualNum + SupersetOf, +{ + #[inline(always)] + fn is_in_subset(&self) -> bool { + self.re.is_in_subset() + } + + #[inline(always)] + fn to_subset_unchecked(&self) -> f32 { + self.re.to_subset_unchecked() + } + + #[inline(always)] + fn from_subset(element: &f32) -> Self { + // Interpret as a purely real number + let re = TSuper::from_subset(element); + let v1 = TSuper::zero(); + let v2 = TSuper::zero(); + Self::new(re, v1, v2) + } +} + +impl SupersetOf for Dual2 +where + TSuper: DualNum + SupersetOf, +{ + #[inline(always)] + fn is_in_subset(&self) -> bool { + self.re.is_in_subset() + } + + #[inline(always)] + fn to_subset_unchecked(&self) -> f64 { + self.re.to_subset_unchecked() + } + + #[inline(always)] + fn from_subset(element: &f64) -> Self { + // Interpret as a purely real number + let re = TSuper::from_subset(element); + let v1 = TSuper::zero(); + let v2 = TSuper::zero(); + Self::new(re, v1, v2) + } +} + +// We can't do a simd implementation until simba lets us implement SimdPartialOrd +// using _T_'s SimdBool. The blanket impl gets in the way. So we must constrain +// T to SimdValue, which is basically the same as +// saying f32 or f64 only. +// +// Limitation of simba. See https://github.com/dimforge/simba/issues/44 + +use nalgebra::{ComplexField, RealField}; +// This impl is modelled on `impl ComplexField for f32`. The imaginary part is nothing. +impl ComplexField for Dual2 +where + T: DualNum + SupersetOf + AbsDiffEq + Sync + Send, + T::Element: DualNum + Scalar + DualNumFloat + Sync + Send, + T: SupersetOf, + T: SupersetOf, + T: SupersetOf, + T: SimdPartialOrd + PartialOrd, + T: SimdValue, + T: RelativeEq + UlpsEq + AbsDiffEq, +{ + type RealField = Self; + + #[inline] + fn from_real(re: Self::RealField) -> Self { + re + } + + #[inline] + fn real(self) -> Self::RealField { + self + } + + #[inline] + fn imaginary(self) -> Self::RealField { + Self::zero() + } + + #[inline] + fn modulus(self) -> Self::RealField { + self.abs() + } + + #[inline] + fn modulus_squared(self) -> Self::RealField { + self * self + } + + #[inline] + fn argument(self) -> Self::RealField { + Self::zero() + } + + #[inline] + fn norm1(self) -> Self::RealField { + self.abs() + } + + #[inline] + fn scale(self, factor: Self::RealField) -> Self { + self * factor + } + + #[inline] + fn unscale(self, factor: Self::RealField) -> Self { + self / factor + } + + #[inline] + fn floor(self) -> Self { + panic!("called floor() on a dual number") + } + + #[inline] + fn ceil(self) -> Self { + panic!("called ceil() on a dual number") + } + + #[inline] + fn round(self) -> Self { + panic!("called round() on a dual number") + } + + #[inline] + fn trunc(self) -> Self { + panic!("called trunc() on a dual number") + } + + #[inline] + fn fract(self) -> Self { + panic!("called fract() on a dual number") + } + + #[inline] + fn mul_add(self, a: Self, b: Self) -> Self { + DualNum::mul_add(&self, a, b) + } + + #[inline] + fn abs(self) -> Self::RealField { + Signed::abs(&self) + } + + #[inline] + fn hypot(self, other: Self) -> Self::RealField { + let sum_sq = self.powi(2) + other.powi(2); + DualNum::sqrt(&sum_sq) + } + + #[inline] + fn recip(self) -> Self { + DualNum::recip(&self) + } + + #[inline] + fn conjugate(self) -> Self { + self + } + + #[inline] + fn sin(self) -> Self { + DualNum::sin(&self) + } + + #[inline] + fn cos(self) -> Self { + DualNum::cos(&self) + } + + #[inline] + fn sin_cos(self) -> (Self, Self) { + DualNum::sin_cos(&self) + } + + #[inline] + fn tan(self) -> Self { + DualNum::tan(&self) + } + + #[inline] + fn asin(self) -> Self { + DualNum::asin(&self) + } + + #[inline] + fn acos(self) -> Self { + DualNum::acos(&self) + } + + #[inline] + fn atan(self) -> Self { + DualNum::atan(&self) + } + + #[inline] + fn sinh(self) -> Self { + DualNum::sinh(&self) + } + + #[inline] + fn cosh(self) -> Self { + DualNum::cosh(&self) + } + + #[inline] + fn tanh(self) -> Self { + DualNum::tanh(&self) + } + + #[inline] + fn asinh(self) -> Self { + DualNum::asinh(&self) + } + + #[inline] + fn acosh(self) -> Self { + DualNum::acosh(&self) + } + + #[inline] + fn atanh(self) -> Self { + DualNum::atanh(&self) + } + + #[inline] + fn log(self, base: Self::RealField) -> Self { + DualNum::ln(&self) / DualNum::ln(&base) + } + + #[inline] + fn log2(self) -> Self { + DualNum::log2(&self) + } + + #[inline] + fn log10(self) -> Self { + DualNum::log10(&self) + } + + #[inline] + fn ln(self) -> Self { + DualNum::ln(&self) + } + + #[inline] + fn ln_1p(self) -> Self { + DualNum::ln_1p(&self) + } + + #[inline] + fn sqrt(self) -> Self { + DualNum::sqrt(&self) + } + + #[inline] + fn exp(self) -> Self { + DualNum::exp(&self) + } + + #[inline] + fn exp2(self) -> Self { + DualNum::exp2(&self) + } + + #[inline] + fn exp_m1(self) -> Self { + DualNum::exp_m1(&self) + } + + #[inline] + fn powi(self, n: i32) -> Self { + DualNum::powi(&self, n) + } + + #[inline] + fn powf(self, n: Self::RealField) -> Self { + // n could be a dual. + DualNum::powd(&self, n) + } + + #[inline] + fn powc(self, n: Self) -> Self { + // same as powf, Self isn't complex + self.powf(n) + } + + #[inline] + fn cbrt(self) -> Self { + DualNum::cbrt(&self) + } + + #[inline] + fn is_finite(&self) -> bool { + self.re.is_finite() + } + + #[inline] + fn try_sqrt(self) -> Option { + if self > Self::zero() { + Some(DualNum::sqrt(&self)) + } else { + None + } + } +} + +impl RealField for Dual2 +where + T: DualNum + SupersetOf + Sync + Send, + T::Element: DualNum + Scalar + DualNumFloat, + T: SupersetOf, + T: SupersetOf, + T: SupersetOf, + T: SimdPartialOrd + PartialOrd, + T: RelativeEq + AbsDiffEq, + T: SimdValue, + T: UlpsEq, + T: AbsDiffEq, +{ + #[inline] + fn copysign(self, sign: Self) -> Self { + if sign.re.is_sign_positive() { + self.simd_abs() + } else { + -self.simd_abs() + } + } + + #[inline] + fn atan2(self, other: Self) -> Self { + let re = self.re.atan2(other.re); + let den = self.re.powi(2) + other.re.powi(2); + + let da = other.re / den; + let db = -self.re / den; + let v1 = self.v1 * da + other.v1 * db; + + let daa = db * da * (T::one() + T::one()); + let dab = db * db - da * da; + let dbb = -daa; + let ca = self.v1 * daa + other.v1 * dab; + let cb = self.v1 * dab + other.v1 * dbb; + let v2 = self.v2 * da + other.v2 * db + ca * self.v1 + cb * other.v1; + + Self::new(re, v1, v2) + } + + #[inline] + fn pi() -> Self { + Self::from_re(::PI()) + } + + #[inline] + fn two_pi() -> Self { + Self::from_re(::TAU()) + } + + #[inline] + fn frac_pi_2() -> Self { + Self::from_re(::FRAC_PI_4()) + } + + #[inline] + fn frac_pi_3() -> Self { + Self::from_re(::FRAC_PI_3()) + } + + #[inline] + fn frac_pi_4() -> Self { + Self::from_re(::FRAC_PI_4()) + } + + #[inline] + fn frac_pi_6() -> Self { + Self::from_re(::FRAC_PI_6()) + } + + #[inline] + fn frac_pi_8() -> Self { + Self::from_re(::FRAC_PI_8()) + } + + #[inline] + fn frac_1_pi() -> Self { + Self::from_re(::FRAC_1_PI()) + } + + #[inline] + fn frac_2_pi() -> Self { + Self::from_re(::FRAC_2_PI()) + } + + #[inline] + fn frac_2_sqrt_pi() -> Self { + Self::from_re(::FRAC_2_SQRT_PI()) + } + + #[inline] + fn e() -> Self { + Self::from_re(::E()) + } + + #[inline] + fn log2_e() -> Self { + Self::from_re(::LOG2_E()) + } + + #[inline] + fn log10_e() -> Self { + Self::from_re(::LOG10_E()) + } + + #[inline] + fn ln_2() -> Self { + Self::from_re(::LN_2()) + } + + #[inline] + fn ln_10() -> Self { + Self::from_re(::LN_10()) + } + + #[inline] + fn is_sign_positive(&self) -> bool { + self.re.is_sign_positive() + } + + #[inline] + fn is_sign_negative(&self) -> bool { + self.re.is_sign_negative() + } + + /// Got to be careful using this, because it throws away the derivatives of the one not chosen + #[inline] + fn max(self, other: Self) -> Self { + if other > self { + other + } else { + self + } + } + + /// Got to be careful using this, because it throws away the derivatives of the one not chosen + #[inline] + fn min(self, other: Self) -> Self { + if other < self { + other + } else { + self + } + } + + /// If the min/max values are constants and the clamping has an effect, you lose your gradients. + #[inline] + fn clamp(self, min: Self, max: Self) -> Self { + if self < min { + min + } else if self > max { + max + } else { + self + } + } + + #[inline] + fn min_value() -> Option { + Some(Self::from_re(T::min_value())) + } + + #[inline] + fn max_value() -> Option { + Some(Self::from_re(T::max_value())) + } +} + +#[cfg(test)] +mod test { + use super::*; + use approx::assert_relative_eq; + + #[test] + fn test_atan2() { + let x = Dual2_64::from(2.0).derivative(); + let y = Dual2_64::from(-3.0); + let z = x.atan2(y); + let z2 = (x / y).atan(); + assert_relative_eq!(z.v1, z2.v1, epsilon = 1e-14); + assert_relative_eq!(z.v2, z2.v2, epsilon = 1e-14); + } +} diff --git a/src/dual2_vec.rs b/src/dual2_vec.rs index 8428f32..bb9ec7a 100644 --- a/src/dual2_vec.rs +++ b/src/dual2_vec.rs @@ -1,6 +1,7 @@ use crate::{Derivative, DualNum, DualNumFloat}; +use approx::{AbsDiffEq, RelativeEq, UlpsEq}; use nalgebra::allocator::Allocator; -use nalgebra::{Const, DefaultAllocator, Dim, Dyn, OMatrix, OVector, U1}; +use nalgebra::*; use num_traits::{Float, FloatConst, FromPrimitive, Inv, Num, One, Signed, Zero}; use std::convert::Infallible; use std::fmt; @@ -11,7 +12,7 @@ use std::ops::{ }; /// A vector second order dual number for the calculation of Hessians. -#[derive(PartialEq, Eq, Clone, Debug)] +#[derive(Clone, Debug)] pub struct Dual2Vec, F, D: Dim> where DefaultAllocator: Allocator + Allocator, @@ -184,3 +185,725 @@ where impl_second_derivatives!(Dual2Vec, [v1, v2], [D]); impl_dual!(Dual2Vec, [v1, v2], [D]); + +/** + * The SimdValue trait is for rearranging data into a form more suitable for Simd, + * and rearranging it back into a usable form. It is not documented particularly well. + * + * The primary job of this SimdValue impl is to allow people to use `simba::simd::f32x4` etc, + * instead of f32/f64. Those types implement nalgebra::SimdRealField/ComplexField, so they + * behave like scalars. When we use them, we would have `DualVec` etc, with our + * F parameter set to `::Element`. We will need to be able to split up that type + * into four of DualVec in order to get out of simd-land. That's what the SimdValue trait is for. + * + * Ultimately, someone will have to to implement SimdRealField on DualVec and call the + * simd_ functions of ``. That's future work for someone who finds + * num_dual is not fast enough. + * + * Unfortunately, doing anything with SIMD is blocked on + * . + * + */ +impl nalgebra::SimdValue for Dual2Vec +where + DefaultAllocator: Allocator + Allocator + Allocator, + T: DualNum + SimdValue + Scalar, + T::Element: DualNum + Scalar, +{ + // Say T = simba::f32x4. T::Element is f32. T::SimdBool is AutoSimd<[bool; 4]>. + // AutoSimd<[f32; 4]> stores an actual [f32; 4], i.e. four floats in one slot. + // So our DualVec has 4 * (1+N) floats in it, stored in blocks of + // four. When we want to do any math on it but ignore its f32x4 storage mode, we need to break + // that type into FOUR of DualVec; then we do math on it, then we bring it back + // together. + // + // Hence this definition of Element: + type Element = Dual2Vec; + type SimdBool = T::SimdBool; + + const LANES: usize = T::LANES; + + #[inline] + fn splat(val: Self::Element) -> Self { + // Need to make `lanes` copies of each of: + // - the real part + // - each of the N epsilon parts + let re = T::splat(val.re); + let v1 = Derivative::splat(val.v1); + let v2 = Derivative::splat(val.v2); + Self::new(re, v1, v2) + } + + #[inline] + fn extract(&self, i: usize) -> Self::Element { + let re = self.re.extract(i); + let v1 = self.v1.extract(i); + let v2 = self.v2.extract(i); + Self::Element { + re, + v1, + v2, + f: PhantomData, + } + } + + #[inline] + unsafe fn extract_unchecked(&self, i: usize) -> Self::Element { + let re = self.re.extract_unchecked(i); + let v1 = self.v1.extract_unchecked(i); + let v2 = self.v2.extract_unchecked(i); + Self::Element { + re, + v1, + v2, + f: PhantomData, + } + } + + #[inline] + fn replace(&mut self, i: usize, val: Self::Element) { + self.re.replace(i, val.re); + self.v1.replace(i, val.v1); + self.v2.replace(i, val.v2); + } + + #[inline] + unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) { + self.re.replace_unchecked(i, val.re); + self.v1.replace_unchecked(i, val.v1); + self.v2.replace_unchecked(i, val.v2); + } + + #[inline] + fn select(self, cond: Self::SimdBool, other: Self) -> Self { + let re = self.re.select(cond, other.re); + let v1 = self.v1.select(cond, other.v1); + let v2 = self.v2.select(cond, other.v2); + Self::new(re, v1, v2) + } +} + +/// Comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + PartialEq, F: Float, D: Dim> PartialEq for Dual2Vec +where + DefaultAllocator: Allocator + Allocator + Allocator, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.re.eq(&other.re) + } +} +/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + PartialOrd, F: Float, D: Dim> PartialOrd for Dual2Vec +where + DefaultAllocator: Allocator + Allocator + Allocator, +{ + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + self.re.partial_cmp(&other.re) + } +} +/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + approx::AbsDiffEq, F: Float, D: Dim> approx::AbsDiffEq + for Dual2Vec +where + DefaultAllocator: Allocator + Allocator + Allocator, +{ + type Epsilon = Self; + #[inline] + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + self.re.abs_diff_eq(&other.re, epsilon.re) + } + + #[inline] + fn default_epsilon() -> Self::Epsilon { + Self::from_re(T::default_epsilon()) + } +} +/// Like PartialEq, comparisons are only made based on the real part. This allows the code to follow the +/// same execution path as real-valued code would. +impl + approx::RelativeEq, F: Float, D: Dim> approx::RelativeEq + for Dual2Vec +where + DefaultAllocator: Allocator + Allocator + Allocator, +{ + #[inline] + fn default_max_relative() -> Self::Epsilon { + Self::from_re(T::default_max_relative()) + } + + #[inline] + fn relative_eq( + &self, + other: &Self, + epsilon: Self::Epsilon, + max_relative: Self::Epsilon, + ) -> bool { + self.re.relative_eq(&other.re, epsilon.re, max_relative.re) + } +} +impl + UlpsEq, F: Float, D: Dim> UlpsEq for Dual2Vec +where + DefaultAllocator: Allocator + Allocator + Allocator, +{ + #[inline] + fn default_max_ulps() -> u32 { + T::default_max_ulps() + } + + #[inline] + fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { + T::ulps_eq(&self.re, &other.re, epsilon.re, max_ulps) + } +} + +impl nalgebra::Field for Dual2Vec +where + T: DualNum + SimdValue, + T::Element: DualNum + Scalar + Float, + DefaultAllocator: Allocator + Allocator + Allocator + Allocator, +{ +} + +use simba::scalar::{SubsetOf, SupersetOf}; + +impl SubsetOf> for Dual2Vec +where + TSuper: DualNum + SupersetOf, + T: DualNum, + DefaultAllocator: Allocator + Allocator + Allocator + Allocator, +{ + #[inline(always)] + fn to_superset(&self) -> Dual2Vec { + let re = TSuper::from_subset(&self.re); + let v1 = Derivative::from_subset(&self.v1); + let v2 = Derivative::from_subset(&self.v2); + Dual2Vec { + re, + v1, + v2, + f: PhantomData, + } + } + #[inline(always)] + fn from_superset(element: &Dual2Vec) -> Option { + let re = TSuper::to_subset(&element.re)?; + let v1 = Derivative::to_subset(&element.v1)?; + let v2 = Derivative::to_subset(&element.v2)?; + Some(Self::new(re, v1, v2)) + } + #[inline(always)] + fn from_superset_unchecked(element: &Dual2Vec) -> Self { + let re = TSuper::to_subset_unchecked(&element.re); + let v1 = Derivative::to_subset_unchecked(&element.v1); + let v2 = Derivative::to_subset_unchecked(&element.v2); + Self::new(re, v1, v2) + } + #[inline(always)] + fn is_in_subset(element: &Dual2Vec) -> bool { + TSuper::is_in_subset(&element.re) + && as SupersetOf>>::is_in_subset( + &element.v1, + ) + && as SupersetOf>>::is_in_subset( + &element.v2, + ) + } +} + +impl SupersetOf for Dual2Vec +where + TSuper: DualNum + SupersetOf, + DefaultAllocator: Allocator + Allocator + Allocator + Allocator, +{ + #[inline(always)] + fn is_in_subset(&self) -> bool { + self.re.is_in_subset() + } + + #[inline(always)] + fn to_subset_unchecked(&self) -> f32 { + self.re.to_subset_unchecked() + } + + #[inline(always)] + fn from_subset(element: &f32) -> Self { + // Interpret as a purely real number + let re = TSuper::from_subset(element); + let v1 = Derivative::none(); + let v2 = Derivative::none(); + Self::new(re, v1, v2) + } +} + +impl SupersetOf for Dual2Vec +where + TSuper: DualNum + SupersetOf, + DefaultAllocator: Allocator + Allocator + Allocator + Allocator, +{ + #[inline(always)] + fn is_in_subset(&self) -> bool { + self.re.is_in_subset() + } + + #[inline(always)] + fn to_subset_unchecked(&self) -> f64 { + self.re.to_subset_unchecked() + } + + #[inline(always)] + fn from_subset(element: &f64) -> Self { + // Interpret as a purely real number + let re = TSuper::from_subset(element); + let v1 = Derivative::none(); + let v2 = Derivative::none(); + Self::new(re, v1, v2) + } +} + +// We can't do a simd implementation until simba lets us implement SimdPartialOrd +// using _T_'s SimdBool. The blanket impl gets in the way. So we must constrain +// T to SimdValue, which is basically the same as +// saying f32 or f64 only. +// +// Limitation of simba. See https://github.com/dimforge/simba/issues/44 + +use nalgebra::{ComplexField, RealField}; +// This impl is modelled on `impl ComplexField for f32`. The imaginary part is nothing. +impl ComplexField for Dual2Vec +where + T: DualNum + SupersetOf + AbsDiffEq + Sync + Send, + T::Element: DualNum + Scalar + DualNumFloat + Sync + Send, + T: SupersetOf, + T: SupersetOf, + T: SupersetOf, + T: SimdPartialOrd + PartialOrd, + T: SimdValue, + T: RelativeEq + UlpsEq + AbsDiffEq, + DefaultAllocator: Allocator + Allocator + Allocator + Allocator, + >::Buffer: Sync + Send, + >::Buffer: Sync + Send, + >::Buffer: Sync + Send, + >::Buffer: Sync + Send, +{ + type RealField = Self; + + #[inline] + fn from_real(re: Self::RealField) -> Self { + re + } + + #[inline] + fn real(self) -> Self::RealField { + self + } + + #[inline] + fn imaginary(self) -> Self::RealField { + Self::zero() + } + + #[inline] + fn modulus(self) -> Self::RealField { + self.abs() + } + + #[inline] + fn modulus_squared(self) -> Self::RealField { + &self * &self + } + + #[inline] + fn argument(self) -> Self::RealField { + Self::zero() + } + + #[inline] + fn norm1(self) -> Self::RealField { + self.abs() + } + + #[inline] + fn scale(self, factor: Self::RealField) -> Self { + self * factor + } + + #[inline] + fn unscale(self, factor: Self::RealField) -> Self { + self / factor + } + + #[inline] + fn floor(self) -> Self { + panic!("called floor() on a dual number") + } + + #[inline] + fn ceil(self) -> Self { + panic!("called ceil() on a dual number") + } + + #[inline] + fn round(self) -> Self { + panic!("called round() on a dual number") + } + + #[inline] + fn trunc(self) -> Self { + panic!("called trunc() on a dual number") + } + + #[inline] + fn fract(self) -> Self { + panic!("called fract() on a dual number") + } + + #[inline] + fn mul_add(self, a: Self, b: Self) -> Self { + DualNum::mul_add(&self, a, b) + } + + #[inline] + fn abs(self) -> Self::RealField { + Signed::abs(&self) + } + + #[inline] + fn hypot(self, other: Self) -> Self::RealField { + let sum_sq = self.powi(2) + other.powi(2); + DualNum::sqrt(&sum_sq) + } + + #[inline] + fn recip(self) -> Self { + DualNum::recip(&self) + } + + #[inline] + fn conjugate(self) -> Self { + self + } + + #[inline] + fn sin(self) -> Self { + DualNum::sin(&self) + } + + #[inline] + fn cos(self) -> Self { + DualNum::cos(&self) + } + + #[inline] + fn sin_cos(self) -> (Self, Self) { + DualNum::sin_cos(&self) + } + + #[inline] + fn tan(self) -> Self { + DualNum::tan(&self) + } + + #[inline] + fn asin(self) -> Self { + DualNum::asin(&self) + } + + #[inline] + fn acos(self) -> Self { + DualNum::acos(&self) + } + + #[inline] + fn atan(self) -> Self { + DualNum::atan(&self) + } + + #[inline] + fn sinh(self) -> Self { + DualNum::sinh(&self) + } + + #[inline] + fn cosh(self) -> Self { + DualNum::cosh(&self) + } + + #[inline] + fn tanh(self) -> Self { + DualNum::tanh(&self) + } + + #[inline] + fn asinh(self) -> Self { + DualNum::asinh(&self) + } + + #[inline] + fn acosh(self) -> Self { + DualNum::acosh(&self) + } + + #[inline] + fn atanh(self) -> Self { + DualNum::atanh(&self) + } + + #[inline] + fn log(self, base: Self::RealField) -> Self { + DualNum::ln(&self) / DualNum::ln(&base) + } + + #[inline] + fn log2(self) -> Self { + DualNum::log2(&self) + } + + #[inline] + fn log10(self) -> Self { + DualNum::log10(&self) + } + + #[inline] + fn ln(self) -> Self { + DualNum::ln(&self) + } + + #[inline] + fn ln_1p(self) -> Self { + DualNum::ln_1p(&self) + } + + #[inline] + fn sqrt(self) -> Self { + DualNum::sqrt(&self) + } + + #[inline] + fn exp(self) -> Self { + DualNum::exp(&self) + } + + #[inline] + fn exp2(self) -> Self { + DualNum::exp2(&self) + } + + #[inline] + fn exp_m1(self) -> Self { + DualNum::exp_m1(&self) + } + + #[inline] + fn powi(self, n: i32) -> Self { + DualNum::powi(&self, n) + } + + #[inline] + fn powf(self, n: Self::RealField) -> Self { + // n could be a dual. + DualNum::powd(&self, n) + } + + #[inline] + fn powc(self, n: Self) -> Self { + // same as powf, Self isn't complex + self.powf(n) + } + + #[inline] + fn cbrt(self) -> Self { + DualNum::cbrt(&self) + } + + #[inline] + fn is_finite(&self) -> bool { + self.re.is_finite() + } + + #[inline] + fn try_sqrt(self) -> Option { + if self > Self::zero() { + Some(DualNum::sqrt(&self)) + } else { + None + } + } +} + +impl RealField for Dual2Vec +where + T: DualNum + SupersetOf + Sync + Send, + T::Element: DualNum + Scalar + DualNumFloat, + T: SupersetOf, + T: SupersetOf, + T: SupersetOf, + T: SimdPartialOrd + PartialOrd, + T: RelativeEq + AbsDiffEq, + T: SimdValue, + T: UlpsEq, + T: AbsDiffEq, + DefaultAllocator: Allocator + Allocator + Allocator + Allocator, + >::Buffer: Sync + Send, + >::Buffer: Sync + Send, + >::Buffer: Sync + Send, + >::Buffer: Sync + Send, +{ + #[inline] + fn copysign(self, sign: Self) -> Self { + if sign.re.is_sign_positive() { + self.simd_abs() + } else { + -self.simd_abs() + } + } + + #[inline] + fn atan2(self, other: Self) -> Self { + let re = self.re.atan2(other.re); + let den = self.re.powi(2) + other.re.powi(2); + + let da = other.re / den; + let db = -self.re / den; + let v1 = &self.v1 * da + &other.v1 * db; + + let daa = db * da * (T::one() + T::one()); + let dab = db * db - da * da; + let dbb = -daa; + let ca = &self.v1 * daa + &other.v1 * dab; + let cb = &self.v1 * dab + &other.v1 * dbb; + let v2 = self.v2 * da + other.v2 * db + ca.tr_mul(&self.v1) + cb.tr_mul(&other.v1); + + Self::new(re, v1, v2) + } + + #[inline] + fn pi() -> Self { + Self::from_re(::PI()) + } + + #[inline] + fn two_pi() -> Self { + Self::from_re(::TAU()) + } + + #[inline] + fn frac_pi_2() -> Self { + Self::from_re(::FRAC_PI_4()) + } + + #[inline] + fn frac_pi_3() -> Self { + Self::from_re(::FRAC_PI_3()) + } + + #[inline] + fn frac_pi_4() -> Self { + Self::from_re(::FRAC_PI_4()) + } + + #[inline] + fn frac_pi_6() -> Self { + Self::from_re(::FRAC_PI_6()) + } + + #[inline] + fn frac_pi_8() -> Self { + Self::from_re(::FRAC_PI_8()) + } + + #[inline] + fn frac_1_pi() -> Self { + Self::from_re(::FRAC_1_PI()) + } + + #[inline] + fn frac_2_pi() -> Self { + Self::from_re(::FRAC_2_PI()) + } + + #[inline] + fn frac_2_sqrt_pi() -> Self { + Self::from_re(::FRAC_2_SQRT_PI()) + } + + #[inline] + fn e() -> Self { + Self::from_re(::E()) + } + + #[inline] + fn log2_e() -> Self { + Self::from_re(::LOG2_E()) + } + + #[inline] + fn log10_e() -> Self { + Self::from_re(::LOG10_E()) + } + + #[inline] + fn ln_2() -> Self { + Self::from_re(::LN_2()) + } + + #[inline] + fn ln_10() -> Self { + Self::from_re(::LN_10()) + } + + #[inline] + fn is_sign_positive(&self) -> bool { + self.re.is_sign_positive() + } + + #[inline] + fn is_sign_negative(&self) -> bool { + self.re.is_sign_negative() + } + + /// Got to be careful using this, because it throws away the derivatives of the one not chosen + #[inline] + fn max(self, other: Self) -> Self { + if other > self { + other + } else { + self + } + } + + /// Got to be careful using this, because it throws away the derivatives of the one not chosen + #[inline] + fn min(self, other: Self) -> Self { + if other < self { + other + } else { + self + } + } + + /// If the min/max values are constants and the clamping has an effect, you lose your gradients. + #[inline] + fn clamp(self, min: Self, max: Self) -> Self { + if self < min { + min + } else if self > max { + max + } else { + self + } + } + + #[inline] + fn min_value() -> Option { + Some(Self::from_re(T::min_value())) + } + + #[inline] + fn max_value() -> Option { + Some(Self::from_re(T::max_value())) + } +}