Skip to content

Commit

Permalink
beat.summarize: allow summarize --calc_dervied for duplicate varnames
Browse files Browse the repository at this point in the history
  • Loading branch information
hvasbath committed Mar 25, 2024
1 parent 85db0bc commit 539972f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
38 changes: 21 additions & 17 deletions beat/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def result_check(mtrace, min_length):

def command_summarize(args):
from arviz import summary
from numpy import ravel, split, vstack
from numpy import hstack, ravel, split, vstack
from pyrocko.gf import RectangularSource

command_str = "summarize"
Expand Down Expand Up @@ -1288,23 +1288,27 @@ def setup(parser):
# BEAT sources calculate derived params
if options.calc_derived:
composite.point2sources(point)
if hasattr(source, "get_derived_parameters"):
if options.mode == geometry_mode_str:
for source in sources:
deri = source.get_derived_parameters(
point=reference, # need to pass correction params
store=store,
target=target,
event=problem.config.event,
)
derived.append(deri)

# pyrocko Rectangular source, TODO use BEAT RS ...
elif isinstance(source, RectangularSource):
for source in sources:
source.magnitude = None
derived.append(
source.get_magnitude(store=store, target=target)
)
if hasattr(source, "get_derived_parameters"):
deri = source.get_derived_parameters(
point=reference, # need to pass correction params
store=store,
target=target,
event=problem.config.event,
)
derived.append(deri)

# pyrocko Rectangular source, TODO use BEAT RS ...
elif isinstance(source, RectangularSource):
source.magnitude = None
derived.append(
source.get_magnitude(store=store, target=target)
)

if len(pc.source_types) > 1:
derived = [hstack(derived)]

elif options.mode == bem_mode_str:
response = composite.engine.process(
sources=composite.sources, targets=composite.targets
Expand Down
28 changes: 18 additions & 10 deletions beat/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def __init__(
self.draws = 0
self._df = None
self.filename = None
self.derived_mapping = None

def __len__(self):
if self.filename is None:
Expand All @@ -339,18 +340,20 @@ def add_derived_variables(self, varnames, shapes):
"Inconsistent number of variables %i and shapes %i!" % (nvars, nshapes)
)

self.derived_mapping = {}
for varname, shape in zip(varnames, shapes):
if varname in self.varnames:
# TODO for mixed source setups needs resolving
raise ValueError(
"Sampled stage contains parameter `%s` to be summarized! "
"--calc_derived cannot be used! Needs patching ..." % varname
)
else:
self.flat_names[varname] = _create_flat_names(varname, shape)
self.var_shapes[varname] = shape
self.var_dtypes[varname] = "float64"
self.varnames.append(varname)
exist_idx = self.varnames.index(varname)
self.varnames.pop(exist_idx)
exist_shape = self.var_shapes[varname]
shape = tuple(map(sum, zip(exist_shape, shape)))
concat_idx = len(self.varnames)
self.derived_mapping[exist_idx] = concat_idx

self.flat_names[varname] = _create_flat_names(varname, shape)
self.var_shapes[varname] = shape
self.var_dtypes[varname] = "float64"
self.varnames.append(varname)

def _load_df(self):
raise ValueError("This method must be defined in inheriting classes!")
Expand Down Expand Up @@ -391,6 +394,11 @@ def write(self, lpoint, draw):
If buffer is full write samples to file.
"""
self.count += 1
if self.derived_mapping:
for exist_idx, concat_idx in self.derived_mapping.items():
value = lpoint.pop(exist_idx)
lpoint[concat_idx] = num.hstack((value, lpoint[concat_idx]))

self.buffer.append((lpoint, draw))
if self.count == self.buffer_size:
self.record_buffer()
Expand Down

0 comments on commit 539972f

Please sign in to comment.