Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model setstate shim #216

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions refl1d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,15 @@ def add(self, other):
L = [other]
self._layers.extend(_check_layer(el) for el in L)

def __setstate__(self, state):
# this is a temporary shim (2024-10-29), to accomodate objects that were pickled
# before the custom __getstate__ was removed.
# TODO: the entire __setstate__ can be removed someday (e.g. in 2026?)
if isinstance(state, tuple):
self.interface, self._layers, self.name, self.thickness = state
else:
self.__dict__.update(state)

def __copy__(self):
stack = Stack()
stack.interface = self.interface
Expand Down Expand Up @@ -663,11 +672,14 @@ def to_dict(self):
}
)

def __getstate__(self):
return self.interface, self.repeat, self.name, self.stack

def __setstate__(self, state):
self.interface, self.repeat, self.name, self.stack = state
# this is a temporary shim (2024-10-29), to accomodate objects that were pickled
# before the custom __getstate__ was removed.
# TODO: the entire __setstate__ can be removed someday (e.g. in 2026?)
if isinstance(state, tuple):
self.interface, self.repeat, self.name, self.stack = state
else:
self.__dict__.update(state)

def penalty(self):
return self.stack.penalty()
Expand Down
33 changes: 33 additions & 0 deletions tests/refl1d/stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,36 @@ def test_stack_serialization():
assert thickness_plus.value == 160
sample[0].thickness.value += 40
assert thickness_plus.value == 200


def test_repeat_serialization():
"""test that stack can be serialized and deserialized with all the methods we use,
preserving the functioning of the Calculation object for the total thickness"""
unit_cell = Slab(SLD(rho=10), thickness=10) | Slab(SLD(rho=10), thickness=20) | Slab(SLD(rho=10), thickness=30)
# This creates a Repeat object from the Stack:
sample = unit_cell * 4
thickness_plus = sample.thickness + 100 # expression

ser_t, ser_s = deserialize(serialize([thickness_plus, sample]))
assert ser_t.value == 340 # (10+20+30 * 4) + 100
ser_s.stack[0].thickness.value += 40
assert ser_t.value == 500 # (50+20+30 * 4) + 100

dc_t, dc_s = deepcopy([thickness_plus, sample])
assert dc_t.value == 340
dc_s.stack[0].thickness.value += 40
assert dc_t.value == 500

pickle_t, pickle_s = pickle.loads(pickle.dumps([thickness_plus, sample]))
assert pickle_t.value == 340
pickle_s.stack[0].thickness.value += 40
assert pickle_t.value == 500

dill_t, dill_s = dill.loads(dill.dumps([thickness_plus, sample]))
assert dill_t.value == 340
dill_s.stack[0].thickness.value += 40
assert dill_t.value == 500

assert thickness_plus.value == 340
sample.stack[0].thickness.value += 40
assert thickness_plus.value == 500
Loading