Skip to content

Commit

Permalink
Merge branch 'master' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray authored Jan 8, 2024
2 parents c850db4 + 5682d83 commit ef73366
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 91 deletions.
7 changes: 1 addition & 6 deletions src/coffea/lookup_tools/correctionlib_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import dask

from coffea.lookup_tools.lookup_base import lookup_base


class correctionlib_wrapper(lookup_base):
def __init__(self, payload):
super().__init__()
self._corr = payload
dask_future = dask.delayed(
self, pure=True, name=f"{self._corr.name}-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
return self._corr.evaluate(*args)
Expand Down
6 changes: 1 addition & 5 deletions src/coffea/lookup_tools/dense_evaluated_lookup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from copy import deepcopy

import dask
import numba
import numpy

Expand Down Expand Up @@ -31,6 +30,7 @@ def numbaize(fstr, varlist):
# methods for dealing with b-tag SFs
class dense_evaluated_lookup(lookup_base):
def __init__(self, values, dims, feval_dim=None):
super().__init__()
self._dimension = 0
whattype = type(dims)
if whattype == numpy.ndarray:
Expand Down Expand Up @@ -66,10 +66,6 @@ def __init__(self, values, dims, feval_dim=None):
"lookup_tools.evaluator only accepts 1D functions right now!"
)
self._feval_dim = feval_dim[0]
dask_future = dask.delayed(
self, pure=True, name=f"denseevallookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
indices = []
Expand Down
9 changes: 3 additions & 6 deletions src/coffea/lookup_tools/dense_lookup.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from copy import deepcopy

import dask
import numpy

from coffea.lookup_tools.lookup_base import lookup_base


class dense_lookup(lookup_base):
def __init__(self, values, dims, feval_dim=None):
super().__init__()
self._dimension = 0
whattype = type(dims)
if whattype == numpy.ndarray:
Expand All @@ -29,10 +29,6 @@ def __init__(self, values, dims, feval_dim=None):
if vals_are_strings:
raise Exception("dense_lookup cannot handle string values!")
self._values = deepcopy(values)
dask_future = dask.delayed(
self, pure=True, name=f"denselookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
if len(args) != self._dimension:
Expand Down Expand Up @@ -62,7 +58,8 @@ def _evaluate(self, *args, **kwargs):
return self._values[tuple(indices)]

def __repr__(self):
myrepr = f"{self._dimension} dimensional histogram with axes:\n"
myrepr = object.__repr__(self)
myrepr += f" {self._dimension} dimensional histogram with axes:\n"
temp = ""
if self._dimension == 1:
temp = f"\t1: {self._axes}\n"
Expand Down
6 changes: 1 addition & 5 deletions src/coffea/lookup_tools/dense_mapped_lookup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numbers
from threading import Lock

import dask
import numba
import numpy

Expand All @@ -13,14 +12,11 @@ class dense_mapped_lookup(lookup_base):
_formulaCache = {}

def __init__(self, axes, mapping, formulas, feval_dim):
super().__init__()
self._axes = axes
self._mapping = mapping
self._formulas = formulas
self._feval_dim = feval_dim
dask_future = dask.delayed(
self, pure=True, name=f"densemappedlookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

@classmethod
def _compile(cls, formula):
Expand Down
6 changes: 1 addition & 5 deletions src/coffea/lookup_tools/jec_uncertainty_lookup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from copy import deepcopy

import dask
import numpy
from scipy.interpolate import interp1d

Expand Down Expand Up @@ -37,6 +36,7 @@ def __init__(self, formula, bins_and_orders, knots_and_vars):
The constructor takes the output of the "convert_junc_txt_file"
text file converter, which returns a formula, bins, and an interpolation table.
"""
super().__init__()
self._dim_order = bins_and_orders[1]
self._bins = bins_and_orders[0]
self._eval_vars = knots_and_vars[1]
Expand Down Expand Up @@ -78,10 +78,6 @@ def __init__(self, formula, bins_and_orders, knots_and_vars):
self._eval_args[argname] = i + len(self._dim_order)
if argname in self._dim_args.keys():
self._eval_args[argname] = self._dim_args[argname]
dask_future = dask.delayed(
self, pure=True, name=f"junclookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
"""uncertainties = f(args)"""
Expand Down
6 changes: 1 addition & 5 deletions src/coffea/lookup_tools/jersf_lookup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from copy import deepcopy

import dask
import numpy

from coffea.lookup_tools.lookup_base import lookup_base
Expand Down Expand Up @@ -33,6 +32,7 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
The constructor takes the output of the "convert_jersf_txt_file"
text file converter, which returns a formula, bins, and values.
"""
super().__init__()
self._dim_order = bins_and_orders[1]
self._bins = bins_and_orders[0]
self._eval_vars = clamps_and_vars[2]
Expand Down Expand Up @@ -65,10 +65,6 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
self._eval_args[argname] = i + len(self._dim_order)
if argname in self._dim_args.keys():
self._eval_args[argname] = self._dim_args[argname]
dask_future = dask.delayed(
self, pure=True, name=f"jersflookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
"""SFs = f(args)"""
Expand Down
6 changes: 1 addition & 5 deletions src/coffea/lookup_tools/jme_standard_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from copy import deepcopy

import awkward
import dask
import numpy
from numpy import sqrt # noqa: F401
from numpy import abs, exp, log, log10 # noqa: F401
Expand Down Expand Up @@ -93,6 +92,7 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
The constructor takes the output of the "convert_jec(jr)_txt_file"
text file converter, which returns a formula, bins, and parameter values.
"""
super().__init__()
self._dim_order = bins_and_orders[1]
self._bins = bins_and_orders[0]
self._eval_vars = clamps_and_vars[2]
Expand Down Expand Up @@ -130,10 +130,6 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
self._eval_args[argname] = i + len(self._dim_order)
if argname in self._dim_args.keys():
self._eval_args[argname] = self._dim_args[argname]
dask_future = dask.delayed(
self, pure=True, name=f"jmestandardlookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
"""jec/jer = f(args)"""
Expand Down
65 changes: 12 additions & 53 deletions src/coffea/lookup_tools/lookup_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numbers
import weakref
from functools import partial

import awkward
Expand All @@ -9,8 +8,7 @@

def getfunction(
args,
thelookup_dask=None,
thelookup_wref=None,
thelookup,
__non_array_args__=tuple(),
__arg_indices__=tuple(),
**kwargs,
Expand Down Expand Up @@ -47,15 +45,6 @@ def getfunction(
for inaarg, naarg in enumerate(__non_array_args__):
repacked_args[__arg_indices__[inaarg + len(args)]] = naarg

thelookup = None
if thelookup_wref is not None:
thelookup = thelookup_wref()
else:
from dask.distributed import worker_client

with worker_client() as client:
thelookup = client.compute(thelookup_dask).result()

result = thelookup._evaluate(*repacked_args, **kwargs)
out = awkward.contents.NumpyArray(result)
if backend == "typetracer":
Expand All @@ -65,24 +54,17 @@ def getfunction(


class _LookupXformFn:
def __init__(self, *args, arg_indices, thelookup_dask, thelookup_wref, **kwargs):
def __init__(self, *args, arg_indices, thelookup, **kwargs):
self.getfunction = getfunction
self._thelookup_dask = thelookup_dask
self._thelookup_wref = thelookup_wref
self._thelookup = thelookup
self.__non_array_args__ = args
self.__arg_indices__ = arg_indices
self.kwargs = kwargs

def __getstate__(self):
out = self.__dict__.copy()
out["_thelookup_wref"] = None
return out

def __call__(self, *args):
func = partial(
self.getfunction,
thelookup_dask=self._thelookup_dask,
thelookup_wref=self._thelookup_wref,
thelookup=self._thelookup,
__non_array_args__=self.__non_array_args__,
__arg_indices__=self.__arg_indices__,
**self.kwargs,
Expand All @@ -93,9 +75,8 @@ def __call__(self, *args):
class lookup_base:
"""Base class for all objects that do some sort of value or function lookup"""

def __init__(self, dask_future):
self._dask_future = dask_future
self._weakref = weakref.ref(self)
def __init__(self):
pass

def __getstate__(self):
out = self.__dict__.copy()
Expand Down Expand Up @@ -124,43 +105,21 @@ def __call__(self, *args, **kwargs):
tomap = _LookupXformFn(
*delay_args,
arg_indices=arg_indices,
thelookup_dask=self._dask_future,
thelookup_wref=self._weakref,
thelookup=self,
**kwargs,
)

# if our inputs are all dask_awkward arrays, then we should map_partitions
if any(isinstance(x, (dask_awkward.Array)) for x in args):
from dask.base import tokenize

zlargs = [
awkward.Array(
arg._meta.layout.form.length_zero_array(highlevel=False),
behavior=arg.behavior,
)
for arg in actual_args
]
zlout = tomap(*zlargs)
meta = awkward.Array(
zlout.layout.to_typetracer(forget_length=True), behavior=zlout.behavior
return dask_awkward.map_partitions(
tomap,
*actual_args,
label=dask_label,
token=tokenize(repr(self), *args),
)

if dask_label is not None:
return dask_awkward.map_partitions(
tomap,
*actual_args,
label=dask_label,
token=tokenize(self._dask_future.name, *args),
meta=meta,
)
else:
return dask_awkward.map_partitions(
tomap,
*actual_args,
token=tokenize(self._dask_future.name, *args),
meta=meta,
)

if all(isinstance(x, (numpy.ndarray, numbers.Number, str)) for x in args):
return self._evaluate(*args, **kwargs)
elif any(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lookup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_evaluate_noimpl():
from coffea.lookup_tools.lookup_base import lookup_base

try:
lookup_base(None)._evaluate()
lookup_base()._evaluate()
except NotImplementedError:
pass

Expand Down

0 comments on commit ef73366

Please sign in to comment.