Skip to content

Commit

Permalink
Make complex math functions host-device
Browse files Browse the repository at this point in the history
  • Loading branch information
chillenzer committed Feb 17, 2025
1 parent 95c0bf2 commit 052ab15
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions include/alpaka/math/Complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
return sqrt(ctx, arg.real() * arg.real() + arg.imag() * arg.imag());
}
Expand All @@ -605,7 +605,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// This holds everywhere, including the branch cuts: acos(z) = -i * ln(z + i * sqrt(1 - z^2))
return Complex<T>{static_cast<T>(0.0), static_cast<T>(-1.0)}
Expand All @@ -623,7 +623,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// acos(z) = ln(z + sqrt(z-1) * sqrt(z+1))
return log(ctx, arg + sqrt(ctx, arg - static_cast<T>(1.0)) * sqrt(ctx, arg + static_cast<T>(1.0)));
Expand All @@ -636,7 +636,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
{
return atan2(ctx, argument.imag(), argument.real());
}
Expand All @@ -648,7 +648,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// This holds everywhere, including the branch cuts: asin(z) = i * ln(sqrt(1 - z^2) - i * z)
return Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)}
Expand All @@ -665,7 +665,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// asinh(z) = ln(z + sqrt(z^2 + 1))
return log(ctx, arg + sqrt(ctx, arg * arg + static_cast<T>(1.0)));
Expand All @@ -678,7 +678,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// This holds everywhere, including the branch cuts: atan(z) = -i/2 * ln((i - z) / (i + z))
return Complex<T>{static_cast<T>(0.0), static_cast<T>(-0.5)}
Expand All @@ -695,7 +695,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// atanh(z) = 0.5 * (ln(1 + z) - ln(1 - z))
return static_cast<T>(0.5)
Expand All @@ -707,7 +707,7 @@ namespace alpaka
template<typename TAcc, typename T>
struct Conj<TAcc, Complex<T>>
{
ALPAKA_FN_ACC auto operator()(TAcc const& /* conj_ctx */, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TAcc const& /* conj_ctx */, Complex<T> const& arg)
{
return Complex<T>{arg.real(), -arg.imag()};
}
Expand All @@ -719,7 +719,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// cos(z) = 0.5 * (exp(i * z) + exp(-i * z))
return T(0.5)
Expand All @@ -734,7 +734,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// cosh(z) = 0.5 * (exp(z) + exp(-z))
return T(0.5) * (exp(ctx, arg) + exp(ctx, static_cast<T>(-1.0) * arg));
Expand All @@ -747,7 +747,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// exp(z) = exp(x + iy) = exp(x) * (cos(y) + i * sin(y))
auto re = T{}, im = T{};
Expand All @@ -762,7 +762,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
{
// Branch cut along the negative real axis (same as for std::complex),
// principal value of ln(z) = ln(|z|) + i * arg(z)
Expand All @@ -777,7 +777,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
{
return log(ctx, argument) / log(ctx, static_cast<T>(2));
}
Expand All @@ -789,7 +789,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
{
return log(ctx, argument) / log(ctx, static_cast<T>(10));
}
Expand All @@ -801,7 +801,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& base, Complex<U> const& exponent)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& base, Complex<U> const& exponent)
{
// Type promotion matching rules of complex std::pow but simplified given our math only supports float
// and double, no long double.
Expand All @@ -818,7 +818,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& base, U const& exponent)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& base, U const& exponent)
{
return pow(ctx, base, Complex<U>{exponent});
}
Expand All @@ -830,7 +830,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, T const& base, Complex<U> const& exponent)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, T const& base, Complex<U> const& exponent)
{
return pow(ctx, Complex<T>{base}, exponent);
}
Expand All @@ -842,7 +842,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
return static_cast<T>(1.0) / sqrt(ctx, arg);
}
Expand All @@ -854,7 +854,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// sin(z) = (exp(i * z) - exp(-i * z)) / 2i
return (exp(ctx, Complex<T>{static_cast<T>(0.0), static_cast<T>(1.0)} * arg)
Expand All @@ -869,7 +869,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// sinh(z) = (exp(z) - exp(-i * z)) / 2
return (exp(ctx, arg) - exp(ctx, static_cast<T>(-1.0) * arg)) / static_cast<T>(2.0);
Expand All @@ -882,7 +882,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(
ALPAKA_FN_HOST_ACC auto operator()(
TCtx const& ctx,
Complex<T> const& arg,
Complex<T>& result_sin,
Expand All @@ -899,7 +899,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& argument)
{
// Branch cut along the negative real axis (same as for std::complex),
// principal value of sqrt(z) = sqrt(|z|) * e^(i * arg(z) / 2)
Expand All @@ -916,7 +916,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// tan(z) = i * (e^-iz - e^iz) / (e^-iz + e^iz) = i * (1 - e^2iz) / (1 + e^2iz)
// Warning: this straightforward implementation can easily result in NaN as 0/0 or inf/inf.
Expand All @@ -932,7 +932,7 @@ namespace alpaka
{
//! Take context as original (accelerator) type, since we call other math functions
template<typename TCtx>
ALPAKA_FN_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
ALPAKA_FN_HOST_ACC auto operator()(TCtx const& ctx, Complex<T> const& arg)
{
// tanh(z) = (e^z - e^-z)/(e^z+e^-z)
return (exp(ctx, arg) - exp(ctx, static_cast<T>(-1.0) * arg))
Expand Down

0 comments on commit 052ab15

Please sign in to comment.