Skip to content

Commit

Permalink
fix for reading in nessai result files
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Nov 14, 2023
1 parent 7eb9bad commit 4169605
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions pycbc/inference/io/nessai.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,55 @@
"""Provides IO for the nessai sampler"""
import numpy

from .base_nested_sampler import BaseNestedSamplerFile

from ...io.hdf import dump_state, load_state
from .posterior import write_samples_to_file
from .posterior import read_raw_samples_from_file, write_samples_to_file
from .dynesty import CommonNestedMetadataIO


class NessaiFile(BaseNestedSamplerFile):
class NessaiFile(CommonNestedMetadataIO, BaseNestedSamplerFile):
"""Class to handle file IO for the ``nessai`` sampler."""

name = "nessai_file"

def read_raw_samples(self, fields, raw_samples=False, seed=0):
"""Reads samples from a nessai file and constructs a posterior.
Using rejection sampling to resample the nested samples
Parameters
----------
fields : list of str
The names of the parameters to load. Names must correspond to
dataset names in the file's ``samples`` group.
raw_samples : bool, optional
Return the raw (unweighted) samples instead of the estimated
posterior samples. Default is False.
Returns
-------
dict :
Dictionary of parameter fields -> samples.
"""
samples = read_raw_samples_from_file(self, fields)
logwt = read_raw_samples_from_file(self, ['logwt'])['logwt']
loglikelihood = read_raw_samples_from_file(
self, ['loglikelihood'])['loglikelihood']
if not raw_samples:
N = len(logwt)
# Rejection sample
rng = numpy.random.default_rng(seed)
logwt -= logwt.max()
logu = numpy.log(rng.rand(N))
keep = logwt > logu
post = {'loglikelihood': loglikelihood[keep]}
for param in fields:
post[param] = samples[param][keep]
return post
else:
return samples

def write_pickled_data_into_checkpoint_file(self, data):
"""Write the pickled data into a checkpoint file"""
if "sampler_info/saved_state" not in self:
Expand Down

0 comments on commit 4169605

Please sign in to comment.