Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasheinrich committed Sep 16, 2019
1 parent 82a8bf0 commit b328d9d
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 43 deletions.
74 changes: 56 additions & 18 deletions src/pyhf/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,17 @@ def _dataprojection(self, auxdata):
normal_data = tensorlib.gather(auxdata, self.normal_data)
return normal_data

def logpdf(self, auxdata, pars):
def make_pdf(self, pars):
'''
Args:
pars: the parameter tensor
Returns:
pdf: the pdf object for the Normal Constraint
'''
tensorlib, _ = get_backend()
if not self.param_viewer.index_selection:
return (
tensorlib.zeros(self.batch_size)
if self.batch_size is not None
else tensorlib.astensor(0.0)[0]
)

return None
pars = tensorlib.astensor(pars)
if self.batch_size == 1 or self.batch_size is None:
batched_pars = tensorlib.reshape(
Expand All @@ -116,9 +118,27 @@ def logpdf(self, auxdata, pars):

result = prob.Independent(
prob.Normal(normal_means, self.sigmas), batch_size=self.batch_size
).log_prob(self._dataprojection(auxdata))
)
return result

def logpdf(self, auxdata, pars):
'''
Args:
maindata: the aux data (a subset of the full data in a HistFactory model)
Returns:
log pdf value: the log of the pdf value of the Normal constraints
'''
tensorlib, _ = get_backend()
pdf = self.make_pdf(pars)
if pdf is None:
return (
tensorlib.zeros(self.batch_size)
if self.batch_size is not None
else tensorlib.astensor(0.0)[0]
)
return pdf.log_prob(self._dataprojection(auxdata))


class poisson_constraint_combined(object):
def __init__(self, pdfconfig, batch_size=None):
Expand Down Expand Up @@ -202,14 +222,17 @@ def _dataprojection(self, auxdata):
poisson_data = tensorlib.gather(auxdata, self.poisson_data)
return poisson_data

def logpdf(self, auxdata, pars):
def make_pdf(self, pars):
'''
Args:
pars: the parameter tensor
Returns:
pdf: the pdf object for the Poisson Constraint
'''
tensorlib, _ = get_backend()
if not self.param_viewer.index_selection:
return (
tensorlib.zeros(self.batch_size)
if self.batch_size is not None
else tensorlib.astensor(0.0)[0]
)
return None
tensorlib, _ = get_backend()

pars = tensorlib.astensor(pars)
Expand All @@ -231,7 +254,22 @@ def logpdf(self, auxdata, pars):
if self.batch_size is None:
pois_rates = pois_rates[0]
# pdf pars are done, now get data and compute
result = prob.Independent(
prob.Poisson(pois_rates), batch_size=self.batch_size
).log_prob(self._dataprojection(auxdata))
return result
return prob.Independent(prob.Poisson(pois_rates), batch_size=self.batch_size)

def logpdf(self, auxdata, pars):
'''
Args:
maindata: the aux data (a subset of the full data in a HistFactory model)
Returns:
log pdf value: the log of the pdf value of the Poisson constraints
'''
tensorlib, _ = get_backend()
pdf = self.make_pdf(pars)
if pdf is None:
return (
tensorlib.zeros(self.batch_size)
if self.batch_size is not None
else tensorlib.astensor(0.0)[0]
)
return pdf.log_prob(self._dataprojection(auxdata))
105 changes: 88 additions & 17 deletions src/pyhf/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def _create_and_register_paramsets(


class _ConstraintModel(object):
'''
Factory class to create pdfs for the constraint terms
'''

def __init__(self, config, batch_size):
self.batch_size = batch_size
self.config = config
Expand Down Expand Up @@ -217,14 +221,50 @@ def _dataprojection(self, data):
cut = tensorlib.shape(data)[0] - len(self.config.auxdata)
return data[cut:]

def make_pdf(self, pars):
'''
Args:
pars: the parameter tensor
Returns:
pdf: A distribution object implementing the constraint pdf of HistFactory.
Either a Poissonn, a Gaussian or a joint pdf of both depending on the
constraints used in the specification.
'''
indices = []
pdfobjs = []

gaussian_pdf = self.constraints_gaussian.make_pdf(pars)
if gaussian_pdf:
indices.append(self.constraints_gaussian._normal_data)
pdfobjs.append(gaussian_pdf)

poisson_pdf = self.constraints_poisson.make_pdf(pars)
if poisson_pdf:
indices.append(self.constraints_poisson._poisson_data)
pdfobjs.append(poisson_pdf)

if pdfobjs:
simpdf = prob.Simultaneous(pdfobjs, indices)
return simpdf

def logpdf(self, auxdata, pars):
tensorlib, _ = get_backend()
normal = self.constraints_gaussian.logpdf(auxdata, pars)
poisson = self.constraints_poisson.logpdf(auxdata, pars)
return prob.joint_logpdf([normal, poisson])
'''
Args:
auxdata: the auxiliary data (a subset of the full data in a HistFactory model)
Returns:
log pdf value: the log of the pdf value
'''
simpdf = self.make_pdf(pars)
return simpdf.log_prob(auxdata)


class _MainModel(object):
'''
Factory class to create pdfs for the main measurement
'''

def __init__(self, config, mega_mods, nominal_rates, batch_size):
self.config = config
self._factor_mods = [
Expand Down Expand Up @@ -261,11 +301,19 @@ def _precompute(self):
tensorlib, _ = get_backend()
self.nominal_rates = tensorlib.astensor(self._nominal_rates)

def logpdf(self, maindata, pars):
tensorlib, _ = get_backend()
def make_pdf(self, pars):
lambdas_data = self.expected_data(pars)
result = prob.Independent(prob.Poisson(lambdas_data)).log_prob(maindata)
return result
return prob.Independent(prob.Poisson(lambdas_data))

def logpdf(self, maindata, pars):
'''
Args:
maindata: the main channnel data (a subset of the full data in a HistFactory model)
Returns:
log pdf value: the log of the pdf value
'''
return self.make_pdf(pars).log_prob(maindata)

def _dataprojection(self, data):
tensorlib, _ = get_backend()
Expand Down Expand Up @@ -534,18 +582,41 @@ def constraint_logpdf(self, auxdata, pars):
def mainlogpdf(self, maindata, pars):
return self.main_model.logpdf(maindata, pars)

def logpdf(self, pars, data):
try:
tensorlib, _ = get_backend()
pars, data = tensorlib.astensor(pars), tensorlib.astensor(data)
def make_pdf(self, pars):
'''
Args:
pars: the parameter tensor
Returns:
pdf: A distribution object implementing the main measurement pdf of HistFactory
'''
tensorlib, _ = get_backend()
pars = tensorlib.astensor(pars)

cut = self.nominal_rates.shape[-1]
total_size = cut + len(self.config.auxdata)
position = list(range(total_size))

pdfobjs = []
indices = []

actual_data = self.main_model._dataprojection(data)
aux_data = self.constraint_model._dataprojection(data)
mainpdf = self.main_model.make_pdf(pars)
pdfobjs.append(mainpdf)
indices.append(position[:cut])

mainpdf = self.main_model.logpdf(actual_data, pars)
constraintpdf = self.constraint_model.logpdf(aux_data, pars)
constraintpdf = self.constraint_model.make_pdf(pars)
if constraintpdf:
pdfobjs.append(constraintpdf)
indices.append(position[cut:])

result = prob.joint_logpdf([mainpdf, constraintpdf])
simpdf = prob.Simultaneous(pdfobjs, indices)
return simpdf

def logpdf(self, pars, data):
try:
tensorlib, _ = get_backend()
result = self.make_pdf(pars).log_prob(data)

if (
not self.batch_size
Expand Down
16 changes: 14 additions & 2 deletions src/pyhf/probability.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import get_backend
from .tensor.common import TensorViewer


class Poisson(object):
Expand All @@ -24,11 +25,11 @@ def log_prob(self, value):


class Independent(object):
'''
"""
A probability density corresponding to the joint
distribution of a batch of identically distributed random
numbers.
'''
"""

def __init__(self, batched_pdf, batch_size=None):
self.batch_size = batch_size
Expand All @@ -40,6 +41,17 @@ def log_prob(self, value):
return tensorlib.sum(_log_prob, axis=-1)


class Simultaneous(object):
def __init__(self, pdfobjs, indices):
self.tv = TensorViewer(indices)
self.pdfobjs = pdfobjs

def log_prob(self, data):
constituent_data = self.tv.split(data)
pdfvals = [p.log_prob(d) for p, d in zip(self.pdfobjs, constituent_data)]
return joint_logpdf(pdfvals)


def joint_logpdf(terms):
tensorlib, _ = get_backend()
terms = tensorlib.stack(terms)
Expand Down
48 changes: 48 additions & 0 deletions src/pyhf/tensor/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from .. import default_backend, get_backend
from .. import events


class TensorViewer(object):
def __init__(self, indices):
# self.partition_indices has the "target" indices
# of the stitched vector. In order to .gather()
# an concatennation of source arrays into the
# desired form, one needs to gather on the "sorted"
# indices
# >>> source = np.asarray([9,8,7,6])
# >>> target = np.asarray([2,1,3,0])
# >>> source[target.argsort()]
# array([6, 8, 9, 7])

self.partition_indices = indices
a = default_backend.astensor(
default_backend.concatenate(self.partition_indices), dtype='int'
)
self._sorted_indices = default_backend.tolist(a.argsort())

self._precompute()
events.subscribe('tensorlib_changed')(self._precompute)

def _precompute(self):
tensorlib, _ = get_backend()
self.sorted_indices = tensorlib.astensor(self._sorted_indices)

def stitch(self, data):
tensorlib, _ = get_backend()
assert len(self.partition_indices) == len(data)

data = tensorlib.concatenate(data, axis=-1)
data = tensorlib.einsum('...j->j...', data)
stitched = tensorlib.gather(data, tensorlib.astensor(self.sorted_indices))
stitched = tensorlib.einsum('...j->j...', stitched)
return stitched

def split(self, data):
tensorlib, _ = get_backend()
data = tensorlib.astensor(data)
data = tensorlib.einsum('...j->j...', tensorlib.astensor(data))
split = [
tensorlib.einsum('...j->j...', tensorlib.gather(data, idx))
for idx in self.partition_indices
]
return split
3 changes: 2 additions & 1 deletion src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def astensor(self, tensor_in, dtype='float'):
return tensor

def gather(self, tensor, indices):
return torch.take(tensor, indices.type(torch.LongTensor))
indices = self.astensor(indices, dtype='int')
return tensor[indices.type(torch.LongTensor)]

def boolean_mask(self, tensor, mask):
mask = self.astensor(mask, dtype='bool')
Expand Down
10 changes: 5 additions & 5 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,19 @@ def spec_1bin_lumi(source=source_1bin_example1()):
"name": "channel1",
"samples": [
{
"data": [20.0, 10.0],
"data": [20.0],
"modifiers": [
{"data": None, "name": "mu", "type": "normfactor"}
],
"name": "signal",
},
{
"data": [100.0, 0.0],
"data": [100.0],
"modifiers": [{"data": None, "name": "lumi", "type": "lumi"}],
"name": "background1",
},
{
"data": [0.0, 100.0],
"data": [0.0],
"modifiers": [{"data": None, "name": "lumi", "type": "lumi"}],
"name": "background2",
},
Expand All @@ -120,8 +120,8 @@ def spec_1bin_lumi(source=source_1bin_example1()):
def expected_result_1bin_lumi(mu=1.0):
if mu == 1:
expected_result = {
"exp": [0.00905976, 0.0357287, 0.12548957, 0.35338293, 0.69589171],
"obs": 0.00941757,
"exp": [0.01060338, 0.04022273, 0.13614217, 0.37078321, 0.71104119],
"obs": 0.01047275,
}
return expected_result

Expand Down

0 comments on commit b328d9d

Please sign in to comment.