Skip to content

Commit

Permalink
Merge pull request #1359 from cuthbertLab/gebc-multiple-classes
Browse files Browse the repository at this point in the history
Allow iterables of qualified class names in `Stream.__getitem__` searches
  • Loading branch information
mscuthbert authored Aug 12, 2022
2 parents ee04424 + 6155e9d commit c538fb8
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 27 deletions.
92 changes: 71 additions & 21 deletions music21/stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,31 @@ def __getitem__(
x = t.cast(iterator.RecursiveIterator[ChangedM21ObjType], self.recurse())
return x # dummy code

@overload
def __getitem__(
self,
k: t.Type # getting something that is a subclass of something that is not a m21 object
) -> iterator.RecursiveIterator[M21ObjType]:
x = t.cast(iterator.RecursiveIterator[M21ObjType], self.recurse())
return x # dummy code


@overload
def __getitem__(
self,
k: t.Collection[t.Type]
) -> iterator.RecursiveIterator[M21ObjType]:
# Remove this code and replace with ... once Astroid #1015 is fixed.
x: iterator.RecursiveIterator[M21ObjType] = self.recurse()
return x


def __getitem__(self,
k: t.Union[str, int, slice, t.Type[ChangedM21ObjType]]
k: t.Union[str,
int,
slice,
t.Type[ChangedM21ObjType],
t.Collection[t.Type]]
) -> t.Union[iterator.RecursiveIterator[M21ObjType],
iterator.RecursiveIterator[ChangedM21ObjType],
M21ObjType,
Expand Down Expand Up @@ -486,8 +509,8 @@ def __getitem__(self,
3.0


If a class is given then an iterator of elements
that match the requested class(es) is returned, similar
If a class is given, then a :class:`~music21.stream.iterator.RecursiveIterator`
of elements matching the requested class is returned, similar
to `Stream().recurse().getElementsByClass()`.

>>> len(s)
Expand All @@ -501,15 +524,32 @@ def __getitem__(self,
... print(n.name, end=' ')
C D E F G A

Note that this iterator is recursive by default.
Note that this iterator is recursive: it will find elements inside of streams
within this stream:

>>> c_sharp = note.Note('C#')
>>> v = stream.Voice([c_sharp])
>>> s.insert(0.5, c_sharp)

>>> v = stream.Voice()
>>> v.insert(0, c_sharp)
>>> s.insert(0.5, v)
>>> len(s[note.Note])
7

When using a single Music21 class in this way, your type checker will
be able to infer that the only objects in any loop are in fact `note.Note`
objects, and catch programming errors before running.

Multiple classes can be provided, separated by commas. Any element matching
any of the requested classes will be matched.

>>> len(s[note.Note, note.Rest])
9

>>> for note_or_rest in s[note.Note, note.Rest]:
... if isinstance(note_or_rest, note.Note):
... print(note_or_rest.name, end=' ')
... else:
... print('Rest', end=' ')
C C# D E Rest F G Rest A

The actual object returned by `s[module.Class]` is a
:class:`~music21.stream.iterator.RecursiveIterator` and has all the functions
Expand Down Expand Up @@ -556,7 +596,8 @@ def __getitem__(self,

>>> s[0.5]
Traceback (most recent call last):
TypeError: Streams can get items by int, slice, class, or string query; got <class 'float'>
TypeError: Streams can get items by int, slice, class, class iterable, or string query;
got <class 'float'>

Changed in v7:
- out of range indexes now raise an IndexError, not StreamException
Expand All @@ -573,6 +614,7 @@ def __getitem__(self,
.recurse().getElementsByClass to get the earlier behavior. Old behavior
still works until v9. This is an attempt to unify __getitem__ behavior in
StreamIterators and Streams.
- allowed iterables of qualified class names, e.g. `[note.Note, note.Rest]`
'''
# need to sort if not sorted, as this call may rely on index positions
if not self.isSorted and self.autoSort:
Expand Down Expand Up @@ -607,7 +649,10 @@ def __getitem__(self,

return t.cast(M21ObjType, searchElements[k])

elif isinstance(k, type) and issubclass(k, base.Music21Object):
elif isinstance(k, type):
return self.recurse().getElementsByClass(k)

elif common.isIterable(k) and all(isinstance(maybe_type, type) for maybe_type in k):
return self.recurse().getElementsByClass(k)

elif isinstance(k, str):
Expand All @@ -619,7 +664,8 @@ def __getitem__(self,
return querySelectorIterator

raise TypeError(
f'Streams can get items by int, slice, class, or string query; got {type(k)}'
'Streams can get items by int, slice, class, class iterable, or string query; '
f'got {type(k)}'
)

def first(self) -> t.Optional[M21ObjType]:
Expand Down Expand Up @@ -4724,7 +4770,7 @@ def optionalAddRest():

# Replace old measures in spanners with new measures
# Example: out is a Part, out.spannerBundle has RepeatBrackets spanning measures
# TODO: when dropping support for Py3.9 add strict=True
# TODO: when dropping support for Py3.9 (min=3.10) add strict=True
for oldM, newM in zip(
self.getElementsByClass(Measure),
out.getElementsByClass(Measure)
Expand Down Expand Up @@ -7707,26 +7753,27 @@ def semiFlat(self):
@overload
def recurse(self,
*,
streamsOnly: t.Literal[True],
streamsOnly: t.Literal[False] = False,
restoreActiveSites=True,
classFilter=(),
includeSelf=None) -> iterator.RecursiveIterator[Stream]:
return iterator.RecursiveIterator(self).getElementsByClass(Stream)
includeSelf=None) -> iterator.RecursiveIterator[M21ObjType]:
return t.cast(iterator.RecursiveIterator[M21ObjType], iterator.RecursiveIterator(self))

@overload
def recurse(self,
*,
streamsOnly: t.Literal[False] = False,
streamsOnly: t.Literal[True],
restoreActiveSites=True,
classFilter=(),
includeSelf=None) -> iterator.RecursiveIterator[M21ObjType]:
return iterator.RecursiveIterator(self)
includeSelf=None) -> iterator.RecursiveIterator[Stream]:
return t.cast(iterator.RecursiveIterator[Stream],
iterator.RecursiveIterator(self).getElementsByClass(Stream))

def recurse(self: StreamType,
def recurse(self,
*,
streamsOnly=False,
restoreActiveSites=True,
classFilter=(),
streamsOnly: bool = False,
restoreActiveSites: bool = True,
classFilter: t.Tuple = (),
includeSelf=None) -> t.Union[iterator.RecursiveIterator[M21ObjType],
iterator.RecursiveIterator[Stream]]:
'''
Expand Down Expand Up @@ -7875,6 +7922,9 @@ def recurse(self: StreamType,
)
if classFilter:
ri = ri.getElementsByClass(classFilter)

if t.TYPE_CHECKING and streamsOnly:
return t.cast(iterator.RecursiveIterator[Stream], ri)
return ri

def containerInHierarchy(
Expand Down
39 changes: 33 additions & 6 deletions music21/stream/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,15 @@ def getElementsByClass(self,
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x

# @overload
# def getElementsByClass(self,
# classFilterList: t.Type,
# *,
# returnClone: bool = True) -> StreamIterator[M21ObjType]:
# # putting a non-music21 type into classFilterList, defaults to the previous type
# x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
# return x

@overload
def getElementsByClass(self,
classFilterList: t.Type[ChangedM21ObjType],
Expand All @@ -1018,9 +1027,10 @@ def getElementsByClass(self,

@overload
def getElementsByClass(self,
classFilterList: t.Iterable[t.Type[ChangedM21ObjType]],
classFilterList: t.Iterable[t.Type],
*,
returnClone: bool = True) -> StreamIterator[M21ObjType]:
# putting multiple types into classFilterList, defaults to the previous type
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x

Expand All @@ -1031,7 +1041,7 @@ def getElementsByClass(
str,
t.Type[ChangedM21ObjType],
t.Iterable[str],
t.Iterable[t.Type[ChangedM21ObjType]],
t.Iterable[t.Type],
],
*,
returnClone: bool = True
Expand Down Expand Up @@ -1653,9 +1663,18 @@ def getElementsByClass(self,
x = t.cast(OffsetIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x

# @overload
# def getElementsByClass(self,
# classFilterList: t.Type,
# *,
# returnClone: bool = True) -> OffsetIterator[M21ObjType]:
# x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
# return x


@overload
def getElementsByClass(self,
classFilterList: t.Iterable[t.Type[ChangedM21ObjType]],
classFilterList: t.Iterable[t.Type],
*,
returnClone: bool = True) -> OffsetIterator[M21ObjType]:
x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
Expand All @@ -1667,12 +1686,12 @@ def getElementsByClass(self,
str,
t.Type[ChangedM21ObjType],
t.Iterable[str],
t.Iterable[t.Type[ChangedM21ObjType]],
t.Iterable[t.Type],
],
*,
returnClone: bool = True
) -> t.Union[OffsetIterator[M21ObjType],
OffsetIterator[ChangedM21ObjType]]:
OffsetIterator[ChangedM21ObjType]]:
'''
Identical to the same method in StreamIterator, but needs to be duplicated
for now.
Expand Down Expand Up @@ -2055,9 +2074,17 @@ def getElementsByClass(self,
x = t.cast(RecursiveIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x # dummy code

# @overload
# def getElementsByClass(self,
# classFilterList: t.Type,
# *,
# returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
# x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
# return x # dummy code

@overload
def getElementsByClass(self,
classFilterList: t.Iterable[t.Type[ChangedM21ObjType]],
classFilterList: t.Iterable[t.Type],
*,
returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
Expand Down

0 comments on commit c538fb8

Please sign in to comment.