Skip to content

Commit

Permalink
ENH: torch: add uintN type to __array_namespace_info__
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Feb 22, 2025
1 parent bf43770 commit e2dc3ad
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions array_api_compat/torch/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,26 @@ def _dtypes(self, kind):
int32 = torch.int32
int64 = torch.int64
uint8 = torch.uint8
# uint16, uint32, and uint64 are present in newer versions of pytorch,
# but they aren't generally supported by the array API functions, so
# we omit them from this function.
try:
# pytorch >= 2.3
uint16 = torch.uint16
uint32 = torch.uint32
uint64 = torch.uint64
uint_kinds = {
"uint16": uint16,
"uint32": uint32,
"uint64": uint64,
}
except AttributeError:
uint_kinds = {}

float32 = torch.float32
float64 = torch.float64
complex64 = torch.complex64
complex128 = torch.complex128

if kind is None:
return {
kinds = {
"bool": bool,
"int8": int8,
"int16": int16,
Expand All @@ -190,6 +200,8 @@ def _dtypes(self, kind):
"complex64": complex64,
"complex128": complex128,
}
kinds.update(uint_kinds)
return kinds
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
Expand All @@ -200,17 +212,21 @@ def _dtypes(self, kind):
"int64": int64,
}
if kind == "unsigned integer":
return {
kinds= {
"uint8": uint8,
}
kinds.update(uint_kinds)
return kinds
if kind == "integral":
return {
kinds= {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
}
kinds.update(uint_kinds)
return kinds
if kind == "real floating":
return {
"float32": float32,
Expand All @@ -222,7 +238,7 @@ def _dtypes(self, kind):
"complex128": complex128,
}
if kind == "numeric":
return {
kinds = {
"int8": int8,
"int16": int16,
"int32": int32,
Expand All @@ -233,6 +249,9 @@ def _dtypes(self, kind):
"complex64": complex64,
"complex128": complex128,
}
kinds.update(uint_kinds)
return kinds

if isinstance(kind, tuple):
res = {}
for k in kind:
Expand Down

0 comments on commit e2dc3ad

Please sign in to comment.