From 644ef22e90c8493af8e14e81a1e10839a8f0b62a Mon Sep 17 00:00:00 2001 From: Mark Neumann Date: Thu, 5 Mar 2020 10:38:15 -0800 Subject: [PATCH] Sorting keys api (#3902) * new idea for sorting * add len to all fields * update references to sorting keys --- allennlp/data/fields/adjacency_field.py | 3 ++ allennlp/data/fields/array_field.py | 3 ++ allennlp/data/fields/field.py | 3 ++ allennlp/data/fields/index_field.py | 3 ++ allennlp/data/fields/label_field.py | 3 ++ allennlp/data/fields/multilabel_field.py | 3 ++ .../data/fields/namespace_swapping_field.py | 3 ++ allennlp/data/fields/span_field.py | 3 ++ .../data/samplers/bucket_batch_sampler.py | 47 +++++++++---------- allennlp/tests/common/params_test.py | 4 +- .../samplers/bucket_batch_sampler_test.py | 24 +++------- .../coref/coref_bert_lstm_small.jsonnet | 2 +- .../fixtures/simple_tagger/experiment.json | 2 +- .../predicting_paper_venues_pt1.md | 2 +- training_config/coref_bert_lstm.jsonnet | 2 +- training_config/coref_spanbert_large.jsonnet | 2 +- ...tanford_sentiment_treebank_roberta.jsonnet | 2 +- tutorials/tagger/exercise.jsonnet | 2 +- tutorials/tagger/experiment.jsonnet | 2 +- 19 files changed, 62 insertions(+), 53 deletions(-) diff --git a/allennlp/data/fields/adjacency_field.py b/allennlp/data/fields/adjacency_field.py index afc646c6f34..570c315f933 100644 --- a/allennlp/data/fields/adjacency_field.py +++ b/allennlp/data/fields/adjacency_field.py @@ -143,3 +143,6 @@ def __str__(self) -> str: f"\t\twith indices:\n {formatted_indices}\n" f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." ) + + def __len__(self): + return len(self.sequence_field) diff --git a/allennlp/data/fields/array_field.py b/allennlp/data/fields/array_field.py index 5438623fc68..fedffd5cd8c 100644 --- a/allennlp/data/fields/array_field.py +++ b/allennlp/data/fields/array_field.py @@ -58,3 +58,6 @@ def empty_field(self): def __str__(self) -> str: return f"ArrayField with shape: {self.array.shape} and dtype: {self.dtype}." + + def __len__(self): + return self.array.shape[0] diff --git a/allennlp/data/fields/field.py b/allennlp/data/fields/field.py index 0405c78a123..358bd89937e 100644 --- a/allennlp/data/fields/field.py +++ b/allennlp/data/fields/field.py @@ -117,3 +117,6 @@ def __eq__(self, other) -> bool: if isinstance(self, other.__class__): return self.__dict__ == other.__dict__ return NotImplemented + + def __len__(self): + raise NotImplementedError diff --git a/allennlp/data/fields/index_field.py b/allennlp/data/fields/index_field.py index edca6d54aa8..7017557d458 100644 --- a/allennlp/data/fields/index_field.py +++ b/allennlp/data/fields/index_field.py @@ -60,3 +60,6 @@ def __eq__(self, other) -> bool: if isinstance(other, int): return self.sequence_index == other return super().__eq__(other) + + def __len__(self): + return 1 diff --git a/allennlp/data/fields/label_field.py b/allennlp/data/fields/label_field.py index 6816045e47f..fef5f20539d 100644 --- a/allennlp/data/fields/label_field.py +++ b/allennlp/data/fields/label_field.py @@ -105,3 +105,6 @@ def empty_field(self): def __str__(self) -> str: return f"LabelField with label: {self.label} in namespace: '{self._label_namespace}'.'" + + def __len__(self): + return 1 diff --git a/allennlp/data/fields/multilabel_field.py b/allennlp/data/fields/multilabel_field.py index 4a5cd468b21..5402fbfe189 100644 --- a/allennlp/data/fields/multilabel_field.py +++ b/allennlp/data/fields/multilabel_field.py @@ -136,3 +136,6 @@ def __str__(self) -> str: return ( f"MultiLabelField with labels: {self.labels} in namespace: '{self._label_namespace}'.'" ) + + def __len__(self): + return 1 diff --git a/allennlp/data/fields/namespace_swapping_field.py b/allennlp/data/fields/namespace_swapping_field.py index bf2ff12a09f..2b588aa3ab7 100644 --- a/allennlp/data/fields/namespace_swapping_field.py +++ b/allennlp/data/fields/namespace_swapping_field.py @@ -48,3 +48,6 @@ def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: @overrides def empty_field(self) -> "NamespaceSwappingField": return NamespaceSwappingField([], self._target_namespace) + + def __len__(self): + return len(self._source_tokens) diff --git a/allennlp/data/fields/span_field.py b/allennlp/data/fields/span_field.py index 5191d57b77a..7b5de4205e7 100644 --- a/allennlp/data/fields/span_field.py +++ b/allennlp/data/fields/span_field.py @@ -68,3 +68,6 @@ def __eq__(self, other) -> bool: if isinstance(other, tuple) and len(other) == 2: return other == (self.span_start, self.span_end) return super().__eq__(other) + + def __len__(self): + return 2 diff --git a/allennlp/data/samplers/bucket_batch_sampler.py b/allennlp/data/samplers/bucket_batch_sampler.py index ed18c2509b5..dcc449126af 100644 --- a/allennlp/data/samplers/bucket_batch_sampler.py +++ b/allennlp/data/samplers/bucket_batch_sampler.py @@ -1,16 +1,23 @@ import logging -from typing import List, Iterable, Tuple, Dict, cast - +from typing import List, Iterable +import random import math + from torch.utils import data -from allennlp.common.util import add_noise_to_dict_values, lazy_groups_of +from allennlp.common.util import lazy_groups_of from allennlp.data.instance import Instance from allennlp.data.samplers import BatchSampler logger = logging.getLogger(__name__) +def add_noise_to_value(value: int, noise_param: float): + noise_value = value * noise_param + noise = random.uniform(-noise_value, noise_value) + return value + noise + + @BatchSampler.register("bucket") class BucketBatchSampler(BatchSampler): """ @@ -26,7 +33,7 @@ class BucketBatchSampler(BatchSampler): The pytorch `Dataset` of allennlp Instances to bucket. batch_size : int, required. The size of each batch of instances yielded when calling the dataloader. - sorting_keys : List[Tuple[str, str]], optional + sorting_keys : List[str], optional To bucket inputs into batches, we want to group the instances by padding length, so that we minimize the amount of padding necessary per batch. In order to do this, we need to know which fields need what type of padding, and in what order. @@ -54,7 +61,7 @@ def __init__( self, data_source: data.Dataset, batch_size: int, - sorting_keys: List[Tuple[str, str]] = None, + sorting_keys: List[str] = None, padding_noise: float = 0.1, drop_last: bool = False, ): @@ -79,19 +86,10 @@ def _argsort_by_padding(self, instances: Iterable[Instance]) -> List[int]: instances_with_lengths = [] for instance in instances: # Make sure instance is indexed before calling .get_padding - instance.index_fields(self.vocab) - padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) - if self.padding_noise > 0.0: - noisy_lengths = {} - for field_name, field_lengths in padding_lengths.items(): - noisy_lengths[field_name] = add_noise_to_dict_values( - field_lengths, self.padding_noise - ) - padding_lengths = noisy_lengths instance_with_lengths = ( [ - padding_lengths[field_name][padding_key] - for (field_name, padding_key) in self.sorting_keys + add_noise_to_value(len(instance.fields.get(field_name)), self.padding_noise) + for field_name in self.sorting_keys ], instance, ) @@ -124,27 +122,26 @@ def _guess_sorting_keys(self, instances: Iterable[Instance], num_instances: int are not homogeneous, you might need more. """ max_length = 0.0 - longest_padding_key: Tuple[str, str] = None + longest_field: str = None for i, instance in enumerate(instances): instance.index_fields(self.vocab) - padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) - for field_name, field_padding in padding_lengths.items(): - for padding_key, length in field_padding.items(): - if length > max_length: - max_length = length - longest_padding_key = (field_name, padding_key) + for field_name, field in instance.fields.items(): + length = len(field) + if length > max_length: + max_length = length + longest_field = field_name if i > num_instances: # Only use num_instances instances to guess the sorting keys. break - if not longest_padding_key: + if not longest_field: # This shouldn't ever happen (you basically have to have an empty instance list), but # just in case... raise AssertionError( "Found no field that needed padding; we are surprised you got this error, please " "open an issue on github" ) - self.sorting_keys = [longest_padding_key] + self.sorting_keys = [longest_field] def __len__(self): batch_count_float = len(self.data_source) / self.batch_size diff --git a/allennlp/tests/common/params_test.py b/allennlp/tests/common/params_test.py index 68d8d61f6db..600656d679c 100644 --- a/allennlp/tests/common/params_test.py +++ b/allennlp/tests/common/params_test.py @@ -38,14 +38,14 @@ def test_overrides(self): overrides = ( '{ "train_data_path": "FOO", "model": { "type": "BAR" },' '"model.text_field_embedder.tokens.type": "BAZ",' - '"data_loader.batch_sampler.sorting_keys.0.0": "question"}' + '"data_loader.batch_sampler.sorting_keys.0": "question"}' ) params = Params.from_file(filename, overrides) assert "dataset_reader" in params assert "trainer" in params assert params["train_data_path"] == "FOO" - assert params["data_loader"]["batch_sampler"]["sorting_keys"][0][0] == "question" + assert params["data_loader"]["batch_sampler"]["sorting_keys"][0] == "question" model_params = params.pop("model") assert model_params.pop("type") == "BAR" diff --git a/allennlp/tests/data/samplers/bucket_batch_sampler_test.py b/allennlp/tests/data/samplers/bucket_batch_sampler_test.py index 1074c2178bd..d8a232aec4d 100644 --- a/allennlp/tests/data/samplers/bucket_batch_sampler_test.py +++ b/allennlp/tests/data/samplers/bucket_batch_sampler_test.py @@ -88,9 +88,7 @@ class TestBucketSampler(SamplerTest): def test_create_batches_groups_correctly(self): dataset = AllennlpDataset(self.instances, vocab=self.vocab) - sampler = BucketBatchSampler( - dataset, batch_size=2, padding_noise=0, sorting_keys=[("text", "tokens___tokens")] - ) + sampler = BucketBatchSampler(dataset, batch_size=2, padding_noise=0, sorting_keys=["text"]) grouped_instances = [] for indices in sampler: @@ -133,13 +131,13 @@ def test_guess_sorting_key_picks_the_longest_key(self): ) assert sampler.sorting_keys is None sampler._guess_sorting_keys(instances) - assert sampler.sorting_keys == [("passage", "tokens___tokens")] + assert sampler.sorting_keys == ["passage"] def test_from_params(self): dataset = AllennlpDataset(self.instances, self.vocab) params = Params({}) - sorting_keys = [("s1", "nt"), ("s2", "nt2")] + sorting_keys = ["s1", "s2"] params["sorting_keys"] = sorting_keys params["batch_size"] = 32 sampler = BucketBatchSampler.from_params(params=params, data_source=dataset) @@ -166,11 +164,7 @@ def test_from_params(self): def test_drop_last_works(self): dataset = AllennlpDataset(self.instances, vocab=self.vocab) sampler = BucketBatchSampler( - dataset, - batch_size=2, - padding_noise=0, - sorting_keys=[("text", "tokens___tokens")], - drop_last=True, + dataset, batch_size=2, padding_noise=0, sorting_keys=["text"], drop_last=True, ) # We use a custom collate_fn for testing, which doesn't actually create tensors, # just the allennlp Batches. @@ -186,9 +180,7 @@ def test_drop_last_works(self): def test_batch_count(self): dataset = AllennlpDataset(self.instances, vocab=self.vocab) - sampler = BucketBatchSampler( - dataset, batch_size=2, padding_noise=0, sorting_keys=[("text", "tokens___tokens")] - ) + sampler = BucketBatchSampler(dataset, batch_size=2, padding_noise=0, sorting_keys=["text"]) # We use a custom collate_fn for testing, which doesn't actually create tensors, # just the allennlp Batches. dataloader = DataLoader(dataset, batch_sampler=sampler, collate_fn=lambda x: Batch(x)) @@ -198,11 +190,7 @@ def test_batch_count(self): def test_batch_count_with_drop_last(self): dataset = AllennlpDataset(self.instances, vocab=self.vocab) sampler = BucketBatchSampler( - dataset, - batch_size=2, - padding_noise=0, - sorting_keys=[("text", "tokens___tokens")], - drop_last=True, + dataset, batch_size=2, padding_noise=0, sorting_keys=["text"], drop_last=True, ) # We use a custom collate_fn for testing, which doesn't actually create tensors, # just the allennlp Batches. diff --git a/allennlp/tests/fixtures/coref/coref_bert_lstm_small.jsonnet b/allennlp/tests/fixtures/coref/coref_bert_lstm_small.jsonnet index 12dd663c2c1..d93b80168e1 100644 --- a/allennlp/tests/fixtures/coref/coref_bert_lstm_small.jsonnet +++ b/allennlp/tests/fixtures/coref/coref_bert_lstm_small.jsonnet @@ -78,7 +78,7 @@ local span_pair_embedding_dim = 3 * span_embedding_dim + feature_size; "data_loader": { "batch_sampler": { "type": "bucket", - "sorting_keys": [["text", "tokens___token_ids"]], + "sorting_keys": ["text"], "batch_size": 1, "padding_noise": 0.0 } diff --git a/allennlp/tests/fixtures/simple_tagger/experiment.json b/allennlp/tests/fixtures/simple_tagger/experiment.json index 166b2d125e9..a90581baa56 100644 --- a/allennlp/tests/fixtures/simple_tagger/experiment.json +++ b/allennlp/tests/fixtures/simple_tagger/experiment.json @@ -25,7 +25,7 @@ "data_loader": { "batch_sampler": { "type": "bucket", - "sorting_keys": [["tokens", "tokens___tokens"]], + "sorting_keys": ["tokens"], "padding_noise": 0.0, "batch_size" : 80 } diff --git a/docs/tutorials/getting_started/predicting_paper_venues/predicting_paper_venues_pt1.md b/docs/tutorials/getting_started/predicting_paper_venues/predicting_paper_venues_pt1.md index 5c05b67cbd7..0e610f7a5c8 100644 --- a/docs/tutorials/getting_started/predicting_paper_venues/predicting_paper_venues_pt1.md +++ b/docs/tutorials/getting_started/predicting_paper_venues/predicting_paper_venues_pt1.md @@ -462,7 +462,7 @@ were to split strings into words and represent words as single ids under the nam "validation_data_path": "https://allennlp.s3.amazonaws.com/datasets/academic-papers-example/dev.jsonl", "iterator": { "type": "bucket", - "sorting_keys": [["abstract", "num_tokens"], ["title", "num_tokens"]], + "sorting_keys": ["abstract", "title"], "batch_size": 64 }, "trainer": { diff --git a/training_config/coref_bert_lstm.jsonnet b/training_config/coref_bert_lstm.jsonnet index d574ab01c0e..349321c8d98 100644 --- a/training_config/coref_bert_lstm.jsonnet +++ b/training_config/coref_bert_lstm.jsonnet @@ -76,7 +76,7 @@ local span_pair_embedding_dim = 3 * span_embedding_dim + feature_size; "data_loader": { "batch_sampler": { "type": "bucket", - "sorting_keys": [["text", "tokens___token_ids"]], + "sorting_keys": ["text"], "padding_noise": 0.0, "batch_size": 1 } diff --git a/training_config/coref_spanbert_large.jsonnet b/training_config/coref_spanbert_large.jsonnet index 814523c1221..dcb65b5b833 100644 --- a/training_config/coref_spanbert_large.jsonnet +++ b/training_config/coref_spanbert_large.jsonnet @@ -90,7 +90,7 @@ local span_pair_embedding_dim = 3 * span_embedding_dim + feature_size; "type": "bucket", # Explicitly specifying sorting keys since the guessing heuristic could get it wrong # as we a span field. - "sorting_keys": [["text", "tokens___token_ids"]], + "sorting_keys": ["text"], "batch_size": 1 } }, diff --git a/training_config/stanford_sentiment_treebank_roberta.jsonnet b/training_config/stanford_sentiment_treebank_roberta.jsonnet index c1a920c6c26..a7c311a27ce 100644 --- a/training_config/stanford_sentiment_treebank_roberta.jsonnet +++ b/training_config/stanford_sentiment_treebank_roberta.jsonnet @@ -50,7 +50,7 @@ local cls_is_last_token = false; "data_loader": { "batch_sampler": { "type": "bucket", - "sorting_keys": [["tokens", "tokens___token_ids"]], + "sorting_keys": ["tokens"], "batch_size" : 32 } }, diff --git a/tutorials/tagger/exercise.jsonnet b/tutorials/tagger/exercise.jsonnet index ed4c337e3b5..7ed485ad440 100644 --- a/tutorials/tagger/exercise.jsonnet +++ b/tutorials/tagger/exercise.jsonnet @@ -49,7 +49,7 @@ local learning_rate = 0.1; "iterator": { "type": "bucket", "batch_size": batch_size, - "sorting_keys": [["sentence", "num_tokens"]] + "sorting_keys": ["sentence"] }, "trainer": { "num_epochs": num_epochs, diff --git a/tutorials/tagger/experiment.jsonnet b/tutorials/tagger/experiment.jsonnet index 65f2c77324a..a5efbdab4ac 100644 --- a/tutorials/tagger/experiment.jsonnet +++ b/tutorials/tagger/experiment.jsonnet @@ -34,7 +34,7 @@ local learning_rate = 0.1; "iterator": { "type": "bucket", "batch_size": batch_size, - "sorting_keys": [["sentence", "num_tokens"]] + "sorting_keys": ["sentence"] }, "trainer": { "num_epochs": num_epochs,