Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consensus sampling decomposition and per-group standard error #211

Merged
merged 9 commits into from
Jan 20, 2025
8 changes: 7 additions & 1 deletion ldp/alg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .algorithms import compute_pass_at_k, evaluate_consensus, to_network
from .algorithms import (
bulk_evaluate_consensus,
compute_pass_at_k,
evaluate_consensus,
to_network,
)
from .beam_search import Beam, BeamSearchRollout
from .callbacks import (
Callback,
Expand Down Expand Up @@ -45,6 +50,7 @@
"TrajectoryMetricsCallback",
"TreeSearchRollout",
"WandBLoggingCallback",
"bulk_evaluate_consensus",
"compute_pass_at_k",
"evaluate_consensus",
"to_network",
Expand Down
135 changes: 102 additions & 33 deletions ldp/alg/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,70 +127,139 @@ def gvizify(x: Any) -> str:
return G


def measure_consensus_proportion(
consensus_size: int, sample_size: int
) -> tuple[float, float]:
"""Aggregate sampling accuracy mean and standard error on consensus proportion."""
sample_acc_mean = consensus_size / sample_size
# Use binomial standard error since we're comparing ideal vs not ideal
sample_acc_ste = np.sqrt(sample_acc_mean * (1.0 - sample_acc_mean) / sample_size)
return sample_acc_mean, sample_acc_ste


TData = TypeVar("TData")
TGroupKey = TypeVar("TGroupKey", bound=Hashable)
TAnswer = TypeVar("TAnswer")
NO_IDEAL_ANSWER_FN: Literal["NO_IDEAL_ANSWER_FN"] = "NO_IDEAL_ANSWER_FN" # Sentinel


async def evaluate_consensus(
async def bulk_evaluate_consensus(
data: Iterable[TData],
grouping_fn: Callable[[TData], TGroupKey],
extract_answer_fn: Callable[[TData], TAnswer | Awaitable[TAnswer]],
num_samples: int | None = None,
seed: np.random.Generator | random.Random | int | None = None,
ideal_answer_fn: (
Callable[[TData], TAnswer] | Literal["NO_IDEAL_ANSWER_FN"]
) = NO_IDEAL_ANSWER_FN,
num_samples: int = 1,
seed: int | None = None,
consensus_callback: Callable[[TAnswer, int, int], Any] | None = None,
) -> tuple[dict[TGroupKey, list[tuple[TAnswer, int]]], float]:
"""
Create consensus groups and evaluate the consensus accuracy for each one.
Create consensus groups and evaluate the consensus accuracy for each.

Args:
data: Data to evaluate consensus upon, length is called N.
grouping_fn: Function to extract the group key from a datum.
extract_answer_fn: Function to extract the actual answer from a datum. It can
be async so this can be done using a LLM call.
data: Flattened data to evaluate consensus upon. Think of this as all results
from at least one evaluation upon a TaskDataset.
grouping_fn: Function to extract the group key from a datum. For a QA dataset,
the group key could be the question or question ID.
extract_answer_fn: Passed through to evaluate_consensus.
num_samples: Passed through to evaluate_consensus.
seed: Passed through to evaluate_consensus.
ideal_answer_fn: Optional function to extract the ideal answer from a datum to
compute accuracy with, or pass NO_IDEAL_ANSWER to skip this calculation.
num_samples: Number of samples to choose from the N total.
seed: Optional seed for sampling.
consensus_callback: Passed through to evaluate_consensus.

Returns:
Two-tuple of consensus list generated by collections.Counter.most_common and
the proportion of groups for which the consensus matches the ideal.
Two-tuple of (1) a dictionary mapping group keys to consensus list
(SEE evaluate_consensus.__doc__'s Returns for more details), and
(2) the proportion of groups for which the consensus matches the ideal.
"""
groups = collections.defaultdict(list)
for x in data:
groups[grouping_fn(x)].append(x)

ideal_count = 0
grouped_consensus: dict[TGroupKey, list[tuple[TAnswer, int]]] = {}
rand = random.Random(seed) if seed is not None else random
for group_key, group in groups.items():
if len(group) < num_samples: # Too few items, sample with replacement
sampled = [rand.choice(group) for _ in range(num_samples)]
else: # Sample without replacement
sampled = rand.sample(group, num_samples)

async def extract_answer(datum: TData) -> TAnswer:
answer = extract_answer_fn(datum)
return await answer if inspect.isawaitable(answer) else answer

# Get answers for the sampled data
answers = await asyncio.gather(*(extract_answer(x) for x in sampled))

# Compute consensus: mode of the sampled answers
grouped_consensus[group_key] = collections.Counter(answers).most_common()
# NOTE: If there are multiple modes, just use the first one
consensus: TAnswer = grouped_consensus[group_key][0][0]
if ideal_answer_fn != NO_IDEAL_ANSWER_FN:

async def add_consensus_check_ideal(
group_key: TGroupKey, group: list[TData]
) -> int:
grouped_consensus[group_key], consensus = await evaluate_consensus(
group, extract_answer_fn, num_samples, seed, consensus_callback
)
if ideal_answer_fn != NO_IDEAL_ANSWER_FN: # If we have an ideal
# Assume all items in the group have the same ideal answer
ideal_count += consensus == ideal_answer_fn(group[0])
return consensus == ideal_answer_fn(group[0])
return 0

ideal_count = sum(
await asyncio.gather(
*itertools.starmap(add_consensus_check_ideal, groups.items())
)
)
return grouped_consensus, ideal_count / len(groups) if groups else 0.0


async def evaluate_consensus(
data: Sequence[TData],
extract_answer_fn: Callable[[TData], TAnswer | Awaitable[TAnswer]],
num_samples: int | None = None,
seed: np.random.Generator | random.Random | int | None = None,
consensus_callback: Callable[[TAnswer, int, int], Any] | None = None,
) -> tuple[list[tuple[TAnswer, int]], TAnswer]:
"""
Create consensus bins given data.

Args:
data: Data to evaluate consensus upon, length is called N.
extract_answer_fn: Function to extract the actual answer from a datum. It can
be async so this can be done using a LLM call.
num_samples: Number of samples to choose from the N total, or None (default) to
infer this value to match N.
seed: Optional seed for sampling.
consensus_callback: Optional callback function called just after computing
consensus, it's passed the consensus answer, consensus size, and
sample size. This is useful for stuff like summary statistics.

Returns:
Two-tuple containing (1) an ordered list of `(extracted_answer, count)` tuples,
ordered by highest `count` first, and (2) the consensus answer
(i.e. the `extracted_answer` with the highest `count`).
"""
rand = (
seed
if isinstance(seed, np.random.Generator | random.Random)
else np.random.default_rng(seed)
)
if num_samples is None:
num_samples = len(data)
if len(data) < num_samples: # Too few items, sample with replacement
raise ValueError(
f"A number of samples {num_samples} exceeding the {len(data)} data points"
" present is disallowed since sampling with replacement can produce"
" misleading consensus. Imagine if there was 1 data point, but 100 samples,"
" this would report perfect consensus with low standard error (but it would"
" be statistically artificial)."
)
if isinstance(rand, random.Random): # Built-in random sample without replacement
sampled: Iterable[TData] = rand.sample(data, num_samples)
else: # NumPy random sample without replacement
sampled = rand.choice(data, num_samples, replace=False) # type: ignore[arg-type]

async def extract_answer(datum: TData) -> TAnswer:
answer = extract_answer_fn(datum)
return await answer if inspect.isawaitable(answer) else answer

# Get answers for the sampled data
answers = await asyncio.gather(*(extract_answer(x) for x in sampled))
# Compute consensus: mode of the sampled answers
most_common = collections.Counter(answers).most_common()
# NOTE: If there are multiple modes, just use the first one
consensus_answer, consensus_count = most_common[0]
if consensus_callback:
consensus_callback(consensus_answer, consensus_count, num_samples)
return most_common, consensus_answer


def compute_pass_at_k(n: int, c: int, k: int) -> float:
"""Compute an unbiased estimation for 'pass @ k'.

Expand Down
75 changes: 59 additions & 16 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import operator

import numpy as np
import pytest
from aviary.core import DummyEnv
from aviary.utils import MultipleChoiceQuestion

from ldp.agent import SimpleAgent
from ldp.alg import compute_pass_at_k, evaluate_consensus
from ldp.alg import bulk_evaluate_consensus, compute_pass_at_k
from ldp.alg.algorithms import measure_consensus_proportion
from ldp.utils import discounted_returns


Expand Down Expand Up @@ -39,7 +41,7 @@ async def test_rollout_and_discounting(dummy_env: DummyEnv) -> None:


@pytest.mark.asyncio
async def test_evaluate_consensus() -> None:
async def test_consensus_evaluation() -> None:
# We have two questions, so let's group based on question
question_1 = MultipleChoiceQuestion(
question="What is the meaning of life?",
Expand All @@ -57,7 +59,7 @@ async def test_evaluate_consensus() -> None:
ideal_answer="8",
)
data_with_several_groups: list[tuple[MultipleChoiceQuestion, str]] = [
# Correct consensus
# Has consensus and it was correct
(question_1, "-84"),
(question_1, "11"),
(question_1, "11"),
Expand All @@ -68,7 +70,7 @@ async def test_evaluate_consensus() -> None:
(question_1, "42"),
(question_1, "42"),
(question_1, "42"),
# Correct consensus
# Has consensus and it was correct
(question_2, "brownie"),
(question_2, "brownie"),
(question_2, "apple"),
Expand All @@ -77,36 +79,77 @@ async def test_evaluate_consensus() -> None:
(question_2, "apple"),
(question_2, "apple"),
(question_2, "apple"),
# Incorrect consensus
# Has no consensus and regardless it's incorrect
(question_3, "1"),
(question_3, "2"),
(question_3, "1"),
(question_3, "2"),
(question_3, "4"),
]
# NOTE: this consensus is sensitive to seed
expected_consensus = {
question_1.question: [("42", 3), ("11", 1), ("-84", 1)],
question_2.question: [("apple", 4), ("brownie", 1)],
question_3.question: [("1", 3), ("2", 2)],
question_1.question: (
[("42", 3), ("cheesecake", 1), ("-84", 1)],
3 / 5,
0.2190890,
),
question_2.question: ([("apple", 4), ("brownie", 1)], 4 / 5, 0.1788854),
question_3.question: ([("2", 2), ("1", 2), ("4", 1)], 2 / 5, 0.2190890),
}
stored_accuracy_mean_ste: list[tuple[float, float]] = []

def append_accuracy_metrics(
consensus_answer: str, # noqa: ARG001
consensus_size: int,
sample_size: int,
) -> None:
stored_accuracy_mean_ste.append(
measure_consensus_proportion(consensus_size, sample_size)
)

# Check accuracy is 0% without an ideal answer
assert await evaluate_consensus(
groups, accuracy = await bulk_evaluate_consensus(
data_with_several_groups,
grouping_fn=lambda x: x[0].question,
extract_answer_fn=operator.itemgetter(1),
num_samples=5,
seed=42,
) == (expected_consensus, 0.0)
seed=np.random.default_rng(42),
extract_answer_fn=operator.itemgetter(1),
consensus_callback=append_accuracy_metrics,
)
assert len(groups) == len(expected_consensus)
for (q, (consensus, acc_mean, acc_ste)), actual_acc_ste in zip(
expected_consensus.items(), stored_accuracy_mean_ste, strict=True
):
assert groups[q] == consensus
assert actual_acc_ste == (pytest.approx(acc_mean), pytest.approx(acc_ste))
assert accuracy == 0.0, "Can't compute accuracy without an ideal answer"
stored_accuracy_mean_ste.clear() # Prepare for next batch of assertions

# Check accuracy is present when we can get an ideal answer
assert await evaluate_consensus(
groups, accuracy = await bulk_evaluate_consensus(
data_with_several_groups,
grouping_fn=lambda x: x[0].question,
extract_answer_fn=operator.itemgetter(1),
ideal_answer_fn=lambda x: x[0].ideal_answer,
num_samples=5,
seed=42,
) == (expected_consensus, 2 / 3)
seed=np.random.default_rng(42),
extract_answer_fn=operator.itemgetter(1),
consensus_callback=append_accuracy_metrics,
)
assert len(groups) == len(expected_consensus)
for (q, (consensus, acc_mean, acc_ste)), actual_acc_ste in zip(
expected_consensus.items(), stored_accuracy_mean_ste, strict=True
):
assert groups[q] == consensus
assert actual_acc_ste == (pytest.approx(acc_mean), pytest.approx(acc_ste))
assert accuracy == 2 / 3

with pytest.raises(ValueError, match="sampling with replacement"):
await bulk_evaluate_consensus(
data_with_several_groups,
grouping_fn=lambda x: x[0].question,
num_samples=10, # Sampling with replacement is disallowed
extract_answer_fn=operator.itemgetter(1),
)


@pytest.mark.parametrize(
Expand Down
Loading