diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 7ba08998ae..82afdaed44 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -483,7 +483,31 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics) - for batch_idx, batch_items in enumerate(progress_bar_train_loader): + class PrefetchIterator: + def __init__(self, iterator): + self.iterator = iterator + import concurrent.futures + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.prefetch() + + def prefetch(self): + self.prefetch_future = self.executor.submit(self._prefetch) + + def _prefetch(self): + return next(self.iterator) + + def __iter__(self): + return self + + def __next__(self): + value = self.prefetch_future.result() + self.prefetch() + return value + + def close(self): + self.executor.shutdown() + + for batch_idx, batch_items in PrefetchIterator(enumerate(progress_bar_train_loader)): if expected_iterations <= batch_idx: break