Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove getstate and setstate from Stack, and base argument from init #201

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions refl1d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,21 @@ class Stack(Layer):
thickness: Parameter = field(metadata={"description": "always equals the sum of the layer thicknesses"})

def __init__(
self, base=None, layers=None, name="Stack", interface=None, thickness: Optional[Union[Parameter, float]] = None
self,
layers: Optional[Union["Stack", List[Union["Slab", "Repeat"]]]] = None,
name: str = "Stack",
interface=None,
thickness: Optional[Union[Parameter, float]] = None,
):
self.name = name
self.interface = None
self._layers = []
# make sure thickness.id is persistent through serialization:
thickness_id = getattr(thickness, "id", None)
thickness = thickness if thickness is not None else 0
self.thickness = Parameter.default(thickness, name=name + " thickness", id=thickness_id)
self.thickness = Parameter.default(thickness, name=name + " thickness")
self._set_thickness()
if layers is not None and base is None:
base = layers
if base is not None:
self.add(base)
if layers is not None:
self.add(layers)

@property
def layers(self):
Expand Down Expand Up @@ -330,15 +331,6 @@ def add(self, other):
L = [other]
self._layers.extend(_check_layer(el) for el in L)

def __getstate__(self):
return self.interface, self._layers, self.name, self.thickness

def __setstate__(self, state):
self.interface, self._layers, self.name, self.thickness = state
# TODO: not clear that this is needed here. The thickness parameter
# from __getstate__ should have a valid expression in it.
self._set_thickness()

def __copy__(self):
stack = Stack()
stack.interface = self.interface
Expand Down
36 changes: 36 additions & 0 deletions tests/refl1d/stack_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from refl1d.names import Slab, Stack, SLD
from copy import deepcopy
from bumps.serialize import serialize, deserialize
import pickle
import dill


def test_stack_serialization():
"""test that stack can be serialized and deserialized with all the methods we use,
preserving the functioning of the Calculation object for the total thickness"""
sample = Slab(SLD(rho=10), thickness=10) | Slab(SLD(rho=10), thickness=20) | Slab(SLD(rho=10), thickness=30)
thickness_plus = sample.thickness + 100 # expression

ser_t, ser_s = deserialize(serialize([thickness_plus, sample]))
assert ser_t.value == 160
ser_s[0].thickness.value += 40
assert ser_t.value == 200

dc_t, dc_s = deepcopy([thickness_plus, sample])
assert dc_t.value == 160
dc_s[0].thickness.value += 40
assert dc_t.value == 200

pickle_t, pickle_s = pickle.loads(pickle.dumps([thickness_plus, sample]))
assert pickle_t.value == 160
pickle_s[0].thickness.value += 40
assert pickle_t.value == 200

dill_t, dill_s = dill.loads(dill.dumps([thickness_plus, sample]))
assert dill_t.value == 160
dill_s[0].thickness.value += 40
assert dill_t.value == 200

assert thickness_plus.value == 160
sample[0].thickness.value += 40
assert thickness_plus.value == 200
Loading