Skip to content

Commit

Permalink
Disable automatic cache with dask (pydata#1024)
Browse files Browse the repository at this point in the history
* Disabled auto-caching dask arrays when pickling and when invoking the .values property.
Added new method .compute().

* Minor tweaks

* Simplified Dataset.copy() and Dataset.compute()

* Minor cleanup

* Cleaned up dask test

* Integrate no_dask_resolve with dask_broadcast branches

* Don't chunk coords

* Added performance warning to release notes

* Fix bug that caused dask array to be computed and then discarded when pickling

* Eagerly cache IndexVariables only

Eagerly cache only IndexVariables (e.g. coords that are not in dims. Coords that are not in dims are chunked and not cached.

* Load IndexVariable.data into memory in init

IndexVariables to eagerly load their data into memory (from disk or dask) as soon as they're created
  • Loading branch information
shoyer authored Nov 14, 2016
2 parents 0ed1e2c + 376200a commit d66f673
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 95 deletions.
16 changes: 11 additions & 5 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ Breaking changes
merges will now succeed in cases that previously raised
``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the
previous default.
- Pickling an xarray object based on the dask backend, or reading its
:py:meth:`values` property, won't automatically convert the array from dask
to numpy in the original object anymore.
If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or
:py:class:`~xarray.Dataset`, its values are eagerly computed and cached,
but only if it's used to index a dim (e.g. it's used for alignment).
By `Guido Imperiale <https://github.com/crusaderky>`_.

Deprecations
~~~~~~~~~~~~
Expand Down Expand Up @@ -52,32 +59,31 @@ Enhancements
- Add checking of ``attr`` names and values when saving to netCDF, raising useful
error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.

- Added ability to save ``DataArray`` objects directly to netCDF files using
:py:meth:`~xarray.DataArray.to_netcdf`, and to load directly from netCDF files
using :py:func:`~xarray.open_dataarray` (:issue:`915`). These remove the need
to convert a ``DataArray`` to a ``Dataset`` before saving as a netCDF file,
and deals with names to ensure a perfect 'roundtrip' capability.
By `Robin Wilson <https://github.com/robintw>`_.

- Multi-index levels are now accessible as "virtual" coordinate variables,
e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index
(see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels
as keyword arguments, e.g., ``ds.sel(time='2000-01')``
(see :ref:`multi-level indexing`).
By `Benoit Bovy <https://github.com/benbovy>`_.

- Added the ``compat`` option ``'no_conflicts'`` to ``merge``, allowing the
combination of xarray objects with disjoint (:issue:`742`) or
overlapping (:issue:`835`) coordinates as long as all present data agrees.
By `Johnnie Gray <https://github.com/jcmgray>`_. See
:ref:`combining.no_conflicts` for more details.

- It is now possible to set ``concat_dim=None`` explicitly in
:py:func:`~xarray.open_mfdataset` to disable inferring a dimension along
which to concatenate.
By `Stephan Hoyer <https://github.com/shoyer>`_.

- Added methods :py:meth:`DataArray.compute`, :py:meth:`Dataset.compute`, and
:py:meth:`Variable.compute` as a non-mutating alternative to
:py:meth:`~DataArray.load`.
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Adds DataArray and Dataset methods :py:meth:`~xarray.DataArray.cumsum` and
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
<https://github.com/pwolfram>`_.
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,19 @@ def load(self):
self._coords = new._coords
return self

def compute(self):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array. The original is
left unaltered.
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

def copy(self, deep=True):
"""Returns a copy of this array.
Expand Down
41 changes: 27 additions & 14 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,11 @@ def load_store(cls, store, decoder=None):
return obj

def __getstate__(self):
"""Always load data in-memory before pickling"""
self.load()
"""Load data in-memory before pickling (except for Dask data)"""
for v in self.variables.values():
if not isinstance(v.data, dask_array_type):
v.load()

# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
state = self.__dict__.copy()
Expand Down Expand Up @@ -342,6 +345,19 @@ def load(self):

return self

def compute(self):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
left unaltered.
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
file_obj=None):
Expand Down Expand Up @@ -424,14 +440,12 @@ def copy(self, deep=False):
"""Returns a copy of this dataset.
If `deep=True`, a deep copy is made of each of the component variables.
Otherwise, a shallow copy is made, so each variable in the new dataset
is also a variable in the original dataset.
Otherwise, a shallow copy of each of the component variable is made, so
that the underlying memory region of the new dataset is the same as in
the original dataset.
"""
if deep:
variables = OrderedDict((k, v.copy(deep=True))
for k, v in iteritems(self._variables))
else:
variables = self._variables.copy()
variables = OrderedDict((k, v.copy(deep=deep))
for k, v in iteritems(self._variables))
# skip __init__ to avoid costly validation
return self._construct_direct(variables, self._coord_names.copy(),
self._dims.copy(), self._attrs_copy())
Expand Down Expand Up @@ -817,11 +831,10 @@ def chunks(self):
chunks = {}
for v in self.variables.values():
if v.chunks is not None:
new_chunks = list(zip(v.dims, v.chunks))
if any(chunk != chunks[d] for d, chunk in new_chunks
if d in chunks):
raise ValueError('inconsistent chunks')
chunks.update(new_chunks)
for dim, c in zip(v.dims, v.chunks):
if dim in chunks and c != chunks[dim]:
raise ValueError('inconsistent chunks')
chunks[dim] = c
return Frozen(SortedKeysDict(chunks))

def chunk(self, chunks=None, name_prefix='xarray-', token=None,
Expand Down
50 changes: 42 additions & 8 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,21 @@ def data(self, data):
"replacement data must match the Variable's shape")
self._data = data

def _data_cast(self):
if isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
return self._data
else:
return np.asarray(self._data)

def _data_cached(self):
if not isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
self._data = np.asarray(self._data)
return self._data
"""Load data into memory and return it.
Do not cache dask arrays automatically; that should
require an explicit load() call.
"""
new_data = self._data_cast()
if not isinstance(self._data, dask_array_type):
self._data = new_data
return new_data

@property
def _indexable_data(self):
Expand All @@ -294,12 +305,26 @@ def load(self):
because all xarray functions should either work on deferred data or
load data automatically.
"""
self._data_cached()
self._data = self._data_cast()
return self

def compute(self):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return a new variable. The original is
left unaltered.
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.
"""
new = self.copy(deep=False)
return new.load()

def __getstate__(self):
"""Always cache data as an in-memory array before pickling"""
self._data_cached()
"""Always cache data as an in-memory array before pickling
(with the exception of dask backend)"""
if not isinstance(self._data, dask_array_type):
self._data_cached()
# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
return self.__dict__
Expand Down Expand Up @@ -1075,10 +1100,19 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
raise ValueError('%s objects must be 1-dimensional' %
type(self).__name__)

def _data_cached(self):
# Unlike in Variable, always eagerly load values into memory
if not isinstance(self._data, PandasIndexAdapter):
self._data = PandasIndexAdapter(self._data)
return self._data

@Variable.data.setter
def data(self, data):
Variable.data.fset(self, data)
if not isinstance(self._data, PandasIndexAdapter):
self._data = PandasIndexAdapter(self._data)

def chunk(self, chunks=None, name=None, lock=False):
# Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
return self.copy(deep=False)

def __getitem__(self, key):
key = self._item_key_to_tuple(key)
Expand Down
75 changes: 60 additions & 15 deletions xarray/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,12 @@ def assert_loads(vars=None):
if vars is None:
vars = expected
with self.roundtrip(expected) as actual:
for v in actual.variables.values():
self.assertFalse(v._in_memory)
for k, v in actual.variables.items():
# IndexVariables are eagerly loaded into memory
if k in actual.dims:
self.assertTrue(v._in_memory)
else:
self.assertFalse(v._in_memory)
yield actual
for k, v in actual.variables.items():
if k in vars:
Expand All @@ -152,6 +156,31 @@ def assert_loads(vars=None):
actual = ds.load()
self.assertDatasetAllClose(expected, actual)

def test_dataset_compute(self):
expected = create_test_data()

with self.roundtrip(expected) as actual:
# Test Dataset.compute()
for k, v in actual.variables.items():
# IndexVariables are eagerly cached
if k in actual.dims:
self.assertTrue(v._in_memory)
else:
self.assertFalse(v._in_memory)

computed = actual.compute()

for k, v in actual.variables.items():
if k in actual.dims:
self.assertTrue(v._in_memory)
else:
self.assertFalse(v._in_memory)
for v in computed.variables.values():
self.assertTrue(v._in_memory)

self.assertDatasetAllClose(expected, actual)
self.assertDatasetAllClose(expected, computed)

def test_roundtrip_None_variable(self):
expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])})
with self.roundtrip(expected) as actual:
Expand Down Expand Up @@ -233,18 +262,6 @@ def test_roundtrip_coordinates(self):
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

expected = original.copy()
expected.attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

expected = original.copy(deep=True)
expected['foo'].attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

def test_roundtrip_boolean_dtype(self):
original = create_boolean_data()
self.assertEqual(original['x'].dtype, 'bool')
Expand Down Expand Up @@ -875,7 +892,26 @@ def test_read_byte_attrs_as_unicode(self):
@requires_dask
@requires_scipy
@requires_netCDF4
class DaskTest(TestCase):
class DaskTest(TestCase, DatasetIOTestCases):
@contextlib.contextmanager
def create_store(self):
yield Dataset()

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={}):
yield data.chunk()

def test_roundtrip_datetime_data(self):
# Override method in DatasetIOTestCases - remove not applicable save_kwds
times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT'])
expected = Dataset({'t': ('t', times), 't0': times[0]})
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

def test_write_store(self):
# Override method in DatasetIOTestCases - not applicable to dask
pass

def test_open_mfdataset(self):
original = Dataset({'foo': ('x', np.random.randn(10))})
with create_tmp_file() as tmp1:
Expand Down Expand Up @@ -995,6 +1031,15 @@ def test_deterministic_names(self):
self.assertIn(tmp, dask_name)
self.assertEqual(original_names, repeat_names)

def test_dataarray_compute(self):
# Test DataArray.compute() on dask backend.
# The test for Dataset.compute() is already in DatasetIOTestCases;
# however dask is the only tested backend which supports DataArrays
actual = DataArray([1,2]).chunk()
computed = actual.compute()
self.assertFalse(actual._in_memory)
self.assertTrue(computed._in_memory)
self.assertDataArrayAllClose(actual, computed)

@requires_scipy_or_netCDF4
@requires_pydap
Expand Down
Loading

0 comments on commit d66f673

Please sign in to comment.