Skip to content

Commit

Permalink
Changed to append seed & index values only once for a particular …
Browse files Browse the repository at this point in the history
…`RPacketTracker` instance (#1881)

* Update numba_interface.py

* Changed implementation to append seed & index value only once for a particular RPacketTracker instance

Co-authored-by: Wolfgang Kerzendorf <wkerzendorf@gmail.com>
  • Loading branch information
DhruvSondhi and wkerzendorf authored Jan 31, 2022
1 parent 583c739 commit 60b76ef
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions tardis/montecarlo/montecarlo_numba/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def set_properties(

rpacket_collection_spec = [
("length", int64),
("seed", int64[:]),
("index", int64[:]),
("seed", int64),
("index", int64),
("status", int64[:]),
("r", float64[:]),
("nu", float64[:]),
Expand Down Expand Up @@ -350,8 +350,8 @@ class RPacketTracker(object):

def __init__(self):
self.length = montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH
self.seed = np.empty(self.length, dtype=np.int64)
self.index = np.empty(self.length, dtype=np.int64)
self.seed = np.int64(0)
self.index = np.int64(0)
self.status = np.empty(self.length, dtype=np.int64)
self.r = np.empty(self.length, dtype=np.float64)
self.nu = np.empty(self.length, dtype=np.float64)
Expand All @@ -363,26 +363,20 @@ def __init__(self):
def track(self, r_packet):
if self.interact_id >= self.length:
temp_length = self.length * 2
temp_index = np.empty(temp_length, dtype=np.int64)
temp_seed = np.empty(temp_length, dtype=np.int64)
temp_status = np.empty(temp_length, dtype=np.int64)
temp_r = np.empty(temp_length, dtype=np.float64)
temp_nu = np.empty(temp_length, dtype=np.float64)
temp_mu = np.empty(temp_length, dtype=np.float64)
temp_energy = np.empty(temp_length, dtype=np.float64)
temp_shell_id = np.empty(temp_length, dtype=np.int64)

temp_index[: self.length] = self.index
temp_seed[: self.length] = self.seed
temp_status[: self.length] = self.status
temp_r[: self.length] = self.r
temp_nu[: self.length] = self.nu
temp_mu[: self.length] = self.mu
temp_energy[: self.length] = self.energy
temp_shell_id[: self.length] = self.shell_id

self.index = temp_index
self.seed = temp_seed
self.status = temp_status
self.r = temp_r
self.nu = temp_nu
Expand All @@ -391,8 +385,8 @@ def track(self, r_packet):
self.shell_id = temp_shell_id
self.length = temp_length

self.index[self.interact_id] = r_packet.index
self.seed[self.interact_id] = r_packet.seed
self.index = r_packet.index
self.seed = r_packet.seed
self.status[self.interact_id] = r_packet.status
self.r[self.interact_id] = r_packet.r
self.nu[self.interact_id] = r_packet.nu
Expand All @@ -402,8 +396,6 @@ def track(self, r_packet):
self.interact_id += 1

def finalize_array(self):
self.index = self.index[: self.interact_id]
self.seed = self.seed[: self.interact_id]
self.status = self.status[: self.interact_id]
self.r = self.r[: self.interact_id]
self.nu = self.nu[: self.interact_id]
Expand Down

0 comments on commit 60b76ef

Please sign in to comment.