Skip to content

Commit

Permalink
Update with NumPy 1.25 (#382)
Browse files Browse the repository at this point in the history
* Fix behaviour with new find_common_dtype

* Replace amax and amin aliases by max and min

* Remove exceptions namespace for backward compatibility

* Broaden promote_types exception to TypeError for backward compatibility

* Use both amax and max aliases in handled functions

* Only use TypeError for DtypePromotionError
  • Loading branch information
rhugonnet authored Aug 8, 2023
1 parent 89dced1 commit 2f3a547
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- matplotlib
- pyproj
- rasterio>=1.3
- numpy=1.24
- numpy
- scipy
- tqdm
- xarray
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- matplotlib
- pyproj
- rasterio>=1.3
- numpy=1.24
- numpy
- scipy
- tqdm
- xarray
Expand Down
6 changes: 4 additions & 2 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
"sum",
"amax",
"amin",
"max",
"min",
"argmax",
"argmin",
"mean",
Expand Down Expand Up @@ -893,7 +895,7 @@ def _overloading_check(
dtype2 = rio.dtypes.get_minimum_dtype(other_data)

# Figure out output dtype
out_dtype = np.find_common_type([dtype1, dtype2], [])
out_dtype = np.promote_types(dtype1, dtype2)

# Figure output nodata
out_nodata = None
Expand Down Expand Up @@ -3091,7 +3093,7 @@ def polygonize(
gpd_dtypes = ["uint8", "uint16", "int16", "int32", "float32"]
list_common_dtype_index = []
for gpd_type in gpd_dtypes:
polygonize_dtype = np.find_common_type([gpd_type, self.dtypes[0]], [])
polygonize_dtype = np.promote_types(gpd_type, self.dtypes[0])
if str(polygonize_dtype) in gpd_dtypes:
list_common_dtype_index.append(gpd_dtypes.index(gpd_type))
if len(list_common_dtype_index) == 0:
Expand Down
33 changes: 26 additions & 7 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3054,7 +3054,7 @@ def test_ops_2args_expl(self, op: str) -> None:
r1 = self.r1
r2 = self.r2
r3 = getattr(r1, op)(r2)
ctype = np.find_common_type([r1.data.dtype, r2.data.dtype], [])
ctype = np.promote_types(r1.data.dtype, r2.data.dtype)
numpy_output = getattr(r1.data.astype(ctype), op)(r2.data.astype(ctype))
assert isinstance(r3, gu.Raster)
assert np.all(r3.data == numpy_output)
Expand Down Expand Up @@ -3559,7 +3559,12 @@ def test_array_ufunc_1nin_1nout(self, ufunc_str: str, nodata_init: None | str, d
ufunc = getattr(np, ufunc_str)

# Find the common dtype between the Raster and the most constrained input type (first character is the input)
com_dtype = np.find_common_type([dtype] + [ufunc.types[0][0]], [])
try:
com_dtype = np.promote_types(dtype, ufunc.types[0][0])
# The promote_types function raises an error for object dtypes (previously returned by find_common_dtypes)
# (TypeError needed for backwards compatibility; also exceptions.DTypePromotionError for NumPy 1.25 and above)
except TypeError:
com_dtype = np.dtype("O")

# Catch warnings
with warnings.catch_warnings():
Expand Down Expand Up @@ -3618,13 +3623,27 @@ def test_array_ufunc_2nin_1nout(
ufunc = getattr(np, ufunc_str)

# Find the common dtype between the Raster and the most constrained input type (first character is the input)
com_dtype1 = np.find_common_type([dtype1] + [ufunc.types[0][0]], [])
com_dtype2 = np.find_common_type([dtype2] + [ufunc.types[0][1]], [])
try:
com_dtype1 = np.promote_types(dtype1, ufunc.types[0][0])
# The promote_types function raises an error for object dtypes (previously returned by find_common_dtypes)
# (TypeError needed for backwards compatibility; also exceptions.DTypePromotionError for NumPy 1.25 and above)
except TypeError:
com_dtype1 = np.dtype("O")

try:
com_dtype2 = np.promote_types(dtype2, ufunc.types[0][1])
# The promote_types function raises an error for object dtypes (previously returned by find_common_dtypes)
# (TypeError needed for backwards compatibility; also exceptions.DTypePromotionError for NumPy 1.25 and above)
except TypeError:
com_dtype2 = np.dtype("O")

# If the two input types can be the same type, pass a tuple with the common type of both
# Below we ignore datetime and timedelta types "m" and "M"
if all(t[0] == t[1] for t in ufunc.types if not ("m" or "M") in t[0:2]):
com_dtype_both = np.find_common_type([com_dtype1, com_dtype2], [])
# Below we ignore datetime and timedelta types "m" and "M", and int64 types "q" and "Q"
if all(t[0] == t[1] for t in ufunc.types if not any(x in t[0:2] for x in ["m", "M", "q", "Q"])):
try:
com_dtype_both = np.promote_types(com_dtype1, com_dtype2)
except TypeError:
com_dtype_both = np.dtype("O")
com_dtype_tuple = (com_dtype_both, com_dtype_both)

# Otherwise, pass the tuple with each common type
Expand Down

0 comments on commit 2f3a547

Please sign in to comment.