Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow iterables of qualified class names in Stream.__getitem__ searches #1359

Merged
merged 6 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 71 additions & 21 deletions music21/stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,31 @@ def __getitem__(
x = t.cast(iterator.RecursiveIterator[ChangedM21ObjType], self.recurse())
return x # dummy code

@overload
def __getitem__(
self,
k: t.Type # getting something that is a subclass of something that is not a m21 object
) -> iterator.RecursiveIterator[M21ObjType]:
x = t.cast(iterator.RecursiveIterator[M21ObjType], self.recurse())
return x # dummy code


@overload
def __getitem__(
self,
k: t.Collection[t.Type]
) -> iterator.RecursiveIterator[M21ObjType]:
# Remove this code and replace with ... once Astroid #1015 is fixed.
x: iterator.RecursiveIterator[M21ObjType] = self.recurse()
return x


def __getitem__(self,
k: t.Union[str, int, slice, t.Type[ChangedM21ObjType]]
k: t.Union[str,
int,
slice,
t.Type[ChangedM21ObjType],
t.Collection[t.Type]]
) -> t.Union[iterator.RecursiveIterator[M21ObjType],
iterator.RecursiveIterator[ChangedM21ObjType],
M21ObjType,
Expand Down Expand Up @@ -486,8 +509,8 @@ def __getitem__(self,
3.0


If a class is given then an iterator of elements
that match the requested class(es) is returned, similar
If a class is given, then a :class:`~music21.stream.iterator.RecursiveIterator`
of elements matching the requested class is returned, similar
to `Stream().recurse().getElementsByClass()`.

>>> len(s)
Expand All @@ -501,15 +524,32 @@ def __getitem__(self,
... print(n.name, end=' ')
C D E F G A

Note that this iterator is recursive by default.
Note that this iterator is recursive: it will find elements inside of streams
within this stream:

>>> c_sharp = note.Note('C#')
>>> v = stream.Voice([c_sharp])
>>> s.insert(0.5, c_sharp)

>>> v = stream.Voice()
>>> v.insert(0, c_sharp)
>>> s.insert(0.5, v)
>>> len(s[note.Note])
7

When using a single Music21 class in this way, your type checker will
be able to infer that the only objects in any loop are in fact `note.Note`
objects, and catch programming errors before running.

Multiple classes can be provided, separated by commas. Any element matching
any of the requested classes will be matched.

>>> len(s[note.Note, note.Rest])
9

>>> for note_or_rest in s[note.Note, note.Rest]:
... if isinstance(note_or_rest, note.Note):
... print(note_or_rest.name, end=' ')
... else:
... print('Rest', end=' ')
C C# D E Rest F G Rest A

The actual object returned by `s[module.Class]` is a
:class:`~music21.stream.iterator.RecursiveIterator` and has all the functions
Expand Down Expand Up @@ -556,7 +596,8 @@ def __getitem__(self,

>>> s[0.5]
Traceback (most recent call last):
TypeError: Streams can get items by int, slice, class, or string query; got <class 'float'>
TypeError: Streams can get items by int, slice, class, class iterable, or string query;
got <class 'float'>

Changed in v7:
- out of range indexes now raise an IndexError, not StreamException
Expand All @@ -573,6 +614,7 @@ def __getitem__(self,
.recurse().getElementsByClass to get the earlier behavior. Old behavior
still works until v9. This is an attempt to unify __getitem__ behavior in
StreamIterators and Streams.
- allowed iterables of qualified class names, e.g. `[note.Note, note.Rest]`
'''
# need to sort if not sorted, as this call may rely on index positions
if not self.isSorted and self.autoSort:
Expand Down Expand Up @@ -607,7 +649,10 @@ def __getitem__(self,

return t.cast(M21ObjType, searchElements[k])

elif isinstance(k, type) and issubclass(k, base.Music21Object):
elif isinstance(k, type):
return self.recurse().getElementsByClass(k)

elif common.isIterable(k) and all(isinstance(maybe_type, type) for maybe_type in k):
return self.recurse().getElementsByClass(k)

elif isinstance(k, str):
Expand All @@ -619,7 +664,8 @@ def __getitem__(self,
return querySelectorIterator

raise TypeError(
f'Streams can get items by int, slice, class, or string query; got {type(k)}'
'Streams can get items by int, slice, class, class iterable, or string query; '
f'got {type(k)}'
)

def first(self) -> t.Optional[M21ObjType]:
Expand Down Expand Up @@ -4724,7 +4770,7 @@ def optionalAddRest():

# Replace old measures in spanners with new measures
# Example: out is a Part, out.spannerBundle has RepeatBrackets spanning measures
# TODO: when dropping support for Py3.9 add strict=True
# TODO: when dropping support for Py3.9 (min=3.10) add strict=True
for oldM, newM in zip(
self.getElementsByClass(Measure),
out.getElementsByClass(Measure)
Expand Down Expand Up @@ -7707,26 +7753,27 @@ def semiFlat(self):
@overload
def recurse(self,
*,
streamsOnly: t.Literal[True],
streamsOnly: t.Literal[False] = False,
restoreActiveSites=True,
classFilter=(),
includeSelf=None) -> iterator.RecursiveIterator[Stream]:
return iterator.RecursiveIterator(self).getElementsByClass(Stream)
includeSelf=None) -> iterator.RecursiveIterator[M21ObjType]:
return t.cast(iterator.RecursiveIterator[M21ObjType], iterator.RecursiveIterator(self))

@overload
def recurse(self,
*,
streamsOnly: t.Literal[False] = False,
streamsOnly: t.Literal[True],
restoreActiveSites=True,
classFilter=(),
includeSelf=None) -> iterator.RecursiveIterator[M21ObjType]:
return iterator.RecursiveIterator(self)
includeSelf=None) -> iterator.RecursiveIterator[Stream]:
return t.cast(iterator.RecursiveIterator[Stream],
iterator.RecursiveIterator(self).getElementsByClass(Stream))

def recurse(self: StreamType,
def recurse(self,
*,
streamsOnly=False,
restoreActiveSites=True,
classFilter=(),
streamsOnly: bool = False,
restoreActiveSites: bool = True,
classFilter: t.Tuple = (),
includeSelf=None) -> t.Union[iterator.RecursiveIterator[M21ObjType],
iterator.RecursiveIterator[Stream]]:
'''
Expand Down Expand Up @@ -7875,6 +7922,9 @@ def recurse(self: StreamType,
)
if classFilter:
ri = ri.getElementsByClass(classFilter)

if t.TYPE_CHECKING and streamsOnly:
return t.cast(iterator.RecursiveIterator[Stream], ri)
return ri

def containerInHierarchy(
Expand Down
39 changes: 33 additions & 6 deletions music21/stream/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,15 @@ def getElementsByClass(self,
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x

# @overload
# def getElementsByClass(self,
# classFilterList: t.Type,
# *,
# returnClone: bool = True) -> StreamIterator[M21ObjType]:
# # putting a non-music21 type into classFilterList, defaults to the previous type
# x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
# return x

@overload
def getElementsByClass(self,
classFilterList: t.Type[ChangedM21ObjType],
Expand All @@ -1018,9 +1027,10 @@ def getElementsByClass(self,

@overload
def getElementsByClass(self,
classFilterList: t.Iterable[t.Type[ChangedM21ObjType]],
classFilterList: t.Iterable[t.Type],
*,
returnClone: bool = True) -> StreamIterator[M21ObjType]:
# putting multiple types into classFilterList, defaults to the previous type
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x

Expand All @@ -1031,7 +1041,7 @@ def getElementsByClass(
str,
t.Type[ChangedM21ObjType],
t.Iterable[str],
t.Iterable[t.Type[ChangedM21ObjType]],
t.Iterable[t.Type],
],
*,
returnClone: bool = True
Expand Down Expand Up @@ -1653,9 +1663,18 @@ def getElementsByClass(self,
x = t.cast(OffsetIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x

# @overload
# def getElementsByClass(self,
# classFilterList: t.Type,
# *,
# returnClone: bool = True) -> OffsetIterator[M21ObjType]:
# x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
# return x


@overload
def getElementsByClass(self,
classFilterList: t.Iterable[t.Type[ChangedM21ObjType]],
classFilterList: t.Iterable[t.Type],
*,
returnClone: bool = True) -> OffsetIterator[M21ObjType]:
x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
Expand All @@ -1667,12 +1686,12 @@ def getElementsByClass(self,
str,
t.Type[ChangedM21ObjType],
t.Iterable[str],
t.Iterable[t.Type[ChangedM21ObjType]],
t.Iterable[t.Type],
],
*,
returnClone: bool = True
) -> t.Union[OffsetIterator[M21ObjType],
OffsetIterator[ChangedM21ObjType]]:
OffsetIterator[ChangedM21ObjType]]:
'''
Identical to the same method in StreamIterator, but needs to be duplicated
for now.
Expand Down Expand Up @@ -2055,9 +2074,17 @@ def getElementsByClass(self,
x = t.cast(RecursiveIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x # dummy code

# @overload
# def getElementsByClass(self,
# classFilterList: t.Type,
# *,
# returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
# x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
# return x # dummy code

@overload
def getElementsByClass(self,
classFilterList: t.Iterable[t.Type[ChangedM21ObjType]],
classFilterList: t.Iterable[t.Type],
*,
returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
Expand Down