Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 12, 2024
1 parent 6f1d9c4 commit a554532
Show file tree
Hide file tree
Showing 20 changed files with 417 additions and 1,368 deletions.
26 changes: 0 additions & 26 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,36 +354,10 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]:
return list(all_pos_premises)


def format_tactic(annot_tac: str, provenances) -> str:
"""Use full names for the all <a>...</a>."""
if len(provenances) == 0:
return annot_tac

tac = ""
marks = list(re.finditer(r"<a>(?P<ident>.+?)</a>", annot_tac))

for i, (m, prov) in enumerate(zip_strict(marks, provenances)):
last_end = marks[i - 1].end() if i > 0 else 0
tac += annot_tac[last_end : m.start()] + "<a>" + prov["full_name"] + "</a>"

tac += annot_tac[marks[-1].end() :]
return tac


def format_state(s: str) -> str:
m = re.match(r"\d+ goals", s)
if m is not None:
return s[m.end() :].strip()
else:
return s


def format_augmented_state(
s: str, premises: List[Premise], max_len: int, p_drop: float
) -> str:
"""Format a state with retrieved premises and drop some of them with probability ``p_drop``."""
s = format_state(s)

aug_s = ""
length = 0
max_premises_len = max_len - len(bytes(s.encode("utf-8")))
Expand Down
File renamed without changes.
3 changes: 0 additions & 3 deletions generator/datamodule.py → generation/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
Batch,
Corpus,
Example,
format_state,
remove_marks,
format_tactic,
format_augmented_state,
)

Expand Down Expand Up @@ -54,7 +52,6 @@ def _load_data(self, data_path: str) -> List[Example]:
"commit": thm["commit"],
"file_path": thm["file_path"],
"full_name": thm["full_name"],
# "state": format_state(tac["state_before"]),
"state": tac["state_before"],
"tactic": tactic,
}
Expand Down
File renamed without changes.
255 changes: 255 additions & 0 deletions generation/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""Lightning module for the tactic generator."""

import os
import torch
import shutil
import pickle
from loguru import logger
import pytorch_lightning as pl
from torchmetrics import Metric
from typing import List, Dict, Any, Optional
from transformers import T5ForConditionalGeneration, AutoTokenizer

from common import (
remove_marks,
IndexedCorpus,
get_optimizers,
load_checkpoint,
)
from retrieval.model import PremiseRetriever


torch.set_float32_matmul_precision("medium")


class TopkAccuracy(Metric):
is_differentiable: Optional[bool] = False
higher_is_better: Optional[bool] = True
full_state_update: bool = True

def __init__(self, k: int) -> None:
super().__init__()
self.k = k
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, batch_preds: List[List[str]], batch_gt: List[str]):
assert len(batch_preds) == len(batch_gt)
for preds, gt in zip(batch_preds, batch_gt):
# This still doesn't account for short names vs. full names.
gt = remove_marks(gt)
preds = [remove_marks(p) for p in preds]
self.correct += gt in preds[: self.k]
self.total += len(batch_gt)

def compute(self) -> float:
return self.correct.float() / self.total


class RetrievalAugmentedGenerator(pl.LightningModule):
def __init__(
self,
model_name: str,
lr: float,
warmup_steps: int,
num_beams: int,
eval_num_retrieved: int,
eval_num_workers: int,
eval_num_gpus: int,
eval_num_theorems: int,
max_inp_seq_len: int,
max_oup_seq_len: int,
length_penalty: float = 0.0,
ret_ckpt_path: Optional[str] = None,
) -> None:
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.warmup_steps = warmup_steps
self.num_beams = num_beams
self.length_penalty = length_penalty
self.eval_num_retrieved = eval_num_retrieved
self.eval_num_workers = eval_num_workers
self.eval_num_gpus = eval_num_gpus
self.eval_num_theorems = eval_num_theorems
self.max_inp_seq_len = max_inp_seq_len
self.max_oup_seq_len = max_oup_seq_len

if ret_ckpt_path is None:
self.retriever = None
else:
logger.info(f"Loading the retriever from {ret_ckpt_path}")
self.retriever = PremiseRetriever.load(
ret_ckpt_path, self.device, freeze=True
)

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.generator = T5ForConditionalGeneration.from_pretrained(model_name)

self.topk_accuracies = dict()
for k in range(1, num_beams + 1):
acc = TopkAccuracy(k)
self.topk_accuracies[k] = acc
self.add_module(f"top{k}_acc_val", acc)

@classmethod
def load(
cls, ckpt_path: str, device, freeze: bool
) -> "RetrievalAugmentedGenerator":
return load_checkpoint(cls, ckpt_path, device, freeze)

def forward(
self,
state_ids: torch.Tensor,
state_mask: torch.Tensor,
tactic_ids: torch.Tensor,
) -> torch.Tensor:
return self.generator(
input_ids=state_ids,
attention_mask=state_mask,
labels=tactic_ids,
).loss

############
# Training #
############

def training_step(self, batch, batch_idx: int):
loss = self(
batch["state_ids"],
batch["state_mask"],
batch["tactic_ids"],
)
self.log(
"loss_train",
loss,
on_step=True,
on_epoch=True,
sync_dist=True,
batch_size=len(batch),
)
self._log_io_texts("train", batch["state_ids"], batch["tactic_ids"])
return loss

def configure_optimizers(self) -> Dict[str, Any]:
return get_optimizers(
self.parameters(), self.trainer, self.lr, self.warmup_steps
)

def _log_io_texts(
self,
split: str,
state_ids: torch.LongTensor,
tactic_ids: torch.LongTensor,
) -> None:
inp = self.tokenizer.decode(state_ids[0], skip_special_tokens=True)
oup_ids = torch.where(
tactic_ids[0] == -100, self.tokenizer.pad_token_id, tactic_ids[0]
)
oup = self.tokenizer.decode(oup_ids, skip_special_tokens=True)
self.logger.log_text(
f"{split}_samples",
["state", "tactic"],
[[inp, oup]],
step=self.global_step,
)

def on_fit_start(self) -> None:
if self.logger is not None:
self.logger.log_hyperparams(self.hparams)
self.logger.watch(self.generator)
assert self.trainer is not None
logger.info(f"Logging to {self.trainer.log_dir}")

if self.retriever is not None:
self.retriever.load_corpus(self.trainer.datamodule.corpus)

##############
# Validation #
##############

def validation_step(self, batch: Dict[str, Any], _) -> None:
state_ids = batch["state_ids"]
state_mask = batch["state_mask"]
tactic_ids = batch["tactic_ids"]

loss = self(state_ids, state_mask, tactic_ids)
self.log(f"loss_val", loss, on_step=False, on_epoch=True, sync_dist=True)
self._log_io_texts("val", state_ids, tactic_ids)

# Generate topk tactic candidates via Beam Search.
output = self.generator.generate(
input_ids=state_ids,
attention_mask=state_mask,
max_length=self.max_oup_seq_len,
num_beams=self.num_beams,
do_sample=False,
num_return_sequences=self.num_beams,
early_stopping=False,
)
output_text = self.tokenizer.batch_decode(output, skip_special_tokens=True)
batch_size = state_ids.size(0)
assert len(output_text) == batch_size * self.num_beams
tactics_pred = [
output_text[i * self.num_beams : (i + 1) * self.num_beams]
for i in range(batch_size)
]

msg = "\n".join(tactics_pred[0])
self.logger.log_text("preds_val", ["tactics"], [[msg]], step=self.global_step)

# Log the topk accuracies.
for k in range(1, self.num_beams + 1):
topk_acc = self.topk_accuracies[k]
topk_acc(tactics_pred, batch["tactic"])
self.log(
f"top{k}_acc_val",
topk_acc,
on_step=False,
on_epoch=True,
sync_dist=True,
)

def on_validation_epoch_end(self) -> None:
if self.eval_num_theorems == 0:
return

from prover.evaluate import evaluate # Avoid circular import.

ckpt_path = f"{self.trainer.log_dir}/last-tmp.ckpt"
self.trainer.save_checkpoint(ckpt_path)
logger.info(f"Saved checkpoint to {ckpt_path}. Evaluating...")
torch.cuda.empty_cache()

data_path = self.trainer.datamodule.data_path
if self.retriever is None:
acc = evaluate(
data_path=data_path,
num_workers=self.eval_num_workers,
num_gpus=self.eval_num_gpus,
num_theorems=self.eval_num_theorems,
ckpt_path=ckpt_path,
)
else:
self.retriever.reindex_corpus(self.trainer.datamodule.eval_batch_size)
corpus_path = f"{self.trainer.log_dir}/checkpoints/indexed_corpus.pickle"
pickle.dump(
IndexedCorpus(
self.retriever.corpus, self.retriever.corpus_embeddings.cpu()
),
open(corpus_path, "wb"),
)
acc = evaluate(
data_path=data_path,
num_workers=self.eval_num_workers,
num_gpus=self.eval_num_gpus,
num_theorems=self.eval_num_theorems,
ckpt_path=ckpt_path,
indexed_corpus_path=corpus_path,
)

self.log("Pass@1_val", acc, on_step=False, on_epoch=True, sync_dist=True)
logger.info(f"Pass@1: {acc}")

if os.path.exists(ckpt_path):
shutil.rmtree(ckpt_path)
File renamed without changes.
Loading

0 comments on commit a554532

Please sign in to comment.