From dbc6f488ecff53ea0d7b5d7a7b1ae44d8b7c8dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 28 Dec 2023 11:52:06 -0500 Subject: [PATCH] Add "trng" support to mux() as well --- lhotse/lazy.py | 15 +++++++++------ lhotse/utils.py | 8 ++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/lhotse/lazy.py b/lhotse/lazy.py index 07b6a3232..d5bd55528 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -12,7 +12,13 @@ extension_contains, open_best, ) -from lhotse.utils import Pathlike, fastcopy, is_module_available, streaming_shuffle +from lhotse.utils import ( + Pathlike, + build_rng, + fastcopy, + is_module_available, + streaming_shuffle, +) T = TypeVar("T") @@ -360,7 +366,7 @@ def __init__( assert len(self.iterators) == len(self.weights) def __iter__(self): - rng = random.Random(self.seed) + rng = build_rng(self.seed) iters = [iter(it) for it in self.iterators] exhausted = [False for _ in range(len(iters))] @@ -450,10 +456,7 @@ def __iter__(self): - each stream may be interpreted as a shard belonging to some larger group of streams (e.g. multiple shards of a given dataset). """ - if self.seed == "trng": - rng = secrets.SystemRandom() - else: - rng = random.Random(self.seed) + rng = build_rng(self.seed) def shuffled_streams(): # Create an infinite iterable of our streams. diff --git a/lhotse/utils.py b/lhotse/utils.py index 584d862b4..80558e0b7 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -6,6 +6,7 @@ import math import os import random +import secrets import sys import urllib import uuid @@ -1092,3 +1093,10 @@ def type_cast_value(self, ctx, value): def is_torchaudio_available() -> bool: return is_module_available("torchaudio") + + +def build_rng(seed: Union[int, Literal["trng"]]) -> random.Random: + if seed == "trng": + return secrets.SystemRandom() + else: + return random.Random(seed)