Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a check for way < num_class #16

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ EQUINE was created to simplify two kinds of uncertainty quantification for super
2) An in-distribution score, indicating whether any of the model's known labels should be trusted.

Dive into our [documentation examples](https://mit-ll-responsible-ai.github.io/equine/)
to get started. Additionally, we provide a [companion web application](https://mit-ll-responsible-ai.github.io/equine-webapp/).
to get started. Additionally, we provide a [companion web application](https://github.com/mit-ll-responsible-ai/equine-webapp).

## Installation
Users are recommended to install a virtual environment such as Anaconda, as is also recommended
Expand Down
52 changes: 20 additions & 32 deletions docs/example_notebooks/vnat_example.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions src/equine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ def generate_episode(
Tuple of support examples, query examples, and query labels.
"""
labels = torch.unique(train_y)
if way > len(labels):
raise ValueError(
f"The way (#classes in each episode), {way}, must be <= number of labels, {len(labels)}"
)

selected_labels = sorted(
labels[torch.randperm(labels.shape[0])][:way].tolist()
) # need to be in same order every time
Expand Down
14 changes: 13 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hypothesis import given
from hypothesis import strategies as st
import hypothesis.extra.numpy as hnp
import pytest


@st.composite
Expand Down Expand Up @@ -40,10 +41,21 @@ def support_dataset(draw):

@given(dataset=support_dataset())
def test_generate_support(dataset) -> None:
train_x, train_y, support_sz, tasks, way = dataset
train_x, train_y, support_sz, tasks, _ = dataset
eq.utils.generate_support(train_x, train_y, support_sz, tasks)


@given(dataset=support_dataset())
def test_generate_episode(dataset) -> None:
train_x, train_y, support_sz, tasks, way = dataset
episode_size = max(len(tasks), train_x.shape[0] // 4)
eq.utils.generate_episode(train_x, train_y, support_sz, way, episode_size)
with pytest.raises(ValueError):
eq.utils.generate_episode(
train_x, train_y, support_sz, len(tasks) + 1, episode_size
)


@st.composite
def draw_two_tensors(draw):
num_classes = draw(st.integers(min_value=2, max_value=128))
Expand Down