From bff3a5107774144a241defb4c7009edc7bb846de Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Fri, 5 Apr 2024 14:28:10 +0000 Subject: [PATCH] format code --- generator/model.py | 15 ++++++++++---- prover/evaluate.py | 46 +++++++++++++++++++++--------------------- prover/proof_search.py | 30 +++++++++++++++++++++------ 3 files changed, 58 insertions(+), 33 deletions(-) diff --git a/generator/model.py b/generator/model.py index 28fb448..1c3c627 100644 --- a/generator/model.py +++ b/generator/model.py @@ -364,7 +364,9 @@ def batch_generate( return tactics_with_scores -def trial_completion_with_args(args_tuple: Tuple[openai.Client, int, float, Dict[str, Any]]) -> List[Tuple[str, float]]: +def trial_completion_with_args( + args_tuple: Tuple[openai.Client, int, float, Dict[str, Any]] +) -> List[Tuple[str, float]]: client, num_retries, backoff_time, completion_args = args_tuple trial = 0 while trial < num_retries: @@ -382,6 +384,7 @@ def trial_completion_with_args(args_tuple: Tuple[openai.Client, int, float, Dict logger.info(f"Retrying in {backoff_time} seconds...") time.sleep(backoff_time) + class VLLMGenerator(TacticGenerator): def __init__( self, @@ -415,8 +418,11 @@ def generate_from_args(self, args: List[Dict[str, Any]]) -> List[Tuple[str, floa with mpp.ThreadPool(64) as p: all_results = [] for result in p.imap( - trial_completion_with_args, - [(self.client, self.num_retries, self.backoff_time, arg) for arg in args], + trial_completion_with_args, + [ + (self.client, self.num_retries, self.backoff_time, arg) + for arg in args + ], ): all_results.extend(result) return all_results @@ -434,7 +440,7 @@ def generate( prompt = self.prompt_format.replace("TACTIC_STATE", state.strip()) completion_args = self.get_completion_args(prompt) return self.generate_from_args([completion_args] * num_samples) - + def get_completion_args(self, prompt: str) -> dict[str, Any]: return { "model": self.model, @@ -464,6 +470,7 @@ def batch_generate( return self.generate_from_args(all_args) + class GPT4TacticGenerator(TacticGenerator): def __init__( self, diff --git a/prover/evaluate.py b/prover/evaluate.py index f7754a0..31db9e3 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -12,7 +12,7 @@ from typing import List, Tuple, Optional, Any from lean_dojo import LeanGitRepo, Theorem, Pos, is_available_in_cache -from common import set_logger, zip_strict +from common import set_logger from prover.proof_search import Status, DistributedProver @@ -139,7 +139,9 @@ def evaluate( num_sampled_tactics=num_sampled_tactics, debug=verbose, ) - results = prover.search_unordered(repo, unfinished_theorems, unfinished_positions, progress_dir=progress_dir) + results = prover.search_unordered( + repo, unfinished_theorems, unfinished_positions, progress_dir=progress_dir + ) # Calculate the result statistics. num_proved = num_failed = num_discarded = 0 @@ -228,9 +230,7 @@ def main() -> None: parser.add_argument( "--vllm-args-json-path", type=str, help="URL of the VLLM server." ) - parser.add_argument( - "--progress-dir", type=str, help="Progress directory" - ) + parser.add_argument("--progress-dir", type=str, help="Progress directory") args = parser.parse_args() assert args.ckpt_path or args.tactic or args.vllm_args_json_path @@ -244,24 +244,24 @@ def main() -> None: logger.info(args) pass_1 = evaluate( - data_path = args.data_path, - exp_id = args.exp_id, - split = args.split, - file_path = args.file_path, - full_name = args.full_name, - name_filter = args.name_filter, - num_theorems = args.num_theorems, - ckpt_path = args.ckpt_path, - indexed_corpus_path = args.indexed_corpus_path, - tactic = args.tactic, - module = args.module, - num_sampled_tactics = args.num_sampled_tactics, - vllm_args = vllm_args, - timeout = args.timeout, - num_workers = args.num_workers, - num_gpus = args.num_gpus, - verbose = args.verbose, - progress_dir = args.progress_dir, + data_path=args.data_path, + exp_id=args.exp_id, + split=args.split, + file_path=args.file_path, + full_name=args.full_name, + name_filter=args.name_filter, + num_theorems=args.num_theorems, + ckpt_path=args.ckpt_path, + indexed_corpus_path=args.indexed_corpus_path, + tactic=args.tactic, + module=args.module, + num_sampled_tactics=args.num_sampled_tactics, + vllm_args=vllm_args, + timeout=args.timeout, + num_workers=args.num_workers, + num_gpus=args.num_gpus, + verbose=args.verbose, + progress_dir=args.progress_dir, ) logger.info(f"Pass@1: {pass_1}") diff --git a/prover/proof_search.py b/prover/proof_search.py index 14f99ad..d7a61f4 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -29,7 +29,11 @@ from common import zip_strict from prover.search_tree import * -from generator.model import RetrievalAugmentedGenerator, FixedTacticGenerator, VLLMGenerator +from generator.model import ( + RetrievalAugmentedGenerator, + FixedTacticGenerator, + VLLMGenerator, +) @dataclass(frozen=True) @@ -61,7 +65,6 @@ def serialize(self) -> str: return json.dumps(result_dict, ensure_ascii=False, indent=4) - class BestFirstSearchProver: """A prover that uses best-first search to find proofs using a tactic generator.""" @@ -83,10 +86,14 @@ def __init__( self.total_time = None def search( - self, repo: LeanGitRepo, thm: Theorem, pos: Pos, progress_dir: Optional[str] = None + self, + repo: LeanGitRepo, + thm: Theorem, + pos: Pos, + progress_dir: Optional[str] = None, ) -> Optional[SearchResult]: logger.info(f"Proving {thm}") - + theorem_uid = thm.uid if progress_dir is not None: assert os.path.isdir(progress_dir) @@ -336,7 +343,14 @@ def __init__( if vllm_args: assert all( key in vllm_args - for key in ["server_url", "model", "max_tokens", "temperature", "stop", "prompt_format"] + for key in [ + "server_url", + "model", + "max_tokens", + "temperature", + "stop", + "prompt_format", + ] ), vllm_args tac_gen = VLLMGenerator( server_url=vllm_args["server_url"], @@ -483,7 +497,11 @@ def __init__( self.prover_pool = ActorPool(provers) def search_unordered( - self, repo: LeanGitRepo, theorems: List[Theorem], positions: List[Pos], progress_dir: Optional[str] = None + self, + repo: LeanGitRepo, + theorems: List[Theorem], + positions: List[Pos], + progress_dir: Optional[str] = None, ) -> List[SearchResult]: """Parallel proof search for `theorems`. The order of the results is not guaranteed to match the order of the input.""" if not self.distributed: