Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Apr 5, 2024
1 parent 3236b83 commit bff3a51
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 33 deletions.
15 changes: 11 additions & 4 deletions generator/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ def batch_generate(
return tactics_with_scores


def trial_completion_with_args(args_tuple: Tuple[openai.Client, int, float, Dict[str, Any]]) -> List[Tuple[str, float]]:
def trial_completion_with_args(
args_tuple: Tuple[openai.Client, int, float, Dict[str, Any]]
) -> List[Tuple[str, float]]:
client, num_retries, backoff_time, completion_args = args_tuple
trial = 0
while trial < num_retries:
Expand All @@ -382,6 +384,7 @@ def trial_completion_with_args(args_tuple: Tuple[openai.Client, int, float, Dict
logger.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time)


class VLLMGenerator(TacticGenerator):
def __init__(
self,
Expand Down Expand Up @@ -415,8 +418,11 @@ def generate_from_args(self, args: List[Dict[str, Any]]) -> List[Tuple[str, floa
with mpp.ThreadPool(64) as p:
all_results = []
for result in p.imap(
trial_completion_with_args,
[(self.client, self.num_retries, self.backoff_time, arg) for arg in args],
trial_completion_with_args,
[
(self.client, self.num_retries, self.backoff_time, arg)
for arg in args
],
):
all_results.extend(result)
return all_results
Expand All @@ -434,7 +440,7 @@ def generate(
prompt = self.prompt_format.replace("TACTIC_STATE", state.strip())
completion_args = self.get_completion_args(prompt)
return self.generate_from_args([completion_args] * num_samples)

def get_completion_args(self, prompt: str) -> dict[str, Any]:
return {
"model": self.model,
Expand Down Expand Up @@ -464,6 +470,7 @@ def batch_generate(

return self.generate_from_args(all_args)


class GPT4TacticGenerator(TacticGenerator):
def __init__(
self,
Expand Down
46 changes: 23 additions & 23 deletions prover/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import List, Tuple, Optional, Any
from lean_dojo import LeanGitRepo, Theorem, Pos, is_available_in_cache

from common import set_logger, zip_strict
from common import set_logger
from prover.proof_search import Status, DistributedProver


Expand Down Expand Up @@ -139,7 +139,9 @@ def evaluate(
num_sampled_tactics=num_sampled_tactics,
debug=verbose,
)
results = prover.search_unordered(repo, unfinished_theorems, unfinished_positions, progress_dir=progress_dir)
results = prover.search_unordered(
repo, unfinished_theorems, unfinished_positions, progress_dir=progress_dir
)

# Calculate the result statistics.
num_proved = num_failed = num_discarded = 0
Expand Down Expand Up @@ -228,9 +230,7 @@ def main() -> None:
parser.add_argument(
"--vllm-args-json-path", type=str, help="URL of the VLLM server."
)
parser.add_argument(
"--progress-dir", type=str, help="Progress directory"
)
parser.add_argument("--progress-dir", type=str, help="Progress directory")
args = parser.parse_args()

assert args.ckpt_path or args.tactic or args.vllm_args_json_path
Expand All @@ -244,24 +244,24 @@ def main() -> None:
logger.info(args)

pass_1 = evaluate(
data_path = args.data_path,
exp_id = args.exp_id,
split = args.split,
file_path = args.file_path,
full_name = args.full_name,
name_filter = args.name_filter,
num_theorems = args.num_theorems,
ckpt_path = args.ckpt_path,
indexed_corpus_path = args.indexed_corpus_path,
tactic = args.tactic,
module = args.module,
num_sampled_tactics = args.num_sampled_tactics,
vllm_args = vllm_args,
timeout = args.timeout,
num_workers = args.num_workers,
num_gpus = args.num_gpus,
verbose = args.verbose,
progress_dir = args.progress_dir,
data_path=args.data_path,
exp_id=args.exp_id,
split=args.split,
file_path=args.file_path,
full_name=args.full_name,
name_filter=args.name_filter,
num_theorems=args.num_theorems,
ckpt_path=args.ckpt_path,
indexed_corpus_path=args.indexed_corpus_path,
tactic=args.tactic,
module=args.module,
num_sampled_tactics=args.num_sampled_tactics,
vllm_args=vllm_args,
timeout=args.timeout,
num_workers=args.num_workers,
num_gpus=args.num_gpus,
verbose=args.verbose,
progress_dir=args.progress_dir,
)

logger.info(f"Pass@1: {pass_1}")
Expand Down
30 changes: 24 additions & 6 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

from common import zip_strict
from prover.search_tree import *
from generator.model import RetrievalAugmentedGenerator, FixedTacticGenerator, VLLMGenerator
from generator.model import (
RetrievalAugmentedGenerator,
FixedTacticGenerator,
VLLMGenerator,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -61,7 +65,6 @@ def serialize(self) -> str:
return json.dumps(result_dict, ensure_ascii=False, indent=4)



class BestFirstSearchProver:
"""A prover that uses best-first search to find proofs using a tactic generator."""

Expand All @@ -83,10 +86,14 @@ def __init__(
self.total_time = None

def search(
self, repo: LeanGitRepo, thm: Theorem, pos: Pos, progress_dir: Optional[str] = None
self,
repo: LeanGitRepo,
thm: Theorem,
pos: Pos,
progress_dir: Optional[str] = None,
) -> Optional[SearchResult]:
logger.info(f"Proving {thm}")

theorem_uid = thm.uid
if progress_dir is not None:
assert os.path.isdir(progress_dir)
Expand Down Expand Up @@ -336,7 +343,14 @@ def __init__(
if vllm_args:
assert all(
key in vllm_args
for key in ["server_url", "model", "max_tokens", "temperature", "stop", "prompt_format"]
for key in [
"server_url",
"model",
"max_tokens",
"temperature",
"stop",
"prompt_format",
]
), vllm_args
tac_gen = VLLMGenerator(
server_url=vllm_args["server_url"],
Expand Down Expand Up @@ -483,7 +497,11 @@ def __init__(
self.prover_pool = ActorPool(provers)

def search_unordered(
self, repo: LeanGitRepo, theorems: List[Theorem], positions: List[Pos], progress_dir: Optional[str] = None
self,
repo: LeanGitRepo,
theorems: List[Theorem],
positions: List[Pos],
progress_dir: Optional[str] = None,
) -> List[SearchResult]:
"""Parallel proof search for `theorems`. The order of the results is not guaranteed to match the order of the input."""
if not self.distributed:
Expand Down

0 comments on commit bff3a51

Please sign in to comment.