Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiindex scalar coords, fixes #1408 #1412

Closed
wants to merge 9 commits into from
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ By `Ryan Abernathey <https://github.com/rabernat>`_.
``data_vars``.
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Fix a bug where selected levels of Multi-Index were lost by ``isel()`` and ``sel()`` (:issue:1408).
Now, the selected levels are automatically converted to scalar coordinates.
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

.. _whats-new.0.9.5:

v0.9.5 (17 April, 2017)
Expand Down
36 changes: 11 additions & 25 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ def _remap_key(self, key):
return indexing.remap_label_indexers(self.data_array, key)

def __getitem__(self, key):
pos_indexers, new_indexes = self._remap_key(key)
return self.data_array[pos_indexers]._replace_indexes(new_indexes)
pos_indexers, new_indexes, selected_dims = self._remap_key(key)
ds = self.data_array._to_temp_dataset().isel(**pos_indexers)
return self.data_array._from_temp_dataset(
ds._replace_indexes(new_indexes, selected_dims))

def __setitem__(self, key, value):
pos_indexers, _ = self._remap_key(key)
pos_indexers, _, _ = self._remap_key(key)
self.data_array[pos_indexers] = value


Expand Down Expand Up @@ -256,23 +258,6 @@ def _replace_maybe_drop_dims(self, variable, name=__default):
if set(v.dims) <= allowed_dims)
return self._replace(variable, coords, name)

def _replace_indexes(self, indexes):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this method and use Dataset._replace_indexes instead, to reduce duplicates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember well, we used duplicates to avoid using _to_temp_dataset and _from_temp_dataset in __getitem__. But now that _replace_indexes has more logic implemented, maybe it is a good idea to reduce duplicates?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
Yes, in DataArray.__getitem__ and also in DataAarray.isel, _to_temp_dataset and _from_temp_dataset are now being used.

if not len(indexes):
return self
coords = self._coords.copy()
for name, idx in indexes.items():
coords[name] = IndexVariable(name, idx)
obj = self._replace(coords=coords)

# switch from dimension to level names, if necessary
dim_names = {}
for dim, idx in indexes.items():
if not isinstance(idx, pd.MultiIndex) and idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
return obj

__this_array = _ThisArray()

def _to_temp_dataset(self):
Expand Down Expand Up @@ -679,11 +664,12 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers):
Dataset.sel
DataArray.isel
"""
pos_indexers, new_indexes = indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance
)
result = self.isel(drop=drop, **pos_indexers)
return result._replace_indexes(new_indexes)
pos_indexers, new_indexes, selected_dims = \
indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance)
ds = self._to_temp_dataset().isel(drop=drop, **pos_indexers)
return self._from_temp_dataset(
ds._replace_indexes(new_indexes, selected_dims))

def isel_points(self, dim='points', **indexers):
"""Return a new DataArray whose dataset is given by pointwise integer
Expand Down
42 changes: 33 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,13 +575,30 @@ def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
obj = self._construct_direct(variables, coord_names, dims, attrs)
return obj

def _replace_indexes(self, indexes):
def _replace_indexes(self, indexes, selected_indexes={}):
"""
Replace coords and dims by indexes, which is a dict mapping
the original dim (str) to new dim (pandas.index).
selected_indexes is a dict which maps the original dims to the
selected dims that will be scalar coordinates, because they were
selected.
"""
if not len(indexes):
return self
variables = self._variables.copy()
coord_names = self._coord_names.copy()
for dim, selected_dim in selected_indexes.items():
for sd in selected_dim:
_, _, ary = _get_virtual_variable(
variables, sd, level_vars=self._level_coords)
variables[sd] = ary[0]
if coord_names is None:
coord_names = set([sd, ])
else:
coord_names.add(sd)
for name, idx in indexes.items():
variables[name] = IndexVariable(name, idx)
obj = self._replace_vars_and_dims(variables)
obj = self._replace_vars_and_dims(variables, coord_names=coord_names)

# switch from dimension to level names, if necessary
dim_names = {}
Expand Down Expand Up @@ -1138,12 +1155,19 @@ def isel(self, drop=False, **indexers):
for k, v in iteritems(indexers)]

variables = OrderedDict()
coord_names = self._coord_names
for name, var in iteritems(self._variables):
var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
new_var = var.isel(**var_indexers)
if not (drop and name in var_indexers):
variables[name] = new_var
coord_names = set(self._coord_names) & set(variables)
if isinstance(new_var, OrderedDict):
# new_var is an OrderedDict if a single element is
# extracted from MultiIndex. See IndexVariable.__getitem__
variables.update(new_var)
coord_names = coord_names | set(new_var.keys())
else:
variables[name] = new_var
coord_names = set(coord_names) & set(variables)
return self._replace_vars_and_dims(variables, coord_names=coord_names)

def sel(self, method=None, tolerance=None, drop=False, **indexers):
Expand Down Expand Up @@ -1202,11 +1226,11 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers):
Dataset.isel_points
DataArray.sel
"""
pos_indexers, new_indexes = indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance
)
pos_indexers, new_indexes, selected_dims = \
indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance)
result = self.isel(drop=drop, **pos_indexers)
return result._replace_indexes(new_indexes)
return result._replace_indexes(new_indexes, selected_dims)

def isel_points(self, dim='points', **indexers):
# type: (...) -> Dataset
Expand Down Expand Up @@ -1392,7 +1416,7 @@ def sel_points(self, dim='points', method=None, tolerance=None,
Dataset.isel_points
DataArray.sel_points
"""
pos_indexers, _ = indexing.remap_label_indexers(
pos_indexers, _, _ = indexing.remap_label_indexers(
self, indexers, method=method, tolerance=tolerance
)
return self.isel_points(dim=dim, **pos_indexers)
Expand Down
18 changes: 18 additions & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,22 @@ def _maybe_unstack(self, obj):
del obj.coords[dim]
return obj

def _maybe_stack(self, applied):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method becomes necessary, because now we cannot do xr.concat([ds.isel(yx=i)] for i in range(*)], dim='yx') because ds.isel(yx=i) does not have yx anymore.

"""
This constructs MultiIndex if 'applied' does not have self._group_dim.
It may happen if a single item is selected from MultiIndex-ed array.
"""
if not hasattr(self._group, 'to_index'):
return applied
index = self._group.to_index()
if not isinstance(index, pd.MultiIndex):
return applied
else:
return [ds if self._group_dim in ds.coords
else ds.expand_dims(index.names).stack(
**{self._group.name: index.names})
for ds in applied]

def fillna(self, value):
"""Fill missing values in this object by group.

Expand Down Expand Up @@ -528,6 +544,7 @@ def _combine(self, applied, shortcut=False):
"""Recombine the applied objects like the original."""
applied_example, applied = peek_at(applied)
coord, dim, positions = self._infer_concat_args(applied_example)
applied = self._maybe_stack(applied)
if shortcut:
combined = self._concat_shortcut(applied, dim, positions)
else:
Expand Down Expand Up @@ -619,6 +636,7 @@ def apply(self, func, **kwargs):
def _combine(self, applied):
"""Recombine the applied objects like the original."""
applied_example, applied = peek_at(applied)
applied = self._maybe_stack(applied)
coord, dim, positions = self._infer_concat_args(applied_example)
combined = concat(applied, dim)
combined = _maybe_reorder(combined, dim, positions)
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,18 @@ def get_dim_indexers(data_obj, indexers):

def remap_label_indexers(data_obj, indexers, method=None, tolerance=None):
"""Given an xarray data object and label based indexers, return a mapping
of equivalent location based indexers. Also return a mapping of updated
pandas index objects (in case of multi-index level drop).
of equivalent location based indexers.
In case of multi-index level drop, it also returns
(new_indexes) a mapping of updated pandas index objects and
(selected_dims) a mapping from the original dims to selected (dropped)
dims.
"""
if method is not None and not isinstance(method, str):
raise TypeError('``method`` must be a string')

pos_indexers = {}
new_indexes = {}
selected_dims = {}

dim_indexers = get_dim_indexers(data_obj, indexers)
for dim, label in iteritems(dim_indexers):
Expand All @@ -291,8 +295,15 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None):
pos_indexers[dim] = idxr
if new_idx is not None:
new_indexes[dim] = new_idx

return pos_indexers, new_indexes
if isinstance(new_idx, pd.MultiIndex):
selected_dims[dim] = [name for name in index.names
if name not in new_idx.names]
else:
selected_dims[dim] = [name for name in index.names
if name != new_idx.name]
if isinstance(idxr, int) and idxr in (0, 1):
selected_dims[dim] = index.names
return pos_indexers, new_indexes, selected_dims


def slice_slice(old_slice, applied_slice, size):
Expand Down
23 changes: 23 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,12 @@ def __getitem__(self, key):
key = self._item_key_to_tuple(key)
values = self._indexable_data[key]
if not hasattr(values, 'ndim') or values.ndim == 0:
level_names = self.level_names
if level_names:
# If a single item is selected from MultiIndex,
# returns an OrderedDict with multiple scalar variables
return variables_from_multiindex(level_names, values,
self._attrs, self._encoding)
return Variable((), values, self._attrs, self._encoding)
else:
return type(self)(self.dims, values, self._attrs,
Expand Down Expand Up @@ -1324,6 +1330,23 @@ def name(self):
def name(self, value):
raise AttributeError('cannot modify name of IndexVariable in-place')


def variables_from_multiindex(dims, data, attrs=None, encoding=None,
fastpath=False):
""" Construct an OrderedDict from a single item of MultiIndex.
keys :level_names
items: Variable with zero-dimension.
This conversion is necessary because pandas.MultiIndex losts its
hierarchical structure if a single element is selected.

dims: tuples, mainly comes from IndexVariable.level_names
data: 0d-np.ndarray which contains a set of level_values.
"""
variables = OrderedDict()
for dim, value in zip(dims, data.item()):
variables[dim] = Variable((), value, attrs, encoding, fastpath)
return variables

# for backwards compatibility
Coordinate = utils.alias(IndexVariable, 'Coordinate')

Expand Down
18 changes: 18 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,24 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,

self.assertDataArrayIdentical(mdata.sel(x={'one': 'a', 'two': 1}),
mdata.sel(one='a', two=1))
self.assertTrue('one' in mdata.sel(one='a').coords)
self.assertTrue('one' in mdata.sel(one='a', two=1).coords)
self.assertTrue('two' in mdata.sel(one='a', two=1).coords)
self.assertTrue('three' in mdata.sel(one='a', two=1, three=-1).coords)

def test_isel_multiindex(self):
mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]],
names=('one', 'two', 'three'))
mdata = DataArray(range(8), dims=['x'], coords={'x': mindex})
selected = mdata.isel(x=0)
self.assertTrue('one' in selected.coords)
self.assertTrue('two' in selected.coords)
self.assertTrue('three' in selected.coords)
# drop
selected = mdata.isel(x=0, drop=True)
self.assertTrue('one' not in selected.coords)
self.assertTrue('two' not in selected.coords)
self.assertTrue('three' not in selected.coords)

def test_virtual_default_coords(self):
array = DataArray(np.zeros((5,)), dims='x')
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,25 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,

self.assertDatasetIdentical(mdata.sel(x={'one': 'a', 'two': 1}),
mdata.sel(one='a', two=1))
self.assertTrue('one' in mdata.sel(one='a').coords)
self.assertTrue('one' in mdata.sel(one='a', two=1).coords)
self.assertTrue('two' in mdata.sel(one='a', two=1).coords)
self.assertTrue('three' in mdata.sel(one='a', two=1, three=-1).coords)

def test_isel_multiindex(self):
mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]],
names=('one', 'two', 'three'))
mdata = Dataset(data_vars={'var': ('x', range(8))},
coords={'x': mindex})
selected = mdata.isel(x=0)
self.assertTrue('one' in selected.coords)
self.assertTrue('two' in selected.coords)
self.assertTrue('three' in selected.coords)
# drop
selected = mdata.isel(x=0, drop=True)
self.assertTrue('one' not in selected.coords)
self.assertTrue('two' not in selected.coords)
self.assertTrue('three' not in selected.coords)

def test_reindex_like(self):
data = create_test_data()
Expand Down
14 changes: 13 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ def test_consolidate_slices():
_consolidate_slices([slice(3), 4])


def test_multi_index_groupby_apply_dataarray():
# regression test for GH873
ds = xr.DataArray(np.random.randn(3, 4), dims=['x', 'y'],
coords={'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]})
doubled = 2 * ds
group_doubled = (ds.stack(space=['x', 'y'])
.groupby('space')
.apply(lambda x: 2 * x)
.unstack('space'))
assert doubled.equals(group_doubled)


def test_multi_index_groupby_apply():
# regression test for GH873
ds = xr.Dataset({'foo': (('x', 'y'), np.random.randn(3, 4))},
Expand Down Expand Up @@ -70,5 +82,5 @@ def test_groupby_duplicate_coordinate_labels():
actual = array.groupby('x').sum()
assert expected.equals(actual)


# TODO: move other groupby tests from test_dataset and test_dataarray over here
Loading