Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Sorting keys api (#3902)
Browse files Browse the repository at this point in the history
* new idea for sorting

* add len to all fields

* update references to sorting keys
  • Loading branch information
DeNeutoy authored Mar 5, 2020
1 parent 72acdd1 commit 644ef22
Show file tree
Hide file tree
Showing 19 changed files with 62 additions and 53 deletions.
3 changes: 3 additions & 0 deletions allennlp/data/fields/adjacency_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ def __str__(self) -> str:
f"\t\twith indices:\n {formatted_indices}\n"
f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'."
)

def __len__(self):
return len(self.sequence_field)
3 changes: 3 additions & 0 deletions allennlp/data/fields/array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,6 @@ def empty_field(self):

def __str__(self) -> str:
return f"ArrayField with shape: {self.array.shape} and dtype: {self.dtype}."

def __len__(self):
return self.array.shape[0]
3 changes: 3 additions & 0 deletions allennlp/data/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,6 @@ def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def __len__(self):
raise NotImplementedError
3 changes: 3 additions & 0 deletions allennlp/data/fields/index_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def __eq__(self, other) -> bool:
if isinstance(other, int):
return self.sequence_index == other
return super().__eq__(other)

def __len__(self):
return 1
3 changes: 3 additions & 0 deletions allennlp/data/fields/label_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ def empty_field(self):

def __str__(self) -> str:
return f"LabelField with label: {self.label} in namespace: '{self._label_namespace}'.'"

def __len__(self):
return 1
3 changes: 3 additions & 0 deletions allennlp/data/fields/multilabel_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,6 @@ def __str__(self) -> str:
return (
f"MultiLabelField with labels: {self.labels} in namespace: '{self._label_namespace}'.'"
)

def __len__(self):
return 1
3 changes: 3 additions & 0 deletions allennlp/data/fields/namespace_swapping_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@ def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
@overrides
def empty_field(self) -> "NamespaceSwappingField":
return NamespaceSwappingField([], self._target_namespace)

def __len__(self):
return len(self._source_tokens)
3 changes: 3 additions & 0 deletions allennlp/data/fields/span_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ def __eq__(self, other) -> bool:
if isinstance(other, tuple) and len(other) == 2:
return other == (self.span_start, self.span_end)
return super().__eq__(other)

def __len__(self):
return 2
47 changes: 22 additions & 25 deletions allennlp/data/samplers/bucket_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import logging
from typing import List, Iterable, Tuple, Dict, cast

from typing import List, Iterable
import random
import math

from torch.utils import data

from allennlp.common.util import add_noise_to_dict_values, lazy_groups_of
from allennlp.common.util import lazy_groups_of
from allennlp.data.instance import Instance
from allennlp.data.samplers import BatchSampler

logger = logging.getLogger(__name__)


def add_noise_to_value(value: int, noise_param: float):
noise_value = value * noise_param
noise = random.uniform(-noise_value, noise_value)
return value + noise


@BatchSampler.register("bucket")
class BucketBatchSampler(BatchSampler):
"""
Expand All @@ -26,7 +33,7 @@ class BucketBatchSampler(BatchSampler):
The pytorch `Dataset` of allennlp Instances to bucket.
batch_size : int, required.
The size of each batch of instances yielded when calling the dataloader.
sorting_keys : List[Tuple[str, str]], optional
sorting_keys : List[str], optional
To bucket inputs into batches, we want to group the instances by padding length, so that we
minimize the amount of padding necessary per batch. In order to do this, we need to know
which fields need what type of padding, and in what order.
Expand Down Expand Up @@ -54,7 +61,7 @@ def __init__(
self,
data_source: data.Dataset,
batch_size: int,
sorting_keys: List[Tuple[str, str]] = None,
sorting_keys: List[str] = None,
padding_noise: float = 0.1,
drop_last: bool = False,
):
Expand All @@ -79,19 +86,10 @@ def _argsort_by_padding(self, instances: Iterable[Instance]) -> List[int]:
instances_with_lengths = []
for instance in instances:
# Make sure instance is indexed before calling .get_padding
instance.index_fields(self.vocab)
padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths())
if self.padding_noise > 0.0:
noisy_lengths = {}
for field_name, field_lengths in padding_lengths.items():
noisy_lengths[field_name] = add_noise_to_dict_values(
field_lengths, self.padding_noise
)
padding_lengths = noisy_lengths
instance_with_lengths = (
[
padding_lengths[field_name][padding_key]
for (field_name, padding_key) in self.sorting_keys
add_noise_to_value(len(instance.fields.get(field_name)), self.padding_noise)
for field_name in self.sorting_keys
],
instance,
)
Expand Down Expand Up @@ -124,27 +122,26 @@ def _guess_sorting_keys(self, instances: Iterable[Instance], num_instances: int
are not homogeneous, you might need more.
"""
max_length = 0.0
longest_padding_key: Tuple[str, str] = None
longest_field: str = None
for i, instance in enumerate(instances):
instance.index_fields(self.vocab)
padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths())
for field_name, field_padding in padding_lengths.items():
for padding_key, length in field_padding.items():
if length > max_length:
max_length = length
longest_padding_key = (field_name, padding_key)
for field_name, field in instance.fields.items():
length = len(field)
if length > max_length:
max_length = length
longest_field = field_name
if i > num_instances:
# Only use num_instances instances to guess the sorting keys.
break

if not longest_padding_key:
if not longest_field:
# This shouldn't ever happen (you basically have to have an empty instance list), but
# just in case...
raise AssertionError(
"Found no field that needed padding; we are surprised you got this error, please "
"open an issue on github"
)
self.sorting_keys = [longest_padding_key]
self.sorting_keys = [longest_field]

def __len__(self):
batch_count_float = len(self.data_source) / self.batch_size
Expand Down
4 changes: 2 additions & 2 deletions allennlp/tests/common/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def test_overrides(self):
overrides = (
'{ "train_data_path": "FOO", "model": { "type": "BAR" },'
'"model.text_field_embedder.tokens.type": "BAZ",'
'"data_loader.batch_sampler.sorting_keys.0.0": "question"}'
'"data_loader.batch_sampler.sorting_keys.0": "question"}'
)
params = Params.from_file(filename, overrides)

assert "dataset_reader" in params
assert "trainer" in params
assert params["train_data_path"] == "FOO"
assert params["data_loader"]["batch_sampler"]["sorting_keys"][0][0] == "question"
assert params["data_loader"]["batch_sampler"]["sorting_keys"][0] == "question"

model_params = params.pop("model")
assert model_params.pop("type") == "BAR"
Expand Down
24 changes: 6 additions & 18 deletions allennlp/tests/data/samplers/bucket_batch_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ class TestBucketSampler(SamplerTest):
def test_create_batches_groups_correctly(self):

dataset = AllennlpDataset(self.instances, vocab=self.vocab)
sampler = BucketBatchSampler(
dataset, batch_size=2, padding_noise=0, sorting_keys=[("text", "tokens___tokens")]
)
sampler = BucketBatchSampler(dataset, batch_size=2, padding_noise=0, sorting_keys=["text"])

grouped_instances = []
for indices in sampler:
Expand Down Expand Up @@ -133,13 +131,13 @@ def test_guess_sorting_key_picks_the_longest_key(self):
)
assert sampler.sorting_keys is None
sampler._guess_sorting_keys(instances)
assert sampler.sorting_keys == [("passage", "tokens___tokens")]
assert sampler.sorting_keys == ["passage"]

def test_from_params(self):
dataset = AllennlpDataset(self.instances, self.vocab)
params = Params({})

sorting_keys = [("s1", "nt"), ("s2", "nt2")]
sorting_keys = ["s1", "s2"]
params["sorting_keys"] = sorting_keys
params["batch_size"] = 32
sampler = BucketBatchSampler.from_params(params=params, data_source=dataset)
Expand All @@ -166,11 +164,7 @@ def test_from_params(self):
def test_drop_last_works(self):
dataset = AllennlpDataset(self.instances, vocab=self.vocab)
sampler = BucketBatchSampler(
dataset,
batch_size=2,
padding_noise=0,
sorting_keys=[("text", "tokens___tokens")],
drop_last=True,
dataset, batch_size=2, padding_noise=0, sorting_keys=["text"], drop_last=True,
)
# We use a custom collate_fn for testing, which doesn't actually create tensors,
# just the allennlp Batches.
Expand All @@ -186,9 +180,7 @@ def test_drop_last_works(self):

def test_batch_count(self):
dataset = AllennlpDataset(self.instances, vocab=self.vocab)
sampler = BucketBatchSampler(
dataset, batch_size=2, padding_noise=0, sorting_keys=[("text", "tokens___tokens")]
)
sampler = BucketBatchSampler(dataset, batch_size=2, padding_noise=0, sorting_keys=["text"])
# We use a custom collate_fn for testing, which doesn't actually create tensors,
# just the allennlp Batches.
dataloader = DataLoader(dataset, batch_sampler=sampler, collate_fn=lambda x: Batch(x))
Expand All @@ -198,11 +190,7 @@ def test_batch_count(self):
def test_batch_count_with_drop_last(self):
dataset = AllennlpDataset(self.instances, vocab=self.vocab)
sampler = BucketBatchSampler(
dataset,
batch_size=2,
padding_noise=0,
sorting_keys=[("text", "tokens___tokens")],
drop_last=True,
dataset, batch_size=2, padding_noise=0, sorting_keys=["text"], drop_last=True,
)
# We use a custom collate_fn for testing, which doesn't actually create tensors,
# just the allennlp Batches.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ local span_pair_embedding_dim = 3 * span_embedding_dim + feature_size;
"data_loader": {
"batch_sampler": {
"type": "bucket",
"sorting_keys": [["text", "tokens___token_ids"]],
"sorting_keys": ["text"],
"batch_size": 1,
"padding_noise": 0.0
}
Expand Down
2 changes: 1 addition & 1 deletion allennlp/tests/fixtures/simple_tagger/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"data_loader": {
"batch_sampler": {
"type": "bucket",
"sorting_keys": [["tokens", "tokens___tokens"]],
"sorting_keys": ["tokens"],
"padding_noise": 0.0,
"batch_size" : 80
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ were to split strings into words and represent words as single ids under the nam
"validation_data_path": "https://allennlp.s3.amazonaws.com/datasets/academic-papers-example/dev.jsonl",
"iterator": {
"type": "bucket",
"sorting_keys": [["abstract", "num_tokens"], ["title", "num_tokens"]],
"sorting_keys": ["abstract", "title"],
"batch_size": 64
},
"trainer": {
Expand Down
2 changes: 1 addition & 1 deletion training_config/coref_bert_lstm.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ local span_pair_embedding_dim = 3 * span_embedding_dim + feature_size;
"data_loader": {
"batch_sampler": {
"type": "bucket",
"sorting_keys": [["text", "tokens___token_ids"]],
"sorting_keys": ["text"],
"padding_noise": 0.0,
"batch_size": 1
}
Expand Down
2 changes: 1 addition & 1 deletion training_config/coref_spanbert_large.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ local span_pair_embedding_dim = 3 * span_embedding_dim + feature_size;
"type": "bucket",
# Explicitly specifying sorting keys since the guessing heuristic could get it wrong
# as we a span field.
"sorting_keys": [["text", "tokens___token_ids"]],
"sorting_keys": ["text"],
"batch_size": 1
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ local cls_is_last_token = false;
"data_loader": {
"batch_sampler": {
"type": "bucket",
"sorting_keys": [["tokens", "tokens___token_ids"]],
"sorting_keys": ["tokens"],
"batch_size" : 32
}
},
Expand Down
2 changes: 1 addition & 1 deletion tutorials/tagger/exercise.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ local learning_rate = 0.1;
"iterator": {
"type": "bucket",
"batch_size": batch_size,
"sorting_keys": [["sentence", "num_tokens"]]
"sorting_keys": ["sentence"]
},
"trainer": {
"num_epochs": num_epochs,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/tagger/experiment.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ local learning_rate = 0.1;
"iterator": {
"type": "bucket",
"batch_size": batch_size,
"sorting_keys": [["sentence", "num_tokens"]]
"sorting_keys": ["sentence"]
},
"trainer": {
"num_epochs": num_epochs,
Expand Down

0 comments on commit 644ef22

Please sign in to comment.