Skip to content

Commit

Permalink
Move structs to struct.py, init default values.
Browse files Browse the repository at this point in the history
- __init__ in RPacket and StorageModel resemble the
  same init methods in test_cmontecarlo.c

- The way arguments are passed in paramterization is
  changed from individual params to a dict for better
  aesthetic sense.
  • Loading branch information
karandesai-96 committed Mar 30, 2016
1 parent 8101adb commit d78b547
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 138 deletions.
Empty file added tardis/montecarlo/struct.py
Empty file.
163 changes: 25 additions & 138 deletions tardis/montecarlo/tests/test_cmontecarlo.py
Original file line number Diff line number Diff line change
@@ -1,178 +1,65 @@
import os
import random
from ctypes import *

import pytest
import numpy as np
import numpy.ctypeslib as npclib
import numpy.testing as nptesting
from tardis import __path__ as path

test_path = os.path.join(path[0], 'montecarlo', 'test_montecarlo.so')
tests = CDLL(test_path)

cmontecarlo_filepath = os.path.join(path[0], 'montecarlo', 'montecarlo.so')
cmontecarlo_methods = CDLL(cmontecarlo_filepath)


class RPacket(Structure):
_fields_ = [
('nu', c_double),
('mu', c_double),
('energy', c_double),
('r', c_double),
('tau_event', c_double),
('nu_line', c_double),
('current_shell_id', c_int64),
('next_line_id', c_int64),
('last_line', c_int64),
('close_line', c_int64),
('recently_crossed_boundary', c_int64),
('current_continuum_id', c_int64),
('virtual_packet_flag', c_int64),
('virtual_packet', c_int64),
('d_line', c_double),
('d_electron', c_double),
('d_boundary', c_double),
('d_cont', c_double),
('next_shell_id', c_int64),
('status', c_int),
('id', c_int64),
('chi_th', c_double),
('chi_cont', c_double),
('chi_ff', c_double),
('chi_bf', c_double)
]

def __init__(self, **kwargs):
super(RPacket, self).__init__()
for key, value in kwargs.iteritems():
setattr(self, key, value)


class StorageModel(Structure):
_fields_ = [
('packet_nus', POINTER(c_double)),
('packet_mus', POINTER(c_double)),
('packet_energies', POINTER(c_double)),
('output_nus', POINTER(c_double)),
('output_energies', POINTER(c_double)),
('last_interaction_in_nu', POINTER(c_double)),
('last_line_interaction_in_id', POINTER(c_int64)),
('last_line_interaction_out_id', POINTER(c_int64)),
('last_line_interaction_shell_id', POINTER(c_int64)),
('last_line_interaction_type', POINTER(c_int64)),
('no_of_packets', c_int64),
('no_of_shells', c_int64),
('r_inner', POINTER(c_double)),
('r_outer', POINTER(c_double)),
('v_inner', POINTER(c_double)),
('time_explosion', c_double),
('inverse_time_explosion', c_double),
('electron_densities', POINTER(c_double)),
('inverse_electron_densities', POINTER(c_double)),
('line_list_nu', POINTER(c_double)),
('continuum_list_nu', POINTER(c_double)),
('line_lists_tau_sobolevs', POINTER(c_double)),
('line_lists_tau_sobolevs_nd', c_int64),
('line_lists_j_blues', POINTER(c_double)),
('line_lists_j_blues_nd', c_int64),
('no_of_lines', c_int64),
('no_of_edges', c_int64),
('line_interaction_id', c_int64),
('transition_probabilities', POINTER(c_double)),
('transition_probabilities_nd', c_int64),
('line2macro_level_upper', POINTER(c_int64)),
('macro_block_references', POINTER(c_int64)),
('transition_type', POINTER(c_int64)),
('destination_level_id', POINTER(c_int64)),
('transition_line_id', POINTER(c_int64)),
('js', POINTER(c_double)),
('nubars', POINTER(c_double)),
('spectrum_start_nu', c_double),
('spectrum_delta_nu', c_double),
('spectrum_end_nu', c_double),
('spectrum_virt_start_nu', c_double),
('spectrum_virt_end_nu', c_double),
('spectrum_virt_nu', POINTER(c_double)),
('sigma_thomson', c_double),
('inverse_sigma_thomson', c_double),
('inner_boundary_albedo', c_double),
('reflective_inner_boundary', c_int64),
('current_packet_id', c_int64),
('chi_bf_tmp_partial', POINTER(c_double)),
('t_electrons', POINTER(c_double)),
('l_pop', POINTER(c_double)),
('l_pop_r', POINTER(c_double)),
('cont_status', c_int),
('virt_packet_nus', POINTER(c_double)),
('virt_packet_energies', POINTER(c_double)),
('virt_packet_last_interaction_in_nu', POINTER(c_double)),
('virt_packet_last_interaction_type', POINTER(c_int64)),
('virt_packet_last_line_interaction_in_id', POINTER(c_int64)),
('virt_packet_last_line_interaction_out_id', POINTER(c_int64)),
('virt_packet_count', c_int64),
('virt_array_size', c_int64)
]

def __init__(self, **kwargs):
super(StorageModel, self).__init__()
for key, value in kwargs.iteritems():
setattr(self, key, value)
from tardis.montecarlo.struct import *


@pytest.mark.parametrize(
['mu', 'r', 'recently_crossed_boundary', 'r_inner', 'r_outer', 'expected'],
[(0.3, 7.5e14, 1, [6.912e14, 8.64e14], [8.64e14, 1.0368e15], 259376919351035.88),
(0.3, 7.5e14, 0, [6.912e14, 8.64e14], [8.64e14, 1.0368e15], 259376919351035.88),
(-0.3, 7.5e13, 0, [6.912e14, 8.64e14], [8.64e14, 1.0368e15], -838532664885601.1),
(-0.3, 7.5e14, 0, [6.912e14, 8.64e14], [8.64e14, 1.0368e15], -259376919351035.88)]
['packet_params', 'expected'],
[({'mu': 0.3,
'r': 7.5e14,
'recently_crossed_boundary': 1}, 259376919351035.88),
({'mu': 0.3,
'r': 7.5e14,
'recently_crossed_boundary': 0}, 259376919351035.88),
({'mu': -0.3,
'r': 7.5e13,
'recently_crossed_boundary': 0}, -838532664885601.1),
({'mu': -0.3,
'r': 7.5e14,
'recently_crossed_boundary': 0}, -259376919351035.88)]
)
def test_compute_distance2boundary(mu, r, recently_crossed_boundary, r_inner, r_outer, expected):
packet = RPacket(mu=mu, r=r, current_shell_id=0, next_shell_id=1,
recently_crossed_boundary=recently_crossed_boundary)
model = StorageModel(r_inner=npclib.as_ctypes(np.array(r_inner)),
r_outer=npclib.as_ctypes(np.array(r_outer)))
def test_compute_distance2boundary(packet_params, expected):
packet = RPacket(**packet_params)
model = StorageModel()

cmontecarlo_methods.compute_distance2boundary.restype = c_double
obtained = cmontecarlo_methods.compute_distance2boundary(byref(packet), byref(model))

nptesting.assert_almost_equal(obtained, expected)
np.testing.assert_almost_equal(obtained, expected)


def test_compute_distance2line():
distance_to_line = 7.792353908000001e+17
tests.test_compute_distance2line.restype = c_double
nptesting.assert_almost_equal(tests.test_compute_distance2line(),
np.testing.assert_almost_equal(tests.test_compute_distance2line(),
distance_to_line)


def test_compute_distance2continuum():
distance_to_electron = 4.359272608766106e+28
tests.test_compute_distance2continuum.restype = c_double
nptesting.assert_almost_equal(tests.test_compute_distance2continuum(),
np.testing.assert_almost_equal(tests.test_compute_distance2continuum(),
distance_to_electron)


def test_rpacket_doppler_factor():
doppler_factor = 0.9998556693818854
tests.test_rpacket_doppler_factor.restype = c_double
nptesting.assert_almost_equal(tests.test_rpacket_doppler_factor(),
np.testing.assert_almost_equal(tests.test_rpacket_doppler_factor(),
doppler_factor)


@pytest.mark.skipif(True, reason='Bad test design')
def test_move_packet():
doppler_factor = 0.9998556693818854
tests.test_move_packet.restype = c_double
nptesting.assert_almost_equal(tests.test_move_packet(),
np.testing.assert_almost_equal(tests.test_move_packet(),
doppler_factor)


def test_increment_j_blue_estimator():
j_blue = 1.1249855669381885
tests.test_increment_j_blue_estimator.restype = c_double
nptesting.assert_almost_equal(tests.test_increment_j_blue_estimator(),
np.testing.assert_almost_equal(tests.test_increment_j_blue_estimator(),
j_blue)


Expand All @@ -199,7 +86,7 @@ def test_montecarlo_thomson_scatter():
def test_calculate_chi_bf():
chi_bf = 1.0006697327643788
tests.test_calculate_chi_bf.restype = c_double
nptesting.assert_almost_equal(tests.test_calculate_chi_bf(),
np.testing.assert_almost_equal(tests.test_calculate_chi_bf(),
chi_bf)


Expand All @@ -212,7 +99,7 @@ def test_montecarlo_bound_free_scatter():
def test_bf_cross_section():
bf_cross_section = 0.0
tests.test_bf_cross_section.restype = c_double
nptesting.assert_almost_equal(tests.test_bf_cross_section(),
np.testing.assert_almost_equal(tests.test_bf_cross_section(),
bf_cross_section)


Expand Down

0 comments on commit d78b547

Please sign in to comment.