Skip to content

Commit

Permalink
prefetch next batch from the data loader while running current batch
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed May 10, 2024
1 parent 9e73792 commit afa9eba
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,31 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl

context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics)

for batch_idx, batch_items in enumerate(progress_bar_train_loader):
class PrefetchIterator:
def __init__(self, iterator):
self.iterator = iterator
import concurrent.futures
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.prefetch()

def prefetch(self):
self.prefetch_future = self.executor.submit(self._prefetch)

def _prefetch(self):
return next(self.iterator)

def __iter__(self):
return self

def __next__(self):
value = self.prefetch_future.result()
self.prefetch()
return value

def close(self):
self.executor.shutdown()

for batch_idx, batch_items in PrefetchIterator(enumerate(progress_bar_train_loader)):
if expected_iterations <= batch_idx:
break

Expand Down

0 comments on commit afa9eba

Please sign in to comment.