Skip to content

Commit

Permalink
wip: clean-up
Browse files Browse the repository at this point in the history
Revert some changes made in pydata#5102 + additional (temporary) fixes.
  • Loading branch information
benbovy committed Jul 29, 2021
1 parent b71bf3f commit 84cbf15
Show file tree
Hide file tree
Showing 14 changed files with 142 additions and 117 deletions.
41 changes: 29 additions & 12 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd

from . import dtypes
from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index
from .indexes import Index, PandasIndex, get_indexer_nd
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index
from .variable import IndexVariable, Variable

Expand Down Expand Up @@ -53,7 +53,10 @@ def _get_joiner(join, index_cls):
def _override_indexes(objects, all_indexes, exclude):
for dim, dim_indexes in all_indexes.items():
if dim not in exclude:
lengths = {index.size for index in dim_indexes}
lengths = {
getattr(index, "size", index.to_pandas_index().size)
for index in dim_indexes
}
if len(lengths) != 1:
raise ValueError(
f"Indexes along dimension {dim!r} don't have the same length."
Expand Down Expand Up @@ -300,11 +303,12 @@ def align(
joined_indexes = {}
for dim, matching_indexes in all_indexes.items():
if dim in indexes:
# TODO: benbovy - flexible indexes. maybe move this logic in util func
if isinstance(indexes[dim], Index):
index = indexes[dim]
else:
index = PandasIndex(safe_cast_to_index(indexes[dim]))
index, _ = PandasIndex.from_pandas_index(
safe_cast_to_index(indexes[dim]), dim
)
if (
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
Expand All @@ -323,17 +327,18 @@ def align(
joiner = _get_joiner(join, type(matching_indexes[0]))
index = joiner(matching_indexes)
# make sure str coords are not cast to object
index = maybe_coerce_to_str(index, all_coords[dim])
index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim])
joined_indexes[dim] = index
else:
index = all_coords[dim][0]

if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
# TODO: benbovy - flexible indexes: expose a size property for xarray.Index?
# Some indexes may not have a defined size (e.g., built from multiple coords of
# different sizes)
labeled_size = index.size
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
if isinstance(index, PandasIndex):
labeled_size = index.to_pandas_index().size
else:
labeled_size = index.size
if len(unlabeled_sizes | {labeled_size}) > 1:
raise ValueError(
f"arguments without labels along dimension {dim!r} cannot be "
Expand All @@ -350,7 +355,14 @@ def align(

result = []
for obj in objects:
valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims}
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
valid_indexers = {}
for k, index in joined_indexes.items():
if k in obj.dims:
if isinstance(index, Index):
valid_indexers[k] = index.to_pandas_index()
else:
valid_indexers[k] = index
if not valid_indexers:
# fast path for no reindexing necessary
new_obj = obj.copy(deep=copy)
Expand Down Expand Up @@ -471,7 +483,11 @@ def reindex_like_indexers(
ValueError
If any dimensions without labels have different sizes.
"""
indexers = {k: v for k, v in other.xindexes.items() if k in target.dims}
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
# this doesn't support yet indexes other than pd.Index
indexers = {
k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims
}

for dim in other.dims:
if dim not in indexers and dim in target.dims:
Expand Down Expand Up @@ -560,7 +576,8 @@ def reindex_variables(
"from that to be indexed along {:s}".format(str(indexer.dims), dim)
)

target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim]))
target = safe_cast_to_index(indexers[dim])
new_indexes[dim] = PandasIndex(target, dim)

if dim in indexes:
# TODO (benbovy - flexible indexes): support other indexes than pd.Index?
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def _infer_concat_order_from_coords(datasets):
"inferring concatenation order"
)

# TODO (benbovy, flexible indexes): all indexes should be Pandas.Index
# get pd.Index objects from Index objects
indexes = [index.array for index in indexes]
# TODO (benbovy, flexible indexes): support flexible indexes?
indexes = [index.to_pandas_index() for index in indexes]

# If dimension coordinate values are same on every dataset then
# should be leaving this dimension alone (it's just a "bystander")
Expand Down
22 changes: 5 additions & 17 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import (
Index,
Indexes,
default_indexes,
propagate_indexes,
wrap_pandas_index,
)
from .indexes import Index, Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
from .options import OPTIONS, _get_keep_attrs
Expand Down Expand Up @@ -473,15 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
return self
coords = self._coords.copy()
for name, idx in indexes.items():
coords[name] = IndexVariable(name, idx)
coords[name] = IndexVariable(name, idx.to_pandas_index())
obj = self._replace(coords=coords)

# switch from dimension to level names, if necessary
dim_names: Dict[Any, str] = {}
for dim, idx in indexes.items():
# TODO: benbovy - flexible indexes: update when MultiIndex has its own class
pd_idx = idx.array
if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim:
pd_idx = idx.to_pandas_index()
if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
Expand Down Expand Up @@ -1046,12 +1039,7 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
if self._indexes is None:
indexes = self._indexes
else:
# TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index)
# xarray Index needs a copy method.
indexes = {
k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep))
for k, v in self._indexes.items()
}
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
Expand Down
52 changes: 32 additions & 20 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
propagate_indexes,
remove_unused_levels_categories,
roll_index,
wrap_pandas_index,
)
from .indexing import is_fancy_indexer
from .merge import (
Expand Down Expand Up @@ -1184,7 +1183,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset":
variables = self._variables.copy()
new_indexes = dict(self.xindexes)
for name, idx in indexes.items():
variables[name] = IndexVariable(name, idx)
variables[name] = IndexVariable(name, idx.to_pandas_index())
new_indexes[name] = idx
obj = self._replace(variables, indexes=new_indexes)

Expand Down Expand Up @@ -2474,6 +2473,10 @@ def sel(
pos_indexers, new_indexes = remap_label_indexers(
self, indexers=indexers, method=method, tolerance=tolerance
)
# TODO: benbovy - flexible indexes: also use variables returned by Index.query
# (temporary dirty fix).
new_indexes = {k: v[0] for k, v in new_indexes.items()}

result = self.isel(indexers=pos_indexers, drop=drop)
return result._overwrite_indexes(new_indexes)

Expand Down Expand Up @@ -3297,20 +3300,21 @@ def _rename_dims(self, name_dict):
return {name_dict.get(k, k): v for k, v in self.dims.items()}

def _rename_indexes(self, name_dict, dims_set):
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645
if self._indexes is None:
return None
indexes = {}
for k, v in self.xindexes.items():
# TODO: benbovy - flexible indexes: make it compatible with any xarray Index
index = v.to_pandas_index()
for k, v in self.indexes.items():
new_name = name_dict.get(k, k)
if new_name not in dims_set:
continue
if isinstance(index, pd.MultiIndex):
new_names = [name_dict.get(k, k) for k in index.names]
indexes[new_name] = PandasMultiIndex(index.rename(names=new_names))
if isinstance(v, pd.MultiIndex):
new_names = [name_dict.get(k, k) for k in v.names]
indexes[new_name] = PandasMultiIndex(
v.rename(names=new_names), new_name
)
else:
indexes[new_name] = PandasIndex(index.rename(new_name))
indexes[new_name] = PandasIndex(v.rename(new_name), new_name)
return indexes

def _rename_all(self, name_dict, dims_dict):
Expand Down Expand Up @@ -3539,7 +3543,10 @@ def swap_dims(
if new_index.nlevels == 1:
# make sure index name matches dimension name
new_index = new_index.rename(k)
indexes[k] = wrap_pandas_index(new_index)
if isinstance(new_index, pd.MultiIndex):
indexes[k] = PandasMultiIndex(new_index, k)
else:
indexes[k] = PandasIndex(new_index, k)
else:
var = v.to_base_variable()
var.dims = dims
Expand Down Expand Up @@ -3812,7 +3819,7 @@ def reorder_levels(
raise ValueError(f"coordinate {dim} has no MultiIndex")
new_index = index.reorder_levels(order)
variables[dim] = IndexVariable(coord.dims, new_index)
indexes[dim] = PandasMultiIndex(new_index)
indexes[dim] = PandasMultiIndex(new_index, dim)

return self._replace(variables, indexes=indexes)

Expand Down Expand Up @@ -3840,7 +3847,7 @@ def _stack_once(self, dims, new_dim):
coord_names = set(self._coord_names) - set(dims) | {new_dim}

indexes = {k: v for k, v in self.xindexes.items() if k not in dims}
indexes[new_dim] = wrap_pandas_index(idx)
indexes[new_dim] = PandasMultiIndex(idx, new_dim)

return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down Expand Up @@ -4029,8 +4036,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
variables[name] = var

for name, lev in zip(index.names, index.levels):
variables[name] = IndexVariable(name, lev)
indexes[name] = PandasIndex(lev)
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
variables[name] = idx_vars[name]
indexes[name] = idx

coord_names = set(self._coord_names) - {dim} | set(index.names)

Expand Down Expand Up @@ -4068,8 +4076,9 @@ def _unstack_full_reindex(
variables[name] = var

for name, lev in zip(new_dim_names, index.levels):
variables[name] = IndexVariable(name, lev)
indexes[name] = PandasIndex(lev)
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
variables[name] = idx_vars[name]
indexes[name] = idx

coord_names = set(self._coord_names) - {dim} | set(new_dim_names)

Expand Down Expand Up @@ -5839,10 +5848,13 @@ def diff(self, dim, n=1, label="upper"):

indexes = dict(self.xindexes)
if dim in indexes:
# TODO: benbovy - flexible indexes: check slicing of xarray indexes?
# or only allow this for pandas indexes?
index = indexes[dim].to_pandas_index()
indexes[dim] = PandasIndex(index[kwargs_new[dim]])
if isinstance(indexes[dim], PandasIndex):
# maybe optimize? (pandas index already indexed above with var.isel)
new_index = indexes[dim].index[kwargs_new[dim]]
if isinstance(new_index, pd.MultiIndex):
indexes[dim] = PandasMultiIndex(new_index, dim)
else:
indexes[dim] = PandasIndex(new_index, dim)

difference = self._replace_with_new_dims(variables, indexes=indexes)

Expand Down
30 changes: 17 additions & 13 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def from_variables(cls, variables: Mapping[Hashable, "Variable"]):

@classmethod
def from_pandas_index(cls, index: pd.Index, dim: Hashable):
from .variable import IndexVariable

if index.name is None:
name = dim
else:
Expand Down Expand Up @@ -247,15 +249,15 @@ def union(self, other):

new_index = self.index.union(other)

return type(self).from_pandas_index(new_index, self.dim)
return type(self)(new_index, self.dim)

def intersection(self, other):
if isinstance(other, PandasIndex):
other = other.index

new_index = self.index.intersection(other)

return type(self).from_pandas_index(new_index, self.dim)
return type(self)(new_index, self.dim)

def copy(self, deep=True):
return type(self)(self.index.copy(deep=deep), self.dim)
Expand Down Expand Up @@ -421,13 +423,6 @@ def query(self, labels, method=None, tolerance=None):
return indexer, None


def wrap_pandas_index(index):
if isinstance(index, pd.MultiIndex):
return PandasMultiIndex(index)
else:
return PandasIndex(index)


def remove_unused_levels_categories(index: pd.Index) -> pd.Index:
"""
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
Expand Down Expand Up @@ -512,7 +507,13 @@ def isel_variable_and_index(
index: Index,
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, "Variable"]],
) -> Tuple["Variable", Optional[Index]]:
"""Index a Variable and pandas.Index together."""
"""Index a Variable and an Index together.
If the index cannot be indexed, return None (it will be dropped).
(note: not compatible yet with xarray flexible indexes).
"""
from .variable import Variable

if not indexers:
Expand All @@ -535,8 +536,11 @@ def isel_variable_and_index(
indexer = indexers[dim]
if isinstance(indexer, Variable):
indexer = indexer.data
pd_index = index.to_pandas_index()
new_index = wrap_pandas_index(pd_index[indexer])
try:
new_index = index[indexer]
except NotImplementedError:
new_index = None

return new_variable, new_index


Expand All @@ -548,7 +552,7 @@ def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex:
new_idx = pd_index[-count:].append(pd_index[:-count])
else:
new_idx = pd_index[:]
return PandasIndex(new_idx)
return PandasIndex(new_idx, index.dim)


def propagate_indexes(
Expand Down
6 changes: 2 additions & 4 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,7 @@ def as_indexable(array):
if isinstance(array, np.ndarray):
return NumpyIndexingAdapter(array)
if isinstance(array, pd.Index):
from .indexes import PandasIndex

return PandasIndex(array)
return PandasIndexingAdapter(array)
if isinstance(array, dask_array_type):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
Expand Down Expand Up @@ -1270,7 +1268,7 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
__slots__ = ("array", "_dtype")

def __init__(self, array: pd.Index, dtype: DTypeLike = None):
self.array = array
self.array = utils.safe_cast_to_index(array)

if dtype is None:
if isinstance(array, pd.PeriodIndex):
Expand Down
Loading

0 comments on commit 84cbf15

Please sign in to comment.