Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add invariant check for IndexVariable.name #8906

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def create_variables(
encoding = None

data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype)
var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding)
var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding, name=name)
return {name: var}

def to_pandas_index(self) -> pd.Index:
Expand Down Expand Up @@ -1153,6 +1153,7 @@ def create_variables(
attrs=attrs,
encoding=encoding,
fastpath=True,
name=name,
)

return index_vars
Expand Down
30 changes: 26 additions & 4 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,12 +2590,14 @@ class IndexVariable(Variable):
unless another name is given.
"""

__slots__ = ()
__slots__ = ("_name",)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the name property instead or is that much much harder?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be better indeed but this is likely more work (not sure how much, though).

IndexVariable still needs a deeper refactor (#8124), or even be eventually dropped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the .name property used?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a quick check: internally only in a few places (conventions / multiindex check, dataarray creation, maybe groupby?) actually.

Externally I have no idea (IndexVariable is public API in theory).


# TODO: PandasIndexingAdapter doesn't match the array api:
_data: PandasIndexingAdapter # type: ignore[assignment]

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
def __init__(
self, dims, data, attrs=None, encoding=None, fastpath=False, name=None
):
super().__init__(dims, data, attrs, encoding, fastpath)
if self.ndim != 1:
raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")
Expand All @@ -2604,6 +2606,11 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
if not isinstance(self._data, PandasIndexingAdapter):
self._data = PandasIndexingAdapter(self._data)

if name is None:
self._name = self.dims[0]
else:
self._name = name

def __dask_tokenize__(self) -> object:
from dask.base import normalize_token

Expand Down Expand Up @@ -2753,7 +2760,22 @@ def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None):
attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding)

return self._replace(data=ndata, attrs=attrs, encoding=encoding)
copied = self._replace(data=ndata, attrs=attrs, encoding=encoding)

return copied

def _replace(
self,
dims=_default,
data=_default,
attrs=_default,
encoding=_default,
) -> Self:
replaced = super()._replace(
dims=dims, data=data, attrs=attrs, encoding=encoding
)
replaced._name = self._name
return replaced

def equals(self, other, equiv=None):
# if equiv is specified, super up
Expand Down Expand Up @@ -2825,7 +2847,7 @@ def get_level_variable(self, level):

@property
def name(self) -> Hashable:
return self.dims[0]
return self._name

@name.setter
def name(self, value) -> NoReturn:
Expand Down
12 changes: 9 additions & 3 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,15 @@ def _assert_indexes_invariants_checks(
}

index_vars = {
k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
k: v
for k, v in possible_coord_variables.items()
if isinstance(v, IndexVariable)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
index_var_names = set(index_vars)
assert indexes.keys() <= index_var_names, (set(indexes), index_var_names)

for k, v in index_vars.items():
assert k == v.name, (k, v.name)

# check pandas index wrappers vs. coordinate data adapters
for k, index in indexes.items():
Expand All @@ -283,7 +289,7 @@ def _assert_indexes_invariants_checks(
if isinstance(index, PandasMultiIndex):
pd_index = index.index
for name in index.index.names:
assert name in possible_coord_variables, (pd_index, index_vars)
assert name in possible_coord_variables, (pd_index, index_var_names)
var = possible_coord_variables[name]
assert (index.dim,) == var.dims, (pd_index, var)
assert index.level_coords_dtype[name] == var.dtype, (
Expand Down
Loading