diff --git a/tardis/montecarlo/montecarlo_numba/nonhomologous_grid.py b/tardis/montecarlo/montecarlo_numba/nonhomologous_grid.py new file mode 100644 index 00000000000..69501f3cfec --- /dev/null +++ b/tardis/montecarlo/montecarlo_numba/nonhomologous_grid.py @@ -0,0 +1,76 @@ +import numpy as np +from numba import njit + +from tardis.montecarlo.montecarlo_numba import ( + njit_dict_no_parallel, +) + + +@njit(**njit_dict_no_parallel) +def velocity_dvdr(r_packet, numba_model): + """ + Velocity at radius r and dv/dr of current shell + + Parameters + ---------- + r_packet: RPacket + numba_model: NumbaModel + + Returns + ----------- + v: float, current velocity + frac: float, dv/dr for current shell + """ + shell_id = r_packet.current_shell_id + v_inner = numba_model.v_inner[shell_id] + v_outer = numba_model.v_outer[shell_id] + r_inner = numba_model.r_inner[shell_id] + r_outer = numba_model.r_outer[shell_id] + r = r_packet.r + frac = (v_outer - v_inner) / (r_outer - r_inner) + return v_inner + frac * (r - r_inner), frac + + +@njit(**njit_dict_no_parallel) +def tau_sobolev_factor(r_packet, numba_model): + """ + The angle and velocity dependent Tau Sobolev factor component. Is called when ENABLE_NONHOMOLOGOUS_EXPANSION is set to True. + + Note: to get Tau Sobolev, this needs to be multiplied by tau_sobolevs found from plasma + Parameters + ---------- + r_packet: RPacket + numba_model: NumbaModel + + Returns + ----------- + factor = 1.0 / ((1 - mu * mu) * v / r + mu * mu * dvdr) + """ + + v, dvdr = velocity_dvdr(r_packet, numba_model) + r = r_packet.r + mu = r_packet.mu + factor = 1.0 / ((1 - mu * mu) * v / r + mu * mu * dvdr) + return factor + + +# @njit(**njit_dict_no_parallel) +def quartic_roots(a, b, c, d, e, threshold): + """ + Solves ax^4 + bx^3 + cx^2 + dx + e = 0, for the real roots greater than the threshold returns (x - threshold). + Uses: https://en.wikipedia.org/wiki/Quartic_function#General_formula_for_roots + + Parameters + ----------- + a, b, c, d, e: coefficients of the equations ax^4 + bx^3 + cx^2 + dx + e = 0, float + threshold: lower needed limit on roots, float + Returns + ----------- + roots: real positive roots of ax^4 + bx^3 + cx^2 + dx + e = 0 + + """ + roots = np.roots((a, b, c, d, e)) + roots = [root for root in roots if isinstance(root, float)] + roots = [root for root in roots if root > threshold] + + return roots diff --git a/tardis/montecarlo/tests/test_nonhomologous.py b/tardis/montecarlo/tests/test_nonhomologous.py new file mode 100644 index 00000000000..06b757057cd --- /dev/null +++ b/tardis/montecarlo/tests/test_nonhomologous.py @@ -0,0 +1,35 @@ +import pytest +from numpy.testing import assert_almost_equal + +import tardis.montecarlo.montecarlo_numba.nonhomologous_grid as nonhomologous_grid + + +@pytest.mark.parametrize( + ["a", "b", "c", "d", "e", "threshold", "expected_roots"], + [ + ( + 0.0, + 0.0, + 0.0, + 2.0, + -1.0, + 0.0, + {"result": [0.5]}, + ), + ( + 1.0, + 2.0, + 0.0, + 2.0, + 0.0, + 0.0, + {"result": []}, + ), + (1.0, -14.0, 71.0, -154.0, 120.0, 2.5, {"result": [3, 4, 5]}), + ], +) +def test_quartic_roots(a, b, c, d, e, threshold, expected_roots): + obtained_roots = nonhomologous_grid.quartic_roots(a, b, c, d, e, threshold) + obtained_roots.sort() + + assert_almost_equal(obtained_roots, expected_roots["result"])