Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Bag Partition is now configurable. #33805

Merged
merged 8 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
* Support the Process Environment for execution in Prism ([#33651](https://github.com/apache/beam/pull/33651))
* Support the AnyOf Environment for execution in Prism ([#33705](https://github.com/apache/beam/pull/33705))
* This improves support for developing Xlang pipelines, when using a compatible cross language service.
* Partitions are now configurable for the DaskRunner in the Python SDK ([#33805](https://github.com/apache/beam/pull/33805)).

## Breaking Changes

Expand Down
39 changes: 35 additions & 4 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def _parse_timeout(candidate):
import dask
return dask.config.no_default

@staticmethod
def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
"""Parse keyword arguments for `dask.Bag`s; used in graph translation."""
out = {}

if npartitions := dask_options.pop('npartitions', None):
out['npartitions'] = npartitions
if partition_size := dask_options.pop('partition_size', None):
out['partition_size'] = partition_size

return out

@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
Expand Down Expand Up @@ -93,6 +105,21 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')
partitions_parser = parser.add_mutually_exclusive_group()
liferoad marked this conversation as resolved.
Show resolved Hide resolved
partitions_parser.add_argument(
'--dask_npartitions',
dest='npartitions',
type=int,
default=None,
help='The desired number of `dask.Bag` partitions. When unspecified, '
'an educated guess is made.')
partitions_parser.add_argument(
'--dask_partition_size',
dest='partition_size',
type=int,
default=None,
help='The length of each `dask.Bag` partition. When unspecified, '
'an educated guess is made.')


@dataclasses.dataclass
Expand Down Expand Up @@ -139,17 +166,20 @@ def metrics(self):
class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
@staticmethod
def to_dask_bag_visitor() -> PipelineVisitor:
def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
from dask import bag as db

if bag_kwargs is None:
bag_kwargs = {}

@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
default_factory=collections.OrderedDict)

def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)
op = op_class(transform_node, bag_kwargs=bag_kwargs)

op_kws = {"input_bag": None, "side_inputs": None}
inputs = list(transform_node.inputs)
Expand Down Expand Up @@ -195,7 +225,7 @@ def is_fnapi_compatible():
def run_pipeline(self, pipeline, options):
import dask

# TODO(alxr): Create interactive notebook support.
# TODO(alxmrs): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')

Expand All @@ -207,11 +237,12 @@ def run_pipeline(self, pipeline, options):

dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
client = ddist.Client(**dask_options)

pipeline.replace_all(dask_overrides())

dask_visitor = self.to_dask_bag_visitor()
dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
pipeline.visit(dask_visitor)
# The dictionary in this visitor keeps a mapping of every Beam
# PTransform to the equivalent Bag operation. This is highly
Expand Down
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/runners/dask/dask_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ def test_parser_destinations__agree_with_dask_client(self):
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
self.assertIn(opt_name, client_args)

def test_parser_extract_bag_kwargs__deletes_dask_kwargs(self):
options = PipelineOptions('--dask_npartitions 8'.split())
dask_options = options.view_as(DaskOptions).get_all_options()

self.assertIn('npartitions', dask_options)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
self.assertNotIn('npartitions', dask_options)
self.assertEqual(bag_kwargs, {'npartitions': 8})

def test_parser_extract_bag_kwargs__unconfigured(self):
options = PipelineOptions()
dask_options = options.view_as(DaskOptions).get_all_options()

# It's present as a default option.
self.assertIn('npartitions', dask_options)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
self.assertNotIn('npartitions', dask_options)
self.assertEqual(bag_kwargs, {})


class DaskRunnerRunPipelineTest(unittest.TestCase):
"""Test class used to introspect the dask runner via a debugger."""
Expand Down
30 changes: 27 additions & 3 deletions sdks/python/apache_beam/runners/dask/transform_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
import abc
import dataclasses
import logging
import math
import typing as t
from dataclasses import field
Expand Down Expand Up @@ -52,6 +53,8 @@
# Value types for PCollections (possibly Windowed Values).
PCollVal = t.Union[WindowedValue, t.Any]

_LOGGER = logging.getLogger(__name__)


def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
"""Wraps a value (item) inside a Window."""
Expand Down Expand Up @@ -127,8 +130,11 @@ class DaskBagOp(abc.ABC):
Attributes
applied: The underlying `AppliedPTransform` which holds the code for the
target operation.
bag_kwargs: (optional) Keyword arguments applied to input bags, usually
from the pipeline's `DaskOptions`.
"""
applied: AppliedPTransform
bag_kwargs: t.Dict = dataclasses.field(default_factory=dict)

@property
def transform(self):
Expand All @@ -151,10 +157,28 @@ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.transform)
items = original_transform.values

npartitions = self.bag_kwargs.get('npartitions')
partition_size = self.bag_kwargs.get('partition_size')
if npartitions and partition_size:
raise ValueError(
f'Please specify either `dask_npartitions` or '
f'`dask_parition_size` but not both: '
f'{npartitions=}, {partition_size=}.')
if not npartitions and not partition_size:
# partition_size is inversely related to `npartitions`.
# Ideal "chunk sizes" in dask are around 10-100 MBs.
# Let's hope ~128 items per partition is around this
# memory overhead.
default_size = 128
partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10))
if partition_size == default_size:
_LOGGER.warning(
'The new default partition size is %d, it used to be 1 '
'in previous DaskRunner versions.' % default_size)

return db.from_sequence(
items,
partition_size=max(
1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
items, npartitions=npartitions, partition_size=partition_size)


def apply_dofn_to_bundle(
Expand Down
Loading