Skip to content

Commit

Permalink
Override __bool__ of TypedListType
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 16, 2023
1 parent 01011aa commit 60e80f5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
20 changes: 20 additions & 0 deletions pytensor/typed_list/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def __getitem__(self, index):
def __len__(self):
return length(self)

def __bool__(self):
# Truthiness of typedLists cannot depend on length,
# just like truthiness of TensorVariables does not depend on size or contents
return True

def append(self, toAppend):
return append(self, toAppend)

Expand Down Expand Up @@ -677,3 +682,18 @@ def perform(self, node, inputs, outputs):
All PyTensor variables must have the same type.
"""


class MakeEmptyList(Op):
__props__ = ()

def make_node(self, ttype):
tl = TypedListType(ttype)()
return Apply(self, [], [tl])

def perform(self, node, inputs, outputs):
(out,) = outputs
out[0] = []


make_empty_list = MakeEmptyList()
4 changes: 4 additions & 0 deletions tests/typed_list/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,7 @@ def test_variable_is_Typed_List_variable(self):
)()

assert isinstance(mySymbolicVariable, TypedListVariable)

def test_any(self):
tlist = TypedListType(TensorType(dtype="int64", shape=(None,)))()
assert any([tlist])

0 comments on commit 60e80f5

Please sign in to comment.