Skip to content

Commit

Permalink
Force super().__init__ for Dataset subclasses.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711849887
  • Loading branch information
iindyk authored and copybara-github committed Jan 3, 2025
1 parent 20f1174 commit a943790
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LazyDataset base classes.
"""Dataset base classes.
There are 3 main classes:
- `MapDataset` define a dataset that supports efficient random access. It
Expand Down Expand Up @@ -99,10 +99,10 @@ def _default_seed(self) -> int | None:
aggregated_seed = []
# Note that the traversal order must be determisitic.
# pylint: disable=protected-access
to_visit = [(self, 0)]
to_visit: list[tuple[_Dataset, int]] = [(self, 0)]
while to_visit:
node, depth = to_visit.pop(0)
if (node_seed := getattr(node, "_seed_rng_seed", None)) is not None:
if (node_seed := node._seed_rng_seed) is not None:
aggregated_seed.extend((node_seed, depth))
else:
to_visit.extend((n, depth + 1) for n in node._parents)
Expand Down Expand Up @@ -693,12 +693,9 @@ def _initialize_stats(
Returns:
The initialized stats object.
"""
# There may be parent `MapDataset` nodes introduced by users that did not
# call super init and thus don't have `_parents`.
parents_stats = []
if hasattr(self, "_parents"):
for p in self._parents:
parents_stats.append(p._initialize_stats(execution_tracking_mode)) # pylint: disable=protected-access
parents_stats = [
p._initialize_stats(execution_tracking_mode) for p in self.parents # pylint: disable=protected-access
]
self._stats = dataset_stats.make_stats(
dataset_stats.StatsConfig(
name=str(self), transform_mutates_spec=self._MUTATES_ELEMENT_SPEC
Expand Down Expand Up @@ -1081,7 +1078,7 @@ def __init__(
for p in self._parents:
# Not all user iterators call super().__init__ and thus don't have the
# options set.
if (p_options := getattr(p, "_options", None)) is not None:
if (p_options := p._options) is not None: # pylint: disable=protected-access
parent_options.append(p_options)
if parent_options:
self._options = functools.reduce(lambda x, y: x.merge(y), parent_options)
Expand All @@ -1095,7 +1092,7 @@ def _options_with_default(self) -> DatasetOptions:
"""
# TODO: Relax the requirement to access options after all iterators
# in the pipeline have been initialized.
return getattr(self, "_options", None) or DatasetOptions()
return self._options or DatasetOptions()

@property
def _parent(self) -> DatasetIterator:
Expand Down Expand Up @@ -1147,10 +1144,7 @@ def _stats(self):
"""Returns the Stats object for recording statistics about this iterator."""
# There may be parent `DatasetIterator` nodes introduced by users that did
# not call super init and thus don't have `_stats`.
parents_stats = []
if hasattr(self, "_parents"):
for p in self._parents:
parents_stats.append(p._stats) # pylint: disable=protected-access
parents_stats = [p._stats for p in self._parents] # pylint: disable=protected-access
return dataset_stats.make_stats(
dataset_stats.StatsConfig(
name=str(self),
Expand Down

0 comments on commit a943790

Please sign in to comment.