Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes Lambada work with GPT2 #115

Merged
merged 13 commits into from
Feb 21, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Python
__pycache__/
/ai2_catwalk.egg-info/
/catwalk.egg-info/
/build/
/dist/
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed the way we compute SQuAD metrics.
- Fixed wikitext on GPT2
- Fixed lambada on GPT2


## [v0.2.2](https://github.com/allenai/catwalk/releases/tag/v0.2.2) - 2023-01-27
Expand Down
103 changes: 66 additions & 37 deletions catwalk/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,41 +105,61 @@ class ModelInstance:
targets: torch.Tensor

def make_model_instances(
texts: Iterator[str],
eleuther_requests: Iterator,
overlap: int = 1
) -> Iterator[ModelInstance]:
for text in texts:
token_ids = [tokenizer.eos_token_id] + tokenizer.encode(text)
# The next line puts the entire text into GPU memory. In principle this is a problem, because it
# might OOM when the text is long. In practice, that doesn't happen.
token_ids = torch.tensor(token_ids, dtype=torch.long, device=device)
window_start = 0
while True:
window_end = window_start + tokenizer.model_max_length
if window_end > len(token_ids) - 1:
break
for request_or_requests in eleuther_requests:
if isinstance(request_or_requests, tuple):
assert all(r.request_type == "loglikelihood" for r in request_or_requests)
requests = {r.args for r in request_or_requests}
if len(requests) != 1:
raise ValueError("This way of calling GPT can only handle loglikelihood requests.")
context, continuation = next(iter(requests))
token_ids = [tokenizer.eos_token_id] + tokenizer.encode(context, continuation)
continuation_tokens = len(tokenizer.encode(continuation))
assert continuation_tokens < tokenizer.model_max_length
token_ids = token_ids[-tokenizer.model_max_length:]
token_ids = torch.tensor(token_ids, dtype=torch.long, device=device)
yield ModelInstance(
context + continuation,
len(token_ids) - continuation_tokens - 1,
token_ids[:-1],
token_ids[1:])
else:
text = request_or_requests.args[0]
token_ids = [tokenizer.eos_token_id] + tokenizer.encode(text)
# The next line puts the entire text into GPU memory. In principle this is a problem, because it
# might OOM when the text is long. In practice, that doesn't happen.
token_ids = torch.tensor(token_ids, dtype=torch.long, device=device)
window_start = 0
while True:
window_end = window_start + tokenizer.model_max_length
if window_end > len(token_ids) - 1:
break
yield ModelInstance(
text,
1 if window_start == 0 else overlap,
token_ids[window_start:window_end],
token_ids[window_start+1:window_end+1])
window_start += tokenizer.model_max_length
window_start -= overlap
window_end = len(token_ids) - 1
if window_start == 0:
last_batch_context_tokens = 1
else:
new_window_start = window_end - tokenizer.model_max_length
last_batch_context_tokens = window_start - new_window_start + overlap
window_start = new_window_start
del new_window_start
yield ModelInstance(
text,
1 if window_start == 0 else overlap,
last_batch_context_tokens,
token_ids[window_start:window_end],
token_ids[window_start+1:window_end+1])
window_start += tokenizer.model_max_length
window_start -= overlap
window_end = len(token_ids) - 1
if window_start == 0:
last_batch_context_tokens = 1
else:
new_window_start = window_end - tokenizer.model_max_length
last_batch_context_tokens = window_start - new_window_start + overlap
window_start = new_window_start
del new_window_start
yield ModelInstance(
text,
last_batch_context_tokens,
token_ids[window_start:window_end],
token_ids[window_start+1:window_end+1])

def make_model_predictions(model_instances: Iterator[ModelInstance]) -> Iterator[Tuple[str, torch.Tensor]]:

def make_model_predictions(
model_instances: Iterator[ModelInstance]
) -> Iterator[Tuple[str, torch.Tensor, bool]]:
for batch in more_itertools.chunked(model_instances, batch_size):
batch_results = []
with torch.inference_mode():
Expand All @@ -154,23 +174,31 @@ def make_model_predictions(model_instances: Iterator[ModelInstance]) -> Iterator
output[mi.num_context_tokens:],
1,
mi.targets[mi.num_context_tokens:].unsqueeze(-1)).squeeze(-1)
batch_results.append((mi.text, logprobs))
exact_match = torch.equal(
mi.targets[mi.num_context_tokens:],
output[mi.num_context_tokens:].argmax(dim=-1))
batch_results.append((mi.text, logprobs, exact_match))
yield from batch_results

def group_model_predictions(model_predictions: Iterator[Tuple[str, torch.Tensor]]) -> Iterator[Tuple[str, float]]:
def group_model_predictions(
model_predictions: Iterator[Tuple[str, torch.Tensor, bool]]
) -> Iterator[Tuple[str, float, bool]]:
last_text = None
summed_logprobs = 0.0
for text, logprobs in model_predictions:
summed_exact_match = True
for text, logprobs, exact_match in model_predictions:
if last_text is not None and text != last_text:
yield last_text, float(summed_logprobs)
yield last_text, float(summed_logprobs), summed_exact_match
summed_logprobs = 0.0
summed_exact_match = True
summed_logprobs += logprobs.sum()
last_text = text
summed_exact_match &= exact_match
if last_text is not None:
yield last_text, float(summed_logprobs)
yield last_text, float(summed_logprobs), summed_exact_match

model_instances = make_model_instances(
task.convert_instance(instance, InstanceFormat.ELEUTHER_REQUESTS).args[0] for instance in Tqdm.tqdm(
task.convert_instance(instance, InstanceFormat.ELEUTHER_REQUESTS) for instance in Tqdm.tqdm(
instances,
desc="Calculating log probabilities")
)
Expand All @@ -179,12 +207,13 @@ def group_model_predictions(model_predictions: Iterator[Tuple[str, torch.Tensor]

from spacy.lang.en import English
spacy_tokenizer = English().tokenizer
for text, logprob in grouped_predictions:
for text, logprob, exact_match in grouped_predictions:
yield {
"text": text,
"word_perplexity": (logprob, len(spacy_tokenizer(text))),
# bytes aren't characters, but this is what Eleuther calls it
"byte_perplexity": (logprob, len(text)),
"bits_per_byte": (logprob, len(text))
"bits_per_byte": (logprob, len(text)),
"acc": (1 if exact_match else 0,)
}

15 changes: 8 additions & 7 deletions catwalk/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Optional

import datasets
from torchmetrics import MeanMetric

from catwalk.task import InstanceFormat, ENTAILMENT_METRICS, QA_METRICS, Task, \
classification_metrics, BINARY_CLASSIFICATION_METRICS, mc_metrics, PERPLEXITY_METRICS
Expand Down Expand Up @@ -183,13 +184,13 @@
).add_metrics(mc_metrics(2)),
#"coqa": EleutherTask("coqa"), # currently broken in the datasets library
"drop": EleutherTask("drop").add_metrics(QA_METRICS),
"lambada": EleutherTask("lambada_standard"),
"lambada_cloze": EleutherTask("lambada_standard_cloze"),
"lambada_mt_en": EleutherTask("lambada_openai_mt_en"),
"lambada_mt_fr": EleutherTask("lambada_openai_mt_fr"),
"lambada_mt_de": EleutherTask("lambada_openai_mt_de"),
"lambada_mt_it": EleutherTask("lambada_openai_mt_it"),
"lambada_mt_es": EleutherTask("lambada_openai_mt_es"),
"lambada": EleutherTask("lambada_standard").add_metrics(PERPLEXITY_METRICS).add_metric("acc", MeanMetric),
"lambada_cloze": EleutherTask("lambada_standard_cloze").add_metrics(PERPLEXITY_METRICS),
"lambada_mt_en": EleutherTask("lambada_openai_mt_en").add_metrics(PERPLEXITY_METRICS),
"lambada_mt_fr": EleutherTask("lambada_openai_mt_fr").add_metrics(PERPLEXITY_METRICS),
"lambada_mt_de": EleutherTask("lambada_openai_mt_de").add_metrics(PERPLEXITY_METRICS),
"lambada_mt_it": EleutherTask("lambada_openai_mt_it").add_metrics(PERPLEXITY_METRICS),
"lambada_mt_es": EleutherTask("lambada_openai_mt_es").add_metrics(PERPLEXITY_METRICS),
"prost": EleutherTask("prost", ranked_classification=True).add_metrics(mc_metrics(4)),
"mc_taco": EleutherTask("mc_taco", ranked_classification=True).add_metrics(BINARY_CLASSIFICATION_METRICS),
"pubmedqa": EleutherTaskWithRenamedSplits("pubmedqa").add_metrics(BINARY_CLASSIFICATION_METRICS),
Expand Down
16 changes: 16 additions & 0 deletions catwalk/tasks/eleuther.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ def instance_as_rank_classification(
fewshot_instances: Optional[List[Dict[str, Any]]] = None,
**kwargs
) -> RankClassificationInstance:
"""
Converts the given instance to an instance for performing ranked classification

:param instance: the instance to convert
:param fewshot_instances: the number of few-show instances to include
:return: the instance in :class:`~catwalk.task.RankClassificationInstance` format
"""
if fewshot_instances is None:
fewshot_instances = []
prefix = ""
Expand Down Expand Up @@ -170,6 +177,15 @@ def instance_as_rank_classification(
fewshot_instances: Optional[List[Dict[str, Any]]] = None,
**kwargs
) -> RankClassificationInstance:
"""
Converts the given instance to an instance for performing ranked classification

:param instance: the instance to convert
:param fewshot_instances: a list of few-shot instances to include. These instances are given in
Huggingface dict format.
:param kwargs: extra arguments that are ignored
:return: the instance in :class:`~catwalk.task.RankClassificationInstance` format
"""
if fewshot_instances is None:
fewshot_instances = []
prefix = ""
Expand Down
14 changes: 14 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,17 @@
},
],
}


#
# The following is a workaround for a bug where Sphinx 5.3.0 tries to find a reference that isn't used anywhere.
#
def on_missing_reference(app, env, node, contnode):
if node['reftarget'] == 'metric kwargs':
return contnode
else:
return None


def setup(app):
app.connect('missing-reference', on_missing_reference)
22 changes: 21 additions & 1 deletion tests/test_spotchecks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,24 @@ def test_gpt2_performance(task: str):
predictions = PredictStep(model=model, task=task, limit=100)
metrics = CalculateMetricsStep(model=model, task=task, predictions=predictions)
results = metrics.result()
assert results["relative_improvement"] > 0
assert results['relative_improvement'] > 0


def test_lambada_gpt():
model = "gpt2"
task = "lambada"
predictions = PredictStep(model=model, task=task, limit=10)
metrics = CalculateMetricsStep(model=model, task=task, predictions=predictions)
results = metrics.result()
assert results['acc'] >= 0.4


def test_perplexity_gpt():
model = "gpt2"
task = "wikitext"
predictions = PredictStep(model=model, task=task, limit=10)
metrics = CalculateMetricsStep(model=model, task=task, predictions=predictions)
results = metrics.result()
assert results['word_perplexity'] < 40
assert results['byte_perplexity'] < 2.5
assert results['bits_per_byte'] < 1.5