Skip to content

Commit

Permalink
Fix the return type of python generator function (#2378)
Browse files Browse the repository at this point in the history
The return type of 'List[tensorflow.TensorSpec]' for generator function
`StableHLOToTFSavedModel._make_input_signatures` is error-ed out during
internal testing.

The fix is to annotating the return type as Iterator as all generators
are basically iterators.
  • Loading branch information
sdasgup3 authored Jun 4, 2024
1 parent 42bf96f commit 14e2323
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import itertools
import logging
import os
from typing import Any, Dict, List
from typing import Any, Dict, Iterator, List
import mlir.dialects.stablehlo as stablehlo
import mlir.ir as ir

Expand Down Expand Up @@ -141,7 +141,7 @@ def inner(*args):
def _make_tf_function(self):
return self._wrap_as_tf_func()

def _make_input_signatures(self) -> List[tf.TensorSpec]:
def _make_input_signatures(self) -> Iterator[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in itertools.chain(
Expand Down

0 comments on commit 14e2323

Please sign in to comment.