Skip to content

Commit

Permalink
Refactored generate_support, restricted typeguard (#18)
Browse files Browse the repository at this point in the history
After merging #16, the tests began failing due to a typeguard error
checking that a dictionary actually produces key, val pairs as
a returned item in `generate_support` in `utils.py`. I used the
opportunity to refactor that call to only ever return a dict by
making it optional to pass in the shuffled indexes (only used
internally in generate_episode, anyway), saving a second call to
`torch.unique`. Upon finding an issue in typeguard that appeared
at the same time as their upstream release and our error, I restricted
the typeguard version to be 4.0 or lower.

Refs #17
  • Loading branch information
nukularrr authored Aug 3, 2023
1 parent 8f04b65 commit 4c9dfda
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 35 deletions.
3 changes: 2 additions & 1 deletion docs/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
2. Install virtual environment (the below example assumes conda)

```shell
conda create --name equine python>=3.10
conda create --name equine python==3.10
conda activate equine
```
We currently support python versions >= 3.9

3. Install the code with the extra `tests` dependencies

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"torchmetrics >= 0.6.0",
"numpy",
"tqdm",
"typeguard >= 3.0, <5.0",
"typeguard >= 3.0, <=4.0",
"icontract",
"scikit-learn", # TODO: remove dependency on train_test_split
"scipy", # TODO: remove dependency on gaussian_kde
Expand Down
2 changes: 1 addition & 1 deletion src/equine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
generate_model_summary,
)

if not TYPE_CHECKING:
if not TYPE_CHECKING: # pragma: no cover
try:
from ._version import version as __version__
except ImportError:
Expand Down
68 changes: 37 additions & 31 deletions src/equine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from typing import Any, List, Union, Tuple, Dict
from typing import Any, List, Union, Tuple
import icontract
import torch
from typeguard import typechecked
Expand Down Expand Up @@ -114,7 +114,7 @@ def _get_shuffle_idxs_by_class(
Returns
-------
Dict[Any, torch.Tensor]
dict[Any, torch.Tensor]
Tensor of indices corresponding to each label.
"""
shuffled_idxs_by_class = OrderedDict()
Expand All @@ -137,22 +137,24 @@ def _get_shuffle_idxs_by_class(
lambda support_size, selected_labels, train_x: support_size * len(selected_labels)
<= len(train_x)
)
@icontract.require(
lambda selected_labels, shuffled_indexes: (
len(shuffled_indexes.keys()) == len(selected_labels)
)
if shuffled_indexes is not None
else True
)
@icontract.ensure(
lambda result, selected_labels, return_indexes: len(result[0].keys())
== len(selected_labels)
if (return_indexes is True)
else len(result.keys()) == len(selected_labels)
lambda result, selected_labels: len(result.keys()) == len(selected_labels)
)
@typechecked
def generate_support(
train_x: torch.Tensor,
train_y: torch.Tensor,
support_size: int,
selected_labels: List,
return_indexes=False,
) -> Union[
dict[Any, torch.Tensor], Tuple[dict[Any, torch.Tensor], dict[Any, torch.Tensor]]
]:
shuffled_indexes: Union[None, dict[Any, torch.Tensor]] = None,
) -> dict[Any, torch.Tensor]:
"""
Randomly select `support_size` examples of `way` classes from the examples in
`train_x` with corresponding labels in `train_y` and return them as a dictionary.
Expand All @@ -167,20 +169,22 @@ def generate_support(
Number of support examples for each class.
selected_labels : List
Selected class labels to generate examples from.
return_indexes : bool, optional
If True, also return the indices of the support examples.
shuffled_indexes: Union[None, dict[Any, torch.Tensor]], optional
Simply use the precomputed indexes if they are available
Returns
-------
Union[Dict[Any, torch.Tensor], Tuple[Dict[Any, torch.Tensor], Dict[Any, torch.Tensor]]]
dict[Any, torch.Tensor]
Ordered dictionary of class labels with corresponding support examples.
"""
labels, counts = torch.unique(train_y, return_counts=True)
for label, count in list(zip(labels, counts)):
if (label in selected_labels) and (count < support_size):
raise ValueError(f"Not enough support examples in class {label}")

shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels)
if shuffled_indexes is None:
for label, count in list(zip(labels, counts)):
if (label in selected_labels) and (count < support_size):
raise ValueError(f"Not enough support examples in class {label}")
shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels)
else:
shuffled_idxs = shuffled_indexes

support = OrderedDict()
for label in selected_labels:
Expand All @@ -192,10 +196,7 @@ def generate_support(
selected_support = shuffled_x[:support_size]
support[label] = selected_support

if return_indexes:
return support, shuffled_idxs
else:
return support
return support


@icontract.require(lambda train_x: len(train_x.shape) == 2)
Expand Down Expand Up @@ -237,10 +238,10 @@ def generate_episode(
Returns
-------
Tuple[Dict[Any, torch.Tensor], torch.Tensor, torch.Tensor]
Tuple[dict[Any, torch.Tensor], torch.Tensor, torch.Tensor]
Tuple of support examples, query examples, and query labels.
"""
labels = torch.unique(train_y)
labels, counts = torch.unique(train_y, return_counts=True)
if way > len(labels):
raise ValueError(
f"The way (#classes in each episode), {way}, must be <= number of labels, {len(labels)}"
Expand All @@ -250,8 +251,13 @@ def generate_episode(
labels[torch.randperm(labels.shape[0])][:way].tolist()
) # need to be in same order every time

support, shuffled_idxs = generate_support(
train_x, train_y, support_size, selected_labels, return_indexes=True
for label, count in list(zip(labels, counts)):
if (label in selected_labels) and (count < support_size):
raise ValueError(f"Not enough support examples in class {label}")
shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels)

support = generate_support(
train_x, train_y, support_size, selected_labels, shuffled_idxs
)

examples_per_task = episode_size // way
Expand Down Expand Up @@ -307,7 +313,7 @@ def generate_model_metrics(
Returns
-------
Dict[str, Any]
dict[str, Any]
Dictionary of model metrics.
"""
pred_y = torch.argmax(eq_preds.classes, dim=1)
Expand All @@ -330,7 +336,7 @@ def generate_model_metrics(
)
@icontract.ensure(lambda result: all(d["numExamples"] >= 0 for d in result))
@typechecked
def get_num_examples_per_label(Y: torch.Tensor) -> List[Dict[str, Any]]:
def get_num_examples_per_label(Y: torch.Tensor) -> List[dict[str, Any]]:
"""
Get the number of examples per label in the given tensor.
Expand All @@ -341,7 +347,7 @@ def get_num_examples_per_label(Y: torch.Tensor) -> List[Dict[str, Any]]:
Returns
-------
List[Dict[str, Any]]
List[dict[str, Any]]
List of dictionaries containing label and number of examples.
"""
tensor_labels, tensor_counts = Y.unique(return_counts=True)
Expand Down Expand Up @@ -374,7 +380,7 @@ def generate_train_summary(
Returns
-------
Dict[str, Any]
dict[str, Any]
Dictionary containing training summary.
"""
train_summary = {
Expand Down Expand Up @@ -408,7 +414,7 @@ def generate_model_summary(
Returns
-------
Dict[str, Any]
dict[str, Any]
Dictionary containing model summary.
"""
summary = generate_model_metrics(eq_preds, test_y)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def support_dataset(draw):


@given(dataset=support_dataset())
def test_generate_support(dataset) -> None:
def test_generate_support(dataset):
train_x, train_y, support_sz, tasks, _ = dataset
eq.utils.generate_support(train_x, train_y, support_sz, tasks)

Expand Down

0 comments on commit 4c9dfda

Please sign in to comment.