Skip to content

Commit

Permalink
Arepo mass average fix (#2250)
Browse files Browse the repository at this point in the history
* Added optional mass averaged densities

* Added typo fixes

* Updated docs

* Fixed missing averaging

* Rebuild docs

* Fixed mass binning

* Masses no longer optional

* Removed statistics keyword

* Added volume quantity

* Updated docs and tests

* Updated test model data

* Rerun tests

* Rerun docs
  • Loading branch information
AlexHls authored Apr 3, 2023
1 parent fbb784b commit 3174757
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 95 deletions.

Large diffs are not rendered by default.

176 changes: 131 additions & 45 deletions tardis/io/parsers/arepo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import os
import sys
import argparse
import warnings

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats


Expand Down Expand Up @@ -51,7 +48,6 @@ def __init__(

try:
import gadget_snap
import calcGrid
except ModuleNotFoundError:
raise ImportError(
"Please make sure you have arepo-snap-util installed if you want to directly import Arepo snapshots."
Expand All @@ -70,6 +66,7 @@ def __init__(
hdf5=True,
quiet=True,
lazy_load=True,
loadonlytype=[0],
)

rz_yaw = np.array(
Expand Down Expand Up @@ -99,13 +96,14 @@ def __init__(
self.s.rotateto(rotmat[0], dir2=rotmat[1], dir3=rotmat[2])

self.time = self.s.time
self.pos = np.array(self.s.data["pos"][: self.s.nparticlesall[0]])
self.pos = np.array(self.s.data["pos"])
self.pos = self.pos.T
# Update position to CoM frame
for i in range(3):
self.pos[i] -= self.s.centerofmass()[i]
self.rho = np.array(self.s.data["rho"])
self.vel = np.array(self.s.data["vel"][: self.s.nparticlesall[0]])
self.mass = np.array(self.s.data["mass"])
self.vel = np.array(self.s.data["vel"])
self.vel = self.vel.T
self.nuc_dict = {}

Expand All @@ -118,7 +116,7 @@ def get_grids(self):
"""
Returns all relevant data to create Profile objects
"""
return self.pos, self.vel, self.rho, self.nuc_dict, self.time
return self.pos, self.vel, self.rho, self.mass, self.nuc_dict, self.time


class Profile:
Expand All @@ -127,7 +125,7 @@ class Profile:
e.g. for plotting and export.
"""

def __init__(self, pos, vel, rho, xnuc, time):
def __init__(self, pos, vel, rho, mass, xnuc, time):
"""
Parameters
----------
Expand All @@ -138,6 +136,8 @@ def __init__(self, pos, vel, rho, xnuc, time):
Meshgrid of velocities/ velocity vectors
rho : list of float
Meshgrid of density
mass : list of float
Meshgrid of masses.
xnuc : dict
Dictonary containing all the nuclear fraction
meshgrids of the relevant species.
Expand All @@ -151,7 +151,8 @@ def __init__(self, pos, vel, rho, xnuc, time):
self.rho = rho
self.xnuc = xnuc
self.time = time

self.mass = mass
self.vol = self.mass / self.rho
self.species = list(self.xnuc.keys())

# Empty values to be filled with the create_profile function
Expand All @@ -161,6 +162,12 @@ def __init__(self, pos, vel, rho, xnuc, time):
self.vel_prof_p = None
self.vel_prof_n = None

self.vol_prof_p = None
self.vol_prof_n = None

self.mass_prof_p = None
self.mass_prof_n = None

self.rho_prof_p = None
self.rho_prof_n = None

Expand Down Expand Up @@ -211,7 +218,7 @@ def plot_profile(self, save=None, dpi=600, **kwargs):
ax1.set_ylabel("Profile (arb. unit)")
ax1.set_title("Profiles along the positive axis")

# Positive direction plots
# Negative direction plots
ax2.plot(
self.pos_prof_n,
self.rho_prof_n / max(self.rho_prof_n),
Expand Down Expand Up @@ -257,7 +264,7 @@ def plot_profile(self, save=None, dpi=600, **kwargs):

return fig

def rebin(self, nshells, statistic="mean"):
def rebin(self, nshells):
"""
Rebins the data to nshells. Uses the scipy.stats.binned_statistic
to bin the data. The standard deviation of each bin can be obtained
Expand All @@ -267,8 +274,6 @@ def rebin(self, nshells, statistic="mean"):
----------
nshells : int
Number of bins of new data.
statistic : str
Scipy keyword for scipy.stats.binned_statistic. Default: mean
Returns
-------
Expand All @@ -278,43 +283,87 @@ def rebin(self, nshells, statistic="mean"):

self.vel_prof_p, bins_p = stats.binned_statistic(
self.pos_prof_p,
self.vel_prof_p,
statistic=statistic,
self.vel_prof_p * self.mass_prof_p,
statistic="mean",
bins=nshells,
)[:2]
self.vel_prof_p /= stats.binned_statistic(
self.pos_prof_p,
self.mass_prof_p,
statistic="mean",
bins=nshells,
)[0]
self.vel_prof_n, bins_n = stats.binned_statistic(
self.pos_prof_n,
self.vel_prof_n,
statistic=statistic,
self.vel_prof_n * self.mass_prof_n,
statistic="mean",
bins=nshells,
)[:2]
self.vel_prof_n /= stats.binned_statistic(
self.pos_prof_n,
self.mass_prof_n,
statistic="mean",
bins=nshells,
)[0]

for spec in self.species:
self.xnuc_prof_p[spec] = (
stats.binned_statistic(
self.pos_prof_p,
self.xnuc_prof_p[spec] * self.mass_prof_p,
statistic="mean",
bins=nshells,
)[0]
/ stats.binned_statistic(
self.pos_prof_p,
self.mass_prof_p,
statistic="mean",
bins=nshells,
)[0]
)
self.xnuc_prof_n[spec] = (
stats.binned_statistic(
self.pos_prof_n,
self.xnuc_prof_n[spec] * self.mass_prof_n,
statistic="mean",
bins=nshells,
)[0]
/ stats.binned_statistic(
self.pos_prof_n,
self.mass_prof_n,
statistic="mean",
bins=nshells,
)[0]
)

self.vol_prof_p = np.array(
[
4 / 3 * np.pi * (bins_p[i + 1] ** 3 - bins_p[i] ** 3)
for i in range(len(bins_p) - 1)
]
)
self.vol_prof_n = np.array(
[
4 / 3 * np.pi * (bins_n[i + 1] ** 3 - bins_n[i] ** 3)
for i in range(len(bins_n) - 1)
]
)

self.rho_prof_p = stats.binned_statistic(
self.mass_prof_p = stats.binned_statistic(
self.pos_prof_p,
self.rho_prof_p,
statistic=statistic,
self.mass_prof_p,
statistic="sum",
bins=nshells,
)[0]
self.rho_prof_n = stats.binned_statistic(
self.mass_prof_n = stats.binned_statistic(
self.pos_prof_n,
self.rho_prof_n,
statistic=statistic,
self.mass_prof_n,
statistic="sum",
bins=nshells,
)[0]

for spec in self.species:
self.xnuc_prof_p[spec] = stats.binned_statistic(
self.pos_prof_p,
self.xnuc_prof_p[spec],
statistic=statistic,
bins=nshells,
)[0]
self.xnuc_prof_n[spec] = stats.binned_statistic(
self.pos_prof_n,
self.xnuc_prof_n[spec],
statistic=statistic,
bins=nshells,
)[0]
self.rho_prof_p = self.mass_prof_p / self.vol_prof_p
self.rho_prof_n = self.mass_prof_n / self.vol_prof_n

self.pos_prof_p = np.array(
[(bins_p[i] + bins_p[i + 1]) / 2 for i in range(len(bins_p) - 1)]
Expand All @@ -330,7 +379,6 @@ def export(
nshells,
filename,
direction="pos",
statistic="mean",
overwrite=False,
):
"""
Expand All @@ -349,9 +397,6 @@ def export(
Specifies if either the positive or negative
direction is to be exported. Available
options: ['pos', 'neg']. Default: pos
statistic : str
Scipy keyword for scipy.stats.binned_statistic. If
statistic=None, data is not rebinned. Default: "mean"
overwrite: bool
If true, will overwrite if a file of the same name exists.
By default False.
Expand Down Expand Up @@ -429,8 +474,7 @@ def export(
f.write("".join(datastring))

# Rebin data to nshells
if statistic is not None:
self.rebin(nshells, statistic=statistic)
self.rebin(nshells)

if direction == "pos":
exp = [
Expand Down Expand Up @@ -468,6 +512,8 @@ def get_profiles(self):
self.vel_prof_n,
self.rho_prof_p,
self.rho_prof_n,
self.mass_prof_p,
self.mass_prof_n,
self.xnuc_prof_p,
self.xnuc_prof_n,
)
Expand Down Expand Up @@ -548,6 +594,12 @@ def create_profile(
+ self.vel[2][cmask_n] ** 2
)

vol_p = self.vol[cmask_p]
vol_n = self.vol[cmask_n]

mass_p = self.mass[cmask_p]
mass_n = self.mass[cmask_n]

rho_p = self.rho[cmask_p]
rho_n = self.rho[cmask_n]

Expand Down Expand Up @@ -585,6 +637,20 @@ def create_profile(
if not mask_p.any() or not mask_n.any():
raise ValueError("No points left between inner and outer radius.")

self.vol_prof_p = np.array(
[x for _, x in sorted(zip(pos_p, vol_p), key=lambda pair: pair[0])]
)[mask_p]
self.vol_prof_n = np.array(
[x for _, x in sorted(zip(pos_n, vol_n), key=lambda pair: pair[0])]
)[mask_n]

self.mass_prof_p = np.array(
[x for _, x in sorted(zip(pos_p, mass_p), key=lambda pair: pair[0])]
)[mask_p]
self.mass_prof_n = np.array(
[x for _, x in sorted(zip(pos_n, mass_n), key=lambda pair: pair[0])]
)[mask_n]

self.rho_prof_p = np.array(
[x for _, x in sorted(zip(pos_p, rho_p), key=lambda pair: pair[0])]
)[mask_p]
Expand Down Expand Up @@ -669,6 +735,12 @@ def create_profile(
self.vel[0] ** 2 + self.vel[1] ** 2 + self.vel[2] ** 2
).flatten()

vol_p = self.vol.flatten()
vol_n = self.vol.flatten()

mass_p = self.mass.flatten()
mass_n = self.mass.flatten()

rho_p = self.rho.flatten()
rho_n = self.rho.flatten()

Expand Down Expand Up @@ -706,6 +778,20 @@ def create_profile(
if not mask_p.any() or not mask_n.any():
raise ValueError("No points left between inner and outer radius.")

self.vol_prof_p = np.array(
[x for _, x in sorted(zip(pos_p, vol_p), key=lambda pair: pair[0])]
)[mask_p]
self.vol_prof_n = np.array(
[x for _, x in sorted(zip(pos_n, vol_n), key=lambda pair: pair[0])]
)[mask_n]

self.mass_prof_p = np.array(
[x for _, x in sorted(zip(pos_p, mass_p), key=lambda pair: pair[0])]
)[mask_p]
self.mass_prof_n = np.array(
[x for _, x in sorted(zip(pos_n, mass_n), key=lambda pair: pair[0])]
)[mask_n]

self.rho_prof_p = np.array(
[x for _, x in sorted(zip(pos_p, rho_p), key=lambda pair: pair[0])]
)[mask_p]
Expand Down Expand Up @@ -844,12 +930,12 @@ def create_profile(
numthreads=args.numthreads,
)

pos, vel, rho, xnuc, time = snapshot.get_grids()
pos, vel, rho, mass, xnuc, time = snapshot.get_grids()

if args.profile == "cone":
profile = ConeProfile(pos, vel, rho, xnuc, time)
profile = ConeProfile(pos, vel, rho, xnuc, time, mass=mass)
elif args.profile == "full":
profile = FullProfile(pos, vel, rho, xnuc, time)
profile = FullProfile(pos, vel, rho, xnuc, time, mass=mass)

if args.profile == "cone":
profile.create_profile(
Expand Down
40 changes: 20 additions & 20 deletions tardis/io/tests/data/arepo_cone_reference_model.csvy
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ datatype:

---
velocity,density,Ni56,Si28
1.0274e+09,0.143612,0.835847,0.0298718
1.07789e+09,0.128768,0.697991,0.0780773
1.12815e+09,0.118203,0.474194,0.177468
1.17816e+09,0.107473,0.249352,0.30694
1.22805e+09,0.0925248,0.136718,0.393771
1.27823e+09,0.076736,0.0999367,0.431174
1.32841e+09,0.06336,0.0785446,0.455648
1.37871e+09,0.0517966,0.0522142,0.459173
1.42888e+09,0.0418666,0.0206574,0.440613
1.47917e+09,0.0339696,0.00318851,0.447642
1.52987e+09,0.0316967,0.000112875,0.475979
1.57926e+09,0.0319146,1.73664e-06,0.422262
1.62842e+09,0.0251404,3.71752e-08,0.311981
1.67827e+09,0.0158305,1.4744e-09,0.248895
1.72846e+09,0.0094452,1.77961e-10,0.229273
1.77912e+09,0.00575855,4.84339e-11,0.214382
1.82919e+09,0.00374244,1.92704e-11,0.193081
1.87927e+09,0.00259477,1.20535e-11,0.165905
1.93113e+09,0.00186891,6.7766e-12,0.136604
1.98181e+09,0.00139947,5.09661e-12,0.110707
1.02744e+09,0.039332,0.784383,0.0458134
1.07784e+09,0.0360818,0.608716,0.113525
1.12803e+09,0.0334366,0.369558,0.231707
1.1779e+09,0.0300155,0.190785,0.348145
1.2279e+09,0.0256713,0.121151,0.408203
1.2781e+09,0.0212718,0.0992874,0.431875
1.32831e+09,0.0178633,0.0794732,0.455453
1.37856e+09,0.0149001,0.0573543,0.462553
1.42874e+09,0.0120893,0.028824,0.44569
1.47872e+09,0.00986893,0.00774774,0.437655
1.52935e+09,0.00835016,0.000694398,0.465911
1.58016e+09,0.0081748,1.77478e-05,0.476364
1.62958e+09,0.00811114,4.8074e-07,0.415122
1.67864e+09,0.00625829,2.06007e-08,0.313869
1.72832e+09,0.00423129,1.79194e-09,0.255375
1.77909e+09,0.00278635,3.505e-10,0.235798
1.82952e+09,0.00188048,1.06163e-10,0.227405
1.87985e+09,0.00137231,4.77628e-11,0.218543
1.93045e+09,0.00100844,2.41338e-11,0.204091
1.98087e+09,0.000773981,1.42833e-11,0.182151
Loading

0 comments on commit 3174757

Please sign in to comment.