Skip to content

Commit

Permalink
clean up copyreg of weakref, control with __getstate__/__setstate__ i…
Browse files Browse the repository at this point in the history
…nstead
  • Loading branch information
lgray committed Jun 5, 2023
1 parent 1b416db commit 1c7b204
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
2 changes: 0 additions & 2 deletions src/coffea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@
deprecations_as_errors = False

import copyreg
import weakref

import dask_awkward

copyreg.pickle(weakref.ref, lambda x: (lambda y: y, (None,)))
copyreg.pickle(dask_awkward.Array, lambda x: (lambda y: y, (None,)))


Expand Down
35 changes: 26 additions & 9 deletions src/coffea/lookup_tools/lookup_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,28 @@ def getfunction(

class _LookupXformFn:
def __init__(self, *args, arg_indices, thelookup_dask, thelookup_wref, **kwargs):
self.func = partial(
getfunction,
thelookup_dask=thelookup_dask,
thelookup_wref=thelookup_wref,
__non_array_args__=args,
__arg_indices__=arg_indices,
**kwargs,
)
self.getfunction = getfunction
self._thelookup_dask = thelookup_dask
self._thelookup_wref = thelookup_wref
self.__non_array_args__ = args
self.__arg_indices__ = arg_indices
self.kwargs = kwargs

def __getstate__(self):
out = self.__dict__.copy()
out["_thelookup_wref"] = None
return out

def __call__(self, *args):
return awkward.transform(self.func, *args)
func = partial(
self.getfunction,
thelookup_dask=self._thelookup_dask,
thelookup_wref=self._thelookup_wref,
__non_array_args__=self.__non_array_args__,
__arg_indices__=self.__arg_indices__,
**self.kwargs,
)
return awkward.transform(func, *args)


class lookup_base:
Expand All @@ -88,6 +99,12 @@ def __init__(self, dask_future):
self._dask_future = dask_future
self._weakref = weakref.ref(self)

def __getstate__(self):
out = self.__dict__.copy()
if "_weakref" in out:
out["_weakref"] = None
return out

def __call__(self, *args, **kwargs):
dask_label = kwargs.pop("dask_label", None)

Expand Down

0 comments on commit 1c7b204

Please sign in to comment.