Skip to content

Commit

Permalink
Replace wrapper structs with tuple structs in python (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
prehner authored Aug 10, 2021
1 parent 1a49a21 commit c980b42
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 140 deletions.
14 changes: 8 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
//! }
//! ```
#![warn(clippy::all)]
#![allow(clippy::needless_range_loop)]

use num_traits::{Float, FromPrimitive, Inv, NumAssignOps, NumOps, Signed};
use std::fmt;
use std::iter::{Product, Sum};
Expand Down Expand Up @@ -61,7 +64,6 @@ pub use linalg::*;
#[cfg(feature = "python")]
pub mod python;


/// A generalized (hyper) dual number.
pub trait DualNum<F>:
NumOps
Expand Down Expand Up @@ -179,13 +181,13 @@ pub trait DualNum<F>:
/// Fused multiply-add
#[inline]
fn mul_add(&self, a: Self, b: Self) -> Self {
self.clone() * a + b
*self * a + b
}

/// Power with dual exponent `x^n`
#[inline]
fn powd(&self, exp: &Self) -> Self {
(self.ln() * exp.clone()).exp()
fn powd(&self, exp: Self) -> Self {
(self.ln() * exp).exp()
}
}

Expand Down Expand Up @@ -224,8 +226,8 @@ macro_rules! impl_dual_num_float {
fn powf(&self, n: Self) -> Self {
<$float>::powf(*self, n)
}
fn powd(&self, n: &Self) -> Self {
<$float>::powf(*self, *n)
fn powd(&self, n: Self) -> Self {
<$float>::powf(*self, n)
}
fn sqrt(&self) -> Self {
<$float>::sqrt(*self)
Expand Down
20 changes: 6 additions & 14 deletions src/python/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,19 @@ use pyo3::prelude::*;
/// 18
/// >>> fx.first_derivative
/// 8.25
pub struct PyDual64 {
pub _data: Dual64,
}
pub struct PyDual64(Dual64);

#[pymethods]
impl PyDual64 {
#[new]
pub fn new(re: f64, eps: f64) -> Self {
Self {
_data: Dual64::new_scalar(re, eps),
}
Self(Dual64::new_scalar(re, eps))
}

#[getter]
/// Dual part.
pub fn get_first_derivative(&self) -> f64 {
self._data.eps[0]
self.0.eps[0]
}
}

Expand All @@ -56,15 +52,11 @@ macro_rules! impl_dual_n {
($py_type_name:ident, $n:literal) => {
#[pyclass(name = "DualVec64")]
#[derive(Clone, Copy)]
pub struct $py_type_name {
pub _data: DualVec64<$n>,
}
pub struct $py_type_name(DualVec64<$n>);

impl $py_type_name {
pub fn new(re: f64, eps: [f64; $n]) -> Self {
Self {
_data: DualVec64::new(re, StaticVec::new_vec(eps)),
}
DualVec64::new(re, StaticVec::new_vec(eps)).into()
}
}

Expand All @@ -73,7 +65,7 @@ macro_rules! impl_dual_n {
#[getter]
/// Dual part.
pub fn get_first_derivative(&self) -> [f64; $n] {
*self._data.eps.raw_array()
*self.0.eps.raw_array()
}
}

Expand Down
18 changes: 7 additions & 11 deletions src/python/dual2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use pyo3::prelude::*;
#[pyclass(name = "Dual2_64")]
#[derive(Clone)]
/// Hyper dual number using 64-bit-floats.
pub struct PyDual2_64 {
pub _data: Dual2_64,
}
pub struct PyDual2_64(Dual2_64);

#[pymethods]
impl PyDual2_64 {
Expand All @@ -21,13 +19,13 @@ impl PyDual2_64 {
#[getter]
/// First hyperdual part.
fn get_first_derivative(&self) -> f64 {
self._data.v1[0]
self.0.v1[0]
}

#[getter]
/// Second hyperdual part.
fn get_second_derivative(&self) -> f64 {
self._data.v2[0]
self.0.v2[0]
}
}

Expand All @@ -36,27 +34,25 @@ impl_dual_num!(PyDual2_64, Dual2_64, f64);
#[pyclass(name = "Dual2Dual64")]
#[derive(Clone)]
/// Hyper dual number using 64-bit-floats.
pub struct PyDual2Dual64 {
pub _data: Dual2<Dual64, f64>,
}
pub struct PyDual2Dual64(Dual2<Dual64, f64>);

#[pymethods]
impl PyDual2Dual64 {
#[new]
pub fn new(v0: PyDual64, v1: PyDual64, v2: PyDual64) -> Self {
Dual2::new_scalar(v0._data, v1._data, v2._data).into()
Dual2::new_scalar(v0.into(), v1.into(), v2.into()).into()
}

#[getter]
/// First hyperdual part.
fn get_first_derivative(&self) -> PyDual64 {
self._data.v1[0].into()
self.0.v1[0].into()
}

#[getter]
/// Second hyperdual part.
fn get_second_derivative(&self) -> PyDual64 {
self._data.v2[(0, 0)].into()
self.0.v2[(0, 0)].into()
}
}

Expand Down
26 changes: 11 additions & 15 deletions src/python/dual3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use pyo3::prelude::*;
#[pyclass(name = "Dual3_64")]
#[derive(Clone)]
/// Hyper dual number using 64-bit-floats.
pub struct PyDual3_64 {
pub _data: Dual3_64,
}
pub struct PyDual3_64(Dual3_64);

#[pymethods]
impl PyDual3_64 {
Expand All @@ -21,19 +19,19 @@ impl PyDual3_64 {
#[getter]
/// First hyperdual part.
fn get_first_derivative(&self) -> f64 {
self._data.v1
self.0.v1
}

#[getter]
/// Second hyperdual part.
fn get_second_derivative(&self) -> f64 {
self._data.v2
self.0.v2
}

#[getter]
/// Third hyperdual part.
fn get_third_derivative(&self) -> f64 {
self._data.v3
self.0.v3
}
}

Expand All @@ -42,33 +40,31 @@ impl_dual_num!(PyDual3_64, Dual3_64, f64);
#[pyclass(name = "Dual3Dual64")]
#[derive(Clone)]
/// Hyper dual number using 64-bit-floats.
pub struct PyDual3Dual64 {
pub _data: Dual3<Dual64, f64>,
}
pub struct PyDual3Dual64(Dual3<Dual64, f64>);

#[pymethods]
impl PyDual3Dual64 {
#[new]
pub fn new(v0: PyDual64, v1: PyDual64, v2: PyDual64, v3: PyDual64) -> Self {
Dual3::new(v0._data, v1._data, v2._data, v3._data).into()
Dual3::new(v0.into(), v1.into(), v2.into(), v3.into()).into()
}

#[getter]
/// First hyperdual part.
fn get_first_derivative(&self) -> PyDual64 {
self._data.v1.into()
self.0.v1.into()
}

#[getter]
/// Second hyperdual part.
fn get_second_derivative(&self) -> PyDual64 {
self._data.v2.into()
self.0.v2.into()
}

#[getter]
/// Third hyperdual part.
fn get_third_derivative(&self) -> PyDual64 {
self._data.v3.into()
self.0.v3.into()
}
}

Expand All @@ -85,10 +81,10 @@ fn derive3(x: &PyAny) -> PyResult<PyObject> {
};
if let Ok(x) = x.extract::<PyDual64>() {
return Ok(
PyCell::new(py, PyDual3Dual64::from(Dual3::from_re(x._data).derive()))?
PyCell::new(py, PyDual3Dual64::from(Dual3::from_re(x.into()).derive()))?
.to_object(py),
);
};
Err(PyErr::new::<PyTypeError, _>(format!("not implemented!")))
Err(PyErr::new::<PyTypeError, _>("not implemented!".to_string()))
})
}
Loading

0 comments on commit c980b42

Please sign in to comment.