-
-
Notifications
You must be signed in to change notification settings - Fork 18.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Backport PR #47762 on branch 1.5.x (REGR: preserve reindexed array ob…
…ject (instead of creating new array) for concat with all-NA array) (#48309) Backport PR #47762: REGR: preserve reindexed array object (instead of creating new array) for concat with all-NA array Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
- Loading branch information
1 parent
3ca5773
commit 46f7167
Showing
5 changed files
with
139 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pandas.tests.extension.array_with_attr.array import ( | ||
FloatAttrArray, | ||
FloatAttrDtype, | ||
) | ||
|
||
__all__ = ["FloatAttrArray", "FloatAttrDtype"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
Test extension array that has custom attribute information (not stored on the dtype). | ||
""" | ||
from __future__ import annotations | ||
|
||
import numbers | ||
|
||
import numpy as np | ||
|
||
from pandas._typing import type_t | ||
|
||
from pandas.core.dtypes.base import ExtensionDtype | ||
|
||
import pandas as pd | ||
from pandas.core.arrays import ExtensionArray | ||
|
||
|
||
class FloatAttrDtype(ExtensionDtype): | ||
type = float | ||
name = "float_attr" | ||
na_value = np.nan | ||
|
||
@classmethod | ||
def construct_array_type(cls) -> type_t[FloatAttrArray]: | ||
""" | ||
Return the array type associated with this dtype. | ||
Returns | ||
------- | ||
type | ||
""" | ||
return FloatAttrArray | ||
|
||
|
||
class FloatAttrArray(ExtensionArray): | ||
dtype = FloatAttrDtype() | ||
__array_priority__ = 1000 | ||
|
||
def __init__(self, values, attr=None) -> None: | ||
if not isinstance(values, np.ndarray): | ||
raise TypeError("Need to pass a numpy array of float64 dtype as values") | ||
if not values.dtype == "float64": | ||
raise TypeError("Need to pass a numpy array of float64 dtype as values") | ||
self.data = values | ||
self.attr = attr | ||
|
||
@classmethod | ||
def _from_sequence(cls, scalars, dtype=None, copy=False): | ||
data = np.array(scalars, dtype="float64", copy=copy) | ||
return cls(data) | ||
|
||
def __getitem__(self, item): | ||
if isinstance(item, numbers.Integral): | ||
return self.data[item] | ||
else: | ||
# slice, list-like, mask | ||
item = pd.api.indexers.check_array_indexer(self, item) | ||
return type(self)(self.data[item], self.attr) | ||
|
||
def __len__(self) -> int: | ||
return len(self.data) | ||
|
||
def isna(self): | ||
return np.isnan(self.data) | ||
|
||
def take(self, indexer, allow_fill=False, fill_value=None): | ||
from pandas.api.extensions import take | ||
|
||
data = self.data | ||
if allow_fill and fill_value is None: | ||
fill_value = self.dtype.na_value | ||
|
||
result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill) | ||
return type(self)(result, self.attr) | ||
|
||
def copy(self): | ||
return type(self)(self.data.copy(), self.attr) | ||
|
||
@classmethod | ||
def _concat_same_type(cls, to_concat): | ||
data = np.concatenate([x.data for x in to_concat]) | ||
attr = to_concat[0].attr if len(to_concat) else None | ||
return cls(data, attr) |
33 changes: 33 additions & 0 deletions
33
pandas/tests/extension/array_with_attr/test_array_with_attr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
|
||
import pandas as pd | ||
import pandas._testing as tm | ||
from pandas.tests.extension.array_with_attr import FloatAttrArray | ||
|
||
|
||
def test_concat_with_all_na(): | ||
# https://github.com/pandas-dev/pandas/pull/47762 | ||
# ensure that attribute of the column array is preserved (when it gets | ||
# preserved in reindexing the array) during merge/concat | ||
arr = FloatAttrArray(np.array([np.nan, np.nan], dtype="float64"), attr="test") | ||
|
||
df1 = pd.DataFrame({"col": arr, "key": [0, 1]}) | ||
df2 = pd.DataFrame({"key": [0, 1], "col2": [1, 2]}) | ||
result = pd.merge(df1, df2, on="key") | ||
expected = pd.DataFrame({"col": arr, "key": [0, 1], "col2": [1, 2]}) | ||
tm.assert_frame_equal(result, expected) | ||
assert result["col"].array.attr == "test" | ||
|
||
df1 = pd.DataFrame({"col": arr, "key": [0, 1]}) | ||
df2 = pd.DataFrame({"key": [0, 2], "col2": [1, 2]}) | ||
result = pd.merge(df1, df2, on="key") | ||
expected = pd.DataFrame({"col": arr.take([0]), "key": [0], "col2": [1]}) | ||
tm.assert_frame_equal(result, expected) | ||
assert result["col"].array.attr == "test" | ||
|
||
result = pd.concat([df1.set_index("key"), df2.set_index("key")], axis=1) | ||
expected = pd.DataFrame( | ||
{"col": arr.take([0, 1, -1]), "col2": [1, np.nan, 2], "key": [0, 1, 2]} | ||
).set_index("key") | ||
tm.assert_frame_equal(result, expected) | ||
assert result["col"].array.attr == "test" |