Skip to content

Commit

Permalink
Merge branch 'main' into jpivarski/update-taxi-dataset-url
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna authored Jan 7, 2025
2 parents b1ee575 + 432c07a commit 441e7e6
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 68 deletions.
10 changes: 6 additions & 4 deletions src/awkward/_meta/numpymeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from functools import cached_property

from awkward._meta.meta import Meta
from awkward._nplikes.shape import ShapeItem
from awkward._typing import JSONSerializable
Expand Down Expand Up @@ -31,13 +33,13 @@ def purelist_depth(self) -> int:
def is_identity_like(self) -> bool:
return False

@property
def minmax_depth(self) -> tuple[int, int]:
@cached_property
def minmax_depth(self) -> tuple[int, int]: # type: ignore[override]
depth = len(self.inner_shape) + 1
return (depth, depth)

@property
def branch_depth(self) -> tuple[bool, int]:
@cached_property
def branch_depth(self) -> tuple[bool, int]: # type: ignore[override]
return (False, len(self.inner_shape) + 1)

@property
Expand Down
30 changes: 17 additions & 13 deletions src/awkward/_meta/recordmeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from functools import cached_property

from awkward._meta.meta import Meta
from awkward._regularize import is_integer
from awkward._typing import Generic, JSONSerializable, TypeVar
Expand Down Expand Up @@ -39,19 +41,21 @@ def purelist_depth(self) -> int:
def is_identity_like(self) -> bool:
return False

@property
def minmax_depth(self) -> tuple[int, int]:
@cached_property
def minmax_depth(self) -> tuple[int, int]: # type: ignore[override]
if len(self._contents) == 0:
return (1, 1)
mins, maxs = [], []
for content in self._contents:
mindepth, maxdepth = content.minmax_depth
mins.append(mindepth)
maxs.append(maxdepth)
return (min(mins), max(maxs))

@property
def branch_depth(self) -> tuple[bool, int]:
mindepth, maxdepth = self._contents[0].minmax_depth
for content in self._contents[1:]:
mindepth_, maxdepth_ = content.minmax_depth
if mindepth_ < mindepth:
mindepth = mindepth_
if maxdepth_ > maxdepth:
maxdepth = maxdepth_
return (mindepth, maxdepth)

@cached_property
def branch_depth(self) -> tuple[bool, int]: # type: ignore[override]
if len(self._contents) == 0:
return False, 1

Expand Down Expand Up @@ -80,8 +84,8 @@ def is_leaf(self) -> bool: # type: ignore[override]
def contents(self) -> list[T]:
return self._contents

@property
def fields(self) -> list[str]:
@cached_property
def fields(self) -> list[str]: # type: ignore[override]
if self._fields is None:
return [str(i) for i in range(len(self._contents))]
else:
Expand Down
39 changes: 21 additions & 18 deletions src/awkward/_meta/unionmeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections import Counter
from functools import cached_property

from awkward._meta.meta import Meta
from awkward._typing import Generic, JSONSerializable, TypeVar
Expand Down Expand Up @@ -31,15 +32,15 @@ def purelist_parameters(self, *keys: str) -> JSONSerializable:

return None

@property
def purelist_isregular(self) -> bool:
@cached_property
def purelist_isregular(self) -> bool: # type: ignore[override]
for content in self._contents:
if not content.purelist_isregular:
return False
return True

@property
def purelist_depth(self) -> int:
@cached_property
def purelist_depth(self) -> int: # type: ignore[override]
out = None
for content in self._contents:
if out is None:
Expand All @@ -53,19 +54,21 @@ def purelist_depth(self) -> int:
def is_identity_like(self) -> bool:
return False

@property
def minmax_depth(self) -> tuple[int, int]:
@cached_property
def minmax_depth(self) -> tuple[int, int]: # type: ignore[override]
if len(self._contents) == 0:
return (0, 0)
mins, maxs = [], []
for content in self._contents:
mindepth, maxdepth = content.minmax_depth
mins.append(mindepth)
maxs.append(maxdepth)
return (min(mins), max(maxs))

@property
def branch_depth(self) -> tuple[bool, int]:
mindepth, maxdepth = self._contents[0].minmax_depth
for content in self._contents[1:]:
mindepth_, maxdepth_ = content.minmax_depth
if mindepth_ < mindepth:
mindepth = mindepth_
if maxdepth_ > maxdepth:
maxdepth = maxdepth_
return (mindepth, maxdepth)

@cached_property
def branch_depth(self) -> tuple[bool, int]: # type: ignore[override]
if len(self._contents) == 0:
return False, 1

Expand All @@ -83,8 +86,8 @@ def branch_depth(self) -> tuple[bool, int]:
assert min_depth is not None
return any_branch, min_depth

@property
def fields(self) -> list[str]:
@cached_property
def fields(self) -> list[str]: # type: ignore[override]
field_counts = Counter([f for c in self._contents for f in c.fields])
return [f for f, n in field_counts.items() if n == len(self._contents)]

Expand All @@ -102,6 +105,6 @@ def dimension_optiontype(self) -> bool:
def content(self, index: int) -> T:
return self._contents[index]

@property
@cached_property
def contents(self) -> list[T]:
return self._contents
3 changes: 3 additions & 0 deletions src/awkward/_nplikes/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __str__(self) -> str:
def __repr__(self):
return self._instance_name

def __hash__(self):
return hash(self._instance_name)

def __eq__(self, other) -> bool:
if other is self:
return True
Expand Down
74 changes: 41 additions & 33 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import Collection, Iterator, Sequence, Set
from functools import lru_cache
from numbers import Number
from typing import Callable

Expand Down Expand Up @@ -55,7 +56,9 @@ def is_unknown_scalar(array: Any) -> TypeGuard[TypeTracerArray]:


def is_unknown_integer(array: Any) -> TypeGuard[TypeTracerArray]:
return is_unknown_scalar(array) and np.issubdtype(array.dtype, np.integer)
return cast(
bool, is_unknown_scalar(array) and np.issubdtype(array.dtype, np.integer)
)


def is_unknown_array(array: Any) -> TypeGuard[TypeTracerArray]:
Expand Down Expand Up @@ -1147,38 +1150,7 @@ def derive_slice_for_length(
return start, stop, step, self.index_as_shape_item(slice_length)

def broadcast_shapes(self, *shapes: tuple[ShapeItem, ...]) -> tuple[ShapeItem, ...]:
ndim = max((len(s) for s in shapes), default=0)
result: list[ShapeItem] = [1] * ndim

for shape in shapes:
# Right broadcasting
missing_dim = ndim - len(shape)
if missing_dim > 0:
head: tuple[int, ...] = (1,) * missing_dim
shape = head + shape

# Fail if we absolutely know the shapes aren't compatible
for i, item in enumerate(shape):
# Item is unknown, take it
if is_unknown_length(item):
result[i] = item
# Existing item is unknown, keep it
elif is_unknown_length(result[i]):
continue
# Items match, continue
elif result[i] == item:
continue
# Item is broadcastable, take existing
elif item == 1:
continue
# Existing is broadcastable, take it
elif result[i] == 1:
result[i] = item
else:
raise ValueError(
"known component of shape does not match broadcast result"
)
return tuple(result)
return _broadcast_shapes(*shapes)

def broadcast_arrays(self, *arrays: TypeTracerArray) -> list[TypeTracerArray]:
for x in arrays:
Expand Down Expand Up @@ -1706,6 +1678,42 @@ def __dlpack__(self, stream=None):
raise NotImplementedError


@lru_cache
def _broadcast_shapes(*shapes):
ndim = max((len(s) for s in shapes), default=0)
result: list[ShapeItem] = [1] * ndim

for shape in shapes:
# Right broadcasting
missing_dim = ndim - len(shape)
if missing_dim > 0:
head: tuple[int, ...] = (1,) * missing_dim
shape = head + shape

# Fail if we absolutely know the shapes aren't compatible
for i, item in enumerate(shape):
# Item is unknown, take it
if is_unknown_length(item):
result[i] = item
# Existing item is unknown, keep it
elif is_unknown_length(result[i]):
continue
# Items match, continue
elif result[i] == item:
continue
# Item is broadcastable, take existing
elif item == 1:
continue
# Existing is broadcastable, take it
elif result[i] == 1:
result[i] = item
else:
raise ValueError(
"known component of shape does not match broadcast result"
)
return tuple(result)


def _attach_report(
layout: Content,
form: Form,
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping
from fnmatch import fnmatchcase
from functools import lru_cache
from glob import escape as escape_glob

import awkward as ak
Expand Down Expand Up @@ -202,6 +203,7 @@ def from_dict(input: Mapping) -> Form:
)


@lru_cache
def from_json(input: str) -> Form:
return from_dict(json.loads(input))

Expand Down
3 changes: 3 additions & 0 deletions src/awkward/types/numpytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import re
from collections.abc import Mapping
from functools import lru_cache

from awkward._behavior import find_array_typestr
from awkward._nplikes.numpy_like import NumpyMetadata
Expand All @@ -25,6 +26,7 @@ def is_primitive(primitive):
return primitive in _primitive_to_dtype_dict


@lru_cache
def primitive_to_dtype(primitive):
if _primitive_to_dtype_datetime.match(primitive) is not None:
return np.dtype(primitive)
Expand All @@ -42,6 +44,7 @@ def primitive_to_dtype(primitive):
return out


@lru_cache
def dtype_to_primitive(dtype):
if dtype.kind.upper() == "M" and dtype == dtype.newbyteorder("="):
return str(dtype)
Expand Down

0 comments on commit 441e7e6

Please sign in to comment.