Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for nessai sampler in pycbc inference #4567

Merged
merged 24 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4db5808
add basic support for nessai sampler
mj-will Sep 6, 2023
155de12
enable all options and resuming in nessai
mj-will Sep 12, 2023
6d612e6
fix prior bounds in nessai model
mj-will Sep 26, 2023
54f18b2
tweak resuming and samples in nessai interface
mj-will Sep 29, 2023
94c27b2
change outdir to avoid namespace conflicts
mj-will Oct 4, 2023
809aef1
tweaks to nessai sampler class
mj-will Oct 4, 2023
7eb9bad
fix nessai checkpointing and other minor tweaks
mj-will Oct 6, 2023
4169605
fix for reading in nessai result files
mj-will Oct 16, 2023
25b8edb
use callback for checkpointing in nessai
mj-will Nov 29, 2023
2d757fd
start addressing codeclimate issues
mj-will Nov 29, 2023
4cf44cd
add nessai to auxiliary samplers
mj-will Nov 29, 2023
a1fe131
add additional comments for nessai
mj-will Nov 29, 2023
0735c55
make simple sampler example 2d
mj-will Nov 29, 2023
f9f204a
fix call to rng.random
mj-will Nov 30, 2023
bc369b6
add nessai to samplers example and update plot
mj-will Nov 30, 2023
cae7be6
set minimum version for nessai
mj-will Nov 30, 2023
b84fda6
force cpu-only version of torch
mj-will Nov 30, 2023
25c86ac
add missing epsie jump proposal
mj-will Nov 30, 2023
c871f50
add plot-marginal to samplers plot
mj-will Dec 4, 2023
7343d12
fix whitespace
mj-will Dec 11, 2023
2598c1a
use lazy formatting in logging functions
mj-will Dec 11, 2023
d1b33ff
move functions to common nested class
mj-will Dec 11, 2023
fc7bcf0
update for change common nested class
mj-will Dec 11, 2023
894f29d
address more code climate issues
mj-will Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions companion.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ cpnest
pymultinest
ultranest
https://github.com/willvousden/ptemcee/archive/master.tar.gz
# Force the cpu-only version of PyTorch
--extra-index-url https://download.pytorch.org/whl/cpu
torch
nessai>=0.11.0

# useful to look at PyCBC Live with htop
setproctitle
Expand Down
3 changes: 3 additions & 0 deletions examples/inference/samplers/epsie_stub.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ ntemps = 4

[jump_proposal-x]
name = normal

[jump_proposal-y]
name = normal
3 changes: 3 additions & 0 deletions examples/inference/samplers/nessai_stub.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[sampler]
name = nessai
nlive = 200
9 changes: 7 additions & 2 deletions examples/inference/samplers/run.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
for f in cpnest_stub.ini emcee_stub.ini emcee_pt_stub.ini dynesty_stub.ini ultranest_stub.ini epsie_stub.ini; do
for f in cpnest_stub.ini emcee_stub.ini emcee_pt_stub.ini dynesty_stub.ini ultranest_stub.ini epsie_stub.ini nessai_stub.ini; do
echo $f
pycbc_inference \
--config-files `dirname $0`/simp.ini `dirname $0`/$f \
Expand All @@ -16,4 +16,9 @@ dynesty_stub.ini.hdf:dynesty \
ultranest_stub.ini.hdf:ultranest \
epsie_stub.ini.hdf:espie \
cpnest_stub.ini.hdf:cpnest \
--output-file sample.png
nessai_stub.ini.hdf:nessai \
--output-file sample.png \
--plot-contours \
--plot-marginal \
--no-contour-labels \
--no-marginal-titles
6 changes: 6 additions & 0 deletions examples/inference/samplers/simp.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ name = test_normal

[variable_params]
x =
y =

[prior-x]
name = uniform
min-x = -10
max-x = 10

[prior-y]
name = uniform
min-y = -10
max-y = 10
2 changes: 2 additions & 0 deletions pycbc/inference/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .multinest import MultinestFile
from .dynesty import DynestyFile
from .ultranest import UltranestFile
from .nessai import NessaiFile
from .posterior import PosteriorFile
from .txt import InferenceTXTFile

Expand All @@ -49,6 +50,7 @@
DynestyFile.name: DynestyFile,
PosteriorFile.name: PosteriorFile,
UltranestFile.name: UltranestFile,
NessaiFile.name: NessaiFile,
}

try:
Expand Down
76 changes: 38 additions & 38 deletions pycbc/inference/io/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,44 @@ def extra_args_parser(parser=None, skip_args=None, **kwargs):
"extracted instead.")
return parser, actions

def write_pickled_data_into_checkpoint_file(self, state):
"""Dump the sampler state into checkpoint file
"""
if 'sampler_info/saved_state' not in self:
self.create_group('sampler_info/saved_state')
dump_state(state, self, path='sampler_info/saved_state')

def read_pickled_data_from_checkpoint_file(self):
"""Load the sampler state (pickled) from checkpoint file
"""
return load_state(self, path='sampler_info/saved_state')

def write_raw_samples(self, data, parameters=None):
"""Write the nested samples to the file
"""
if 'samples' not in self:
self.create_group('samples')
write_samples_to_file(self, data, parameters=parameters,
group='samples')
def validate(self):
"""Runs a validation test.
This checks that a samples group exist, and that pickeled data can
be loaded.

Returns
-------
bool :
Whether or not the file is valid as a checkpoint file.
"""
try:
if 'sampler_info/saved_state' in self:
load_state(self, path='sampler_info/saved_state')
checkpoint_valid = True
except KeyError:
checkpoint_valid = False
return checkpoint_valid


class DynestyFile(CommonNestedMetadataIO, BaseNestedSamplerFile):
"""Class to handle file IO for the ``dynesty`` sampler."""

Expand Down Expand Up @@ -148,41 +186,3 @@ def read_raw_samples(self, fields, raw_samples=False, seed=0):
return post
else:
return samples

def write_pickled_data_into_checkpoint_file(self, state):
"""Dump the sampler state into checkpoint file
"""
if 'sampler_info/saved_state' not in self:
self.create_group('sampler_info/saved_state')
dump_state(state, self, path='sampler_info/saved_state')

def read_pickled_data_from_checkpoint_file(self):
"""Load the sampler state (pickled) from checkpoint file
"""
return load_state(self, path='sampler_info/saved_state')

def write_raw_samples(self, data, parameters=None):
"""Write the nested samples to the file
"""
if 'samples' not in self:
self.create_group('samples')
write_samples_to_file(self, data, parameters=parameters,
group='samples')

def validate(self):
"""Runs a validation test.
This checks that a samples group exist, and that pickeled data can
be loaded.

Returns
-------
bool :
Whether or not the file is valid as a checkpoint file.
"""
try:
if 'sampler_info/saved_state' in self:
load_state(self, path='sampler_info/saved_state')
checkpoint_valid = True
except KeyError:
checkpoint_valid = False
return checkpoint_valid
49 changes: 49 additions & 0 deletions pycbc/inference/io/nessai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Provides IO for the nessai sampler"""
import numpy

from .base_nested_sampler import BaseNestedSamplerFile

from .posterior import read_raw_samples_from_file
from .dynesty import CommonNestedMetadataIO


class NessaiFile(CommonNestedMetadataIO, BaseNestedSamplerFile):
"""Class to handle file IO for the ``nessai`` sampler."""

name = "nessai_file"

def read_raw_samples(self, fields, raw_samples=False, seed=0):
"""Reads samples from a nessai file and constructs a posterior.

Using rejection sampling to resample the nested samples

Parameters
----------
fields : list of str
The names of the parameters to load. Names must correspond to
dataset names in the file's ``samples`` group.
raw_samples : bool, optional
Return the raw (unweighted) samples instead of the estimated
posterior samples. Default is False.

Returns
-------
dict :
Dictionary of parameter fields -> samples.
"""
samples = read_raw_samples_from_file(self, fields)
logwt = read_raw_samples_from_file(self, ['logwt'])['logwt']
loglikelihood = read_raw_samples_from_file(
self, ['loglikelihood'])['loglikelihood']
if not raw_samples:
n_samples = len(logwt)
# Rejection sample
rng = numpy.random.default_rng(seed)
logwt -= logwt.max()
logu = numpy.log(rng.random(n_samples))
keep = logwt > logu
post = {'loglikelihood': loglikelihood[keep]}
for param in fields:
post[param] = samples[param][keep]
return post
return samples
7 changes: 7 additions & 0 deletions pycbc/inference/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@
except ImportError:
pass

try:
from .nessai import NessaiSampler
samplers[NessaiSampler.name] = NessaiSampler
except ImportError:
pass


def load_from_config(cp, model, **kwargs):
"""Loads a sampler from the given config file.

Expand Down
Loading
Loading