Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fix #5132 (#5134)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Apr 20, 2021
1 parent 2526674 commit 24ec7db
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Updated CONTRIBUTING.md to remind reader to upgrade pip setuptools to avoid spaCy installation issues.

### Fixed

- Fixed a bug with the `ShardedDatasetReader` when used with multi-process data loading (https://github.com/allenai/allennlp/issues/5132).


## [v2.3.0](https://github.com/allenai/allennlp/releases/tag/v2.3.0) - 2021-04-14

Expand Down
12 changes: 11 additions & 1 deletion allennlp/data/dataset_readers/sharded_dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
from typing import Iterable

from overrides import overrides

from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
Expand Down Expand Up @@ -46,12 +48,18 @@ def __init__(self, base_reader: DatasetReader, **kwargs) -> None:
self.reader._set_worker_info(None)
self.reader._set_distributed_info(None)

@overrides
def text_to_instance(self, *args, **kwargs) -> Instance:
"""
Just delegate to the base reader text_to_instance.
"""
return self.reader.text_to_instance(*args, **kwargs) # type: ignore

@overrides
def apply_token_indexers(self, instance: Instance) -> None:
self.reader.apply_token_indexers(instance)

@overrides
def _read(self, file_path: PathOrStr) -> Iterable[Instance]:
try:
maybe_extracted_archive = cached_path(file_path, extract_archive=True)
Expand All @@ -76,5 +84,7 @@ def _read(self, file_path: PathOrStr) -> Iterable[Instance]:

for shard in self.shard_iterable(shards):
logger.info(f"reading instances from {shard}")
for instance in self.reader.read(shard):
# We call `self.reader._read()` here instead of `self.reader.read()` because `.read()`
# will prematurely call `self.reader.apply_token_indexers()`.
for instance in self.reader._read(shard):
yield instance
11 changes: 9 additions & 2 deletions tests/data/dataset_readers/sharded_dataset_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.data_loaders import MultiProcessDataLoader
from allennlp.data.dataset_readers import (
SequenceTaggingDatasetReader,
ShardedDatasetReader,
Expand Down Expand Up @@ -51,9 +52,12 @@ def setup_method(self) -> None:

self.reader = ShardedDatasetReader(base_reader=self.base_reader)

def read_and_check_instances(self, filepath: str):
def read_and_check_instances(self, filepath: str, num_workers: int = 0):
data_loader = MultiProcessDataLoader(
self.reader, filepath, num_workers=num_workers, batch_size=1
)
all_instances = []
for instance in self.reader.read(filepath):
for instance in data_loader.iter_instances():
all_instances.append(instance)

# 100 files * 4 sentences / file
Expand All @@ -71,5 +75,8 @@ def read_and_check_instances(self, filepath: str):
def test_sharded_read_glob(self):
self.read_and_check_instances(self.identical_files_glob)

def test_sharded_read_with_multiprocess_loader(self):
self.read_and_check_instances(self.identical_files_glob, num_workers=2)

def test_sharded_read_archive(self):
self.read_and_check_instances(str(self.archive_filename))

0 comments on commit 24ec7db

Please sign in to comment.