From e4c22731374919bcc0e1c32146abddabd820fc00 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 7 Nov 2023 11:47:11 -0800 Subject: [PATCH] Make a dataset shuffler which can shuffle per batch without reading all of the batches into memory first. Saves memory on some of the excessively large datasets, such as DE_HDT --- stanza/models/pos/data.py | 16 ++++++++++++++++ stanza/models/tagger.py | 14 +++----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/stanza/models/pos/data.py b/stanza/models/pos/data.py index 567d2605e7..07ccd4e2b7 100644 --- a/stanza/models/pos/data.py +++ b/stanza/models/pos/data.py @@ -280,4 +280,20 @@ def resolve_none(data): data[sent_idx][tok_idx][feat_idx] = '_' return data +class ShuffledDataset: + def __init__(self, datasets, batch_size): + self.batch_size = batch_size + self.datasets = datasets + self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets] + def __iter__(self): + iterators = [iter(x) for x in self.loaders] + lengths = [len(x) for x in self.loaders] + indices = [[x] * y for x, y in enumerate(lengths)] + indices = [idx for inner in indices for idx in inner] + + for idx in indices: + yield(next(iterators[idx])) + + def __len__(self): + return sum(len(x) for x in self.datasets) diff --git a/stanza/models/tagger.py b/stanza/models/tagger.py index 7212538edd..a315154293 100644 --- a/stanza/models/tagger.py +++ b/stanza/models/tagger.py @@ -19,7 +19,7 @@ from torch import nn, optim import stanza.models.pos.data as data -from stanza.models.pos.data import Dataset +from stanza.models.pos.data import Dataset, ShuffledDataset from stanza.models.pos.trainer import Trainer from stanza.models.pos import scorer from stanza.models.common import utils @@ -205,8 +205,7 @@ def load_training_data(args, pretrain): for td in train_data: td.has_feats = True # calculate the batches - train_batches = [i.to_loader(batch_size=args["batch_size"], shuffle=True) - for i in train_data] + train_batches = ShuffledDataset(train_data, args["batch_size"]) return vocab, train_data, train_batches def train(args): @@ -284,14 +283,7 @@ def train(args): trainer.model.log_norms() while True: do_break = False - # we now merge all train batches together into one giant list - # this allows us to mix batches which have or don't have individual training columns, - # such as if XPOS or UPOS are missing from a training file, - # as we shuffle all of those batches together - # the downside being that it loses the efficiency benefit of the pytorch dataloader - all_train_batches = [x for train_batch in train_batches for x in iter(train_batch)] - random.shuffle(all_train_batches) - for i, batch in enumerate(all_train_batches): + for i, batch in enumerate(train_batches): start_time = time.time() global_step += 1 loss = trainer.update(batch, eval=False) # update step