Skip to content

Commit

Permalink
Use TypedDict for IndexedInputs
Browse files Browse the repository at this point in the history
Expanded PEP 585 typing.

Addition of TypeErrors and Docstrings in Metrics.

PiperOrigin-RevId: 476073314
  • Loading branch information
RyanMullins authored and LIT team committed Sep 22, 2022
1 parent 4dee37e commit 4760b4d
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 129 deletions.
29 changes: 17 additions & 12 deletions lit_nlp/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import random
from types import MappingProxyType # pylint: disable=g-importing-member
from typing import cast, List, Dict, Optional, Callable, Mapping, Sequence
from typing import cast, Optional, Callable, Mapping, Sequence

from absl import logging

Expand Down Expand Up @@ -49,13 +49,13 @@ class Dataset(object):
"""Base class for LIT datasets."""

_spec: Spec = {}
_examples: List[JsonDict] = []
_examples: list[JsonDict] = []
_description: Optional[str] = None
_base: Optional['Dataset'] = None

def __init__(self,
spec: Optional[Spec] = None,
examples: Optional[List[JsonDict]] = None,
examples: Optional[list[JsonDict]] = None,
description: Optional[str] = None,
base: Optional['Dataset'] = None):
"""Base class constructor.
Expand Down Expand Up @@ -111,7 +111,7 @@ def load(self, path: str):
return self._base.load(path)
pass

def save(self, examples: List[IndexedInput], path: str):
def save(self, examples: list[IndexedInput], path: str):
"""Save newly-created datapoints to disk in a dataset-specific format.
Subclasses should override this method if they wish to save new, persisted
Expand All @@ -134,7 +134,7 @@ def spec(self) -> Spec:
return self._spec

@property
def examples(self) -> List[JsonDict]:
def examples(self) -> list[JsonDict]:
"""Return examples, in format described by spec."""
return self._examples

Expand Down Expand Up @@ -171,7 +171,7 @@ def shuffle(self, seed=42):
# random.shuffle will shuffle in-place; use sample to make a new list.
return self.sample(n=len(self), seed=seed)

def remap(self, field_map: Dict[str, str]):
def remap(self, field_map: dict[str, str]):
"""Return a copy of this dataset with some fields renamed."""
new_spec = utils.remap_dict(self.spec(), field_map)
new_examples = [utils.remap_dict(ex, field_map) for ex in self.examples]
Expand All @@ -194,19 +194,24 @@ def bytes_from_lit_example(lit_example: JsonDict) -> bytes:
class IndexedDataset(Dataset):
"""Dataset with additional indexing information."""

_index: Dict[ExampleId, IndexedInput] = {}
_index: dict[ExampleId, IndexedInput] = {}

def index_inputs(self, examples: List[types.Input]) -> List[IndexedInput]:
def index_inputs(self, examples: list[types.Input]) -> list[IndexedInput]:
"""Create indexed versions of inputs."""
# pylint: disable=g-complex-comprehension not complex, just a line-too-long
return [
IndexedInput({'data': example, 'id': self.id_fn(example), 'meta': {}})
IndexedInput(
data=example,
id=self.id_fn(example),
meta=types.InputMetadata(added=None, parentId=None, source=None))
for example in examples
] # pyformat: disable
]
# pylint: enable=g-complex-comprehension

def __init__(self,
*args,
id_fn: Optional[IdFnType] = None,
indexed_examples: Optional[List[IndexedInput]] = None,
indexed_examples: Optional[list[IndexedInput]] = None,
**kw):
super().__init__(*args, **kw)
assert id_fn is not None, 'id_fn must be specified.'
Expand Down Expand Up @@ -244,7 +249,7 @@ def index(self) -> Mapping[ExampleId, IndexedInput]:
"""Return a read-only view of the index."""
return MappingProxyType(self._index)

def save(self, examples: List[IndexedInput], path: str):
def save(self, examples: list[IndexedInput], path: str):
"""Save newly-created datapoints to disk.
Args:
Expand Down
Loading

0 comments on commit 4760b4d

Please sign in to comment.