Skip to content

Commit

Permalink
Changed implementation to append seed & index value only once for a p…
Browse files Browse the repository at this point in the history
…articular RPacketTracker instance
  • Loading branch information
DhruvSondhi committed Jan 30, 2022
1 parent 6696741 commit 73aaaf7
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions tardis/montecarlo/montecarlo_numba/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def set_properties(

rpacket_collection_spec = [
("length", int64),
("seed", int64[:]),
("index", int64[:]),
("seed", int64),
("index", int64),
("status", int64[:]),
("r", float64[:]),
("nu", float64[:]),
Expand Down Expand Up @@ -348,10 +348,10 @@ class RPacketTracker(object):
Internal counter for the interactions that a particular RPacket undergoes
"""

def __init__(self, seed, index):
def __init__(self):
self.length = montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH
self.seed = seed
self.index = index
self.seed = np.int64(0)
self.index = np.int64(0)
self.status = np.empty(self.length, dtype=np.int64)
self.r = np.empty(self.length, dtype=np.float64)
self.nu = np.empty(self.length, dtype=np.float64)
Expand All @@ -363,26 +363,20 @@ def __init__(self, seed, index):
def track(self, r_packet):
if self.interact_id >= self.length:
temp_length = self.length * 2
temp_index = np.empty(temp_length, dtype=np.int64)
temp_seed = np.empty(temp_length, dtype=np.int64)
temp_status = np.empty(temp_length, dtype=np.int64)
temp_r = np.empty(temp_length, dtype=np.float64)
temp_nu = np.empty(temp_length, dtype=np.float64)
temp_mu = np.empty(temp_length, dtype=np.float64)
temp_energy = np.empty(temp_length, dtype=np.float64)
temp_shell_id = np.empty(temp_length, dtype=np.int64)

temp_index[: self.length] = self.index
temp_seed[: self.length] = self.seed
temp_status[: self.length] = self.status
temp_r[: self.length] = self.r
temp_nu[: self.length] = self.nu
temp_mu[: self.length] = self.mu
temp_energy[: self.length] = self.energy
temp_shell_id[: self.length] = self.shell_id

self.index = temp_index
self.seed = temp_seed
self.status = temp_status
self.r = temp_r
self.nu = temp_nu
Expand All @@ -391,8 +385,8 @@ def track(self, r_packet):
self.shell_id = temp_shell_id
self.length = temp_length

self.index[self.interact_id] = r_packet.index
self.seed[self.interact_id] = r_packet.seed
self.index = r_packet.index
self.seed = r_packet.seed
self.status[self.interact_id] = r_packet.status
self.r[self.interact_id] = r_packet.r
self.nu[self.interact_id] = r_packet.nu
Expand All @@ -402,8 +396,6 @@ def track(self, r_packet):
self.interact_id += 1

def finalize_array(self):
self.index = self.index[: self.interact_id]
self.seed = self.seed[: self.interact_id]
self.status = self.status[: self.interact_id]
self.r = self.r[: self.interact_id]
self.nu = self.nu[: self.interact_id]
Expand Down

0 comments on commit 73aaaf7

Please sign in to comment.