Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize and split #4541

Merged
merged 12 commits into from
Feb 26, 2021
33 changes: 11 additions & 22 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@

from .compression import compressions, maybe_compress, decompress
from .serialize import (
serialize,
deserialize,
Serialize,
Serialized,
extract_serialize,
msgpack_decode_default,
msgpack_encode_default,
merge_and_deserialize,
serialize_and_split,
)
from .utils import frame_split_size, merge_frames, msgpack_opts
from .utils import msgpack_opts
from ..utils import is_writeable, nbytes

_deserialize = deserialize


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,7 +46,7 @@ def dumps(msg, serializers=None, on_error="message", context=None):
}

data = {
key: serialize(
key: serialize_and_split(
value.data, serializers=serializers, on_error=on_error, context=context
)
for key, value in data.items()
Expand All @@ -67,26 +65,19 @@ def dumps(msg, serializers=None, on_error="message", context=None):

# Compress frames that are not yet compressed
out_compression = []
_out_frames = []
for frame, compression in zip(
frames, head.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
_frames = frame_split_size(frame)
_compression, _frames = zip(
*[maybe_compress(frame, **compress_opts) for frame in _frames]
)
out_compression.extend(_compression)
_out_frames.extend(_frames)
else: # already specified, so pass
out_compression.append(compression)
_out_frames.append(frame)
if compression is None:
compression, frame = maybe_compress(frame, **compress_opts)

out_compression.append(compression)
out_frames.append(frame)

head["compression"] = out_compression
head["count"] = len(_out_frames)
head["count"] = len(frames)
header["headers"][key] = head
header["keys"].append(key)
out_frames.extend(_out_frames)

for key, (head, frames) in pre.items():
if "writeable" not in head:
Expand Down Expand Up @@ -146,9 +137,7 @@ def loads(frames, deserialize=True, deserializers=None):
if deserialize or key in bytestrings:
if "compression" in head:
fs = decompress(head, fs)
if not any(hasattr(f, "__cuda_array_interface__") for f in fs):
fs = merge_frames(head, fs)
value = _deserialize(head, fs, deserializers=deserializers)
value = merge_and_deserialize(head, fs, deserializers=deserializers)
else:
value = Serialized(head, fs)

Expand Down
12 changes: 11 additions & 1 deletion distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def serialize_numpy_ndarray(x, context=None):
# "ValueError: cannot include dtype 'M' in a buffer"
data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data

header = {"dtype": dt, "shape": x.shape, "strides": strides}
header = {
"dtype": dt,
"shape": x.shape,
"strides": strides,
"writeable": [x.flags.writeable],
}

if broadcast_to is not None:
header["broadcast_to"] = broadcast_to
Expand All @@ -112,6 +117,7 @@ def deserialize_numpy_ndarray(header, frames):
return pickle.loads(frames[0], buffers=frames[1:])

(frame,) = frames
(writeable,) = header["writeable"]

is_custom, dt = header["dtype"]
if is_custom:
Expand All @@ -125,6 +131,10 @@ def deserialize_numpy_ndarray(header, frames):
shape = header["shape"]

x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"])
if not writeable:
x.flags.writeable = False
elif not x.flags.writeable:
x = x.copy()

return x

Expand Down
72 changes: 71 additions & 1 deletion distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,30 @@ def dask_loads(header, frames):


def pickle_dumps(x, context=None):
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=context.get("pickle-protocol", None) if context else None,
)
header = {
"serializer": "pickle",
"pickle-writeable": tuple(not f.readonly for f in frames[1:]),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do something similar in dask_dumps and cuda_dumps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we want to do that here or in the individual registered dumps/loads functions like the numpy serialization does?
Anyways, I don't think it should block this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's a good question. I think support for NumPy arrays is a bit older as it is a primary use case. So that function may just be a bit unusual because of that.

We should be ok pulling this out of the NumPy case and handling them generally. I would think that should yield simpler easier to understand code, but could be wrong about that

For context tracking writeable frames was needed to solve some gnarly issues ( #1978 ) ( #3943 ). So if there is a general way to solve this, that would be ideal to ensure they don't resurface

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but let's do that in a follow up PR.
It assumes that dask_dumps returns a memoryview compatible object, is that right?
Also, we apparently allow additionally frames when deserializing: https://github.com/dask/distributed/blob/master/distributed/protocol/tests/test_serialize.py#L82

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good 🙂

Yeah though I think that is pretty closely enforced today

I think that is just showing we ignore empty frames, but could be missing something

return header, frames


def pickle_loads(header, frames):
x, buffers = frames[0], frames[1:]
writeable = header["pickle-writeable"]
for i in range(len(buffers)):
readonly = memoryview(buffers[i]).readonly
if writeable[i]:
if readonly:
buffers[i] = bytearray(buffers[i])
elif not readonly:
buffers[i] = bytes(buffers[i])
return pickle.loads(x, buffers=buffers)


Expand Down Expand Up @@ -374,6 +385,65 @@ def deserialize(header, frames, deserializers=None):
return loads(header, frames)


def serialize_and_split(x, serializers=None, on_error="message", context=None):
"""Serialize and split compressable frames

This function is a drop-in replacement of `serialize()` that calls `serialize()`
followed by `frame_split_size()` on frames that should be compressed.

Use `merge_and_deserialize()` to merge and deserialize the frames back.

See Also
--------
serialize
merge_and_deserialize
"""
header, frames = serialize(x, serializers, on_error, context)
num_sub_frames = []
offsets = []
out_frames = []
for frame, compression in zip(
frames, header.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
sub_frames = frame_split_size(frame)
num_sub_frames.append(len(sub_frames))
offsets.append(len(out_frames))
out_frames.extend(sub_frames)
else:
num_sub_frames.append(1)
offsets.append(len(out_frames))
out_frames.append(frame)

header["split-num-sub-frames"] = num_sub_frames
header["split-offsets"] = offsets
return header, out_frames


def merge_and_deserialize(header, frames, deserializers=None):
"""Merge and deserialize frames

This function is a drop-in replacement of `deserialize()` that merges
frames that were split by `serialize_and_split()`

See Also
--------
deserialize
serialize_and_split
"""
merged_frames = []
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
if n == 1:
merged_frames.append(frames[offset])
else:
merged_frames.append(bytearray().join(frames[offset : offset + n]))

return deserialize(header, merged_frames, deserializers=deserializers)


class Serialize:
"""Mark an object that should be serialized

Expand Down