diff --git a/README.md b/README.md index d063d04..a720017 100644 --- a/README.md +++ b/README.md @@ -310,22 +310,31 @@ python generation/main.py fit --config generation/confs/cli_lean4_novel_premises After the tactic generator is trained, we combine it with best-first search to prove theorems by interacting with Lean. -For models without retrieval, run: +The evaluation script takes Hugging Face model checkpoints (either local or remote) as input. For remote models, you can simply use their names, e.g., [kaiyuy/leandojo-lean4-tacgen-byt5-small](https://huggingface.co/kaiyuy/leandojo-lean4-tacgen-byt5-small). For locally trained models, you first need to convert them from PyTorch Ligthning checkpoints to Hugging Face checkpoints: ```bash -python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path $PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1 -python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path $PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1 +python scripts/convert_checkpoint.py generator --src $PATH_TO_GENERATOR_CHECKPOINT --dst ./leandojo-lean4-tacgen-byt5-small +python scripts/convert_checkpoint.py retriever --src $PATH_TO_RETRIEVER_CHECKPOINT --dst ./leandojo-lean4-retriever-byt5-small ``` +, where `PATH_TO_GENERATOR_CHECKPOINT` and `PATH_TO_RETRIEVER_CHECKPOINT` are PyTorch Ligthning checkpoints produced by the training script. -For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises): + +To evaluate the model without retrieval, run (using the `random` data split as example): +```bash +python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-tacgen-byt5-small --split test --num-workers 5 --num-gpus 1 +``` +You may tweak `--num-workers` and `--num-gpus` to fit your hardware. + + +For the model with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises): ```bash -python retrieval/index.py --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS -# Do it separately for two data splits. +python retrieval/index.py --ckpt_path ./leandojo-lean4-retriever-byt5-small --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS ``` +It saves the indexed corpurs as a pickle file to `PATH_TO_INDEXED_CORPUS`. Then, run: ```bash -python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path $PATH_TO_REPROVER_CHECKPOINT --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-cpus 8 --with-gpus -# Do it separately for two data splits. +python scripts/convert_checkpoint.py generator --src $PATH_TO_REPROVER_CHECKPOINT --dst ./leandojo-lean4-retriever-tacgen-byt5-small +python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-retriever-tacgen-byt5-small --ret_ckpt_path ./leandojo-lean4-retriever-byt5-small --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-workers 5 --num-gpus 1 ``` diff --git a/prover/proof_search.py b/prover/proof_search.py index a6813ec..359b9ab 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -65,6 +65,7 @@ def __init__( debug: bool, ) -> None: self.tac_gen = tac_gen + self.tac_gen.initialize() self.timeout = timeout self.num_sampled_tactics = num_sampled_tactics self.debug = debug @@ -309,7 +310,6 @@ def __init__( num_sampled_tactics: int, debug: bool, ) -> None: - tac_gen.initialize() self.prover = BestFirstSearchProver( tac_gen, timeout, diff --git a/retrieval/index.py b/retrieval/index.py index 6beb7d8..27d532d 100644 --- a/retrieval/index.py +++ b/retrieval/index.py @@ -35,7 +35,7 @@ def main() -> None: model.reindex_corpus(batch_size=args.batch_size) pickle.dump( - IndexedCorpus(model.corpus, model.corpus_embeddings.cpu()), + IndexedCorpus(model.corpus, model.corpus_embeddings.to(torch.float32).cpu()), open(args.output_path, "wb"), ) logger.info(f"Indexed corpus saved to {args.output_path}") diff --git a/scripts/stats.py b/scripts/stats.py index e1f4644..6450015 100644 --- a/scripts/stats.py +++ b/scripts/stats.py @@ -1,17 +1,36 @@ +import re import sys +import numpy as np from glob import glob +from loguru import logger +import matplotlib.pyplot as plt +total_time = [] +TOTAL_TIME_REGEX = re.compile(r"total_time=(?P