Skip to content

Commit

Permalink
add file
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 12, 2024
1 parent 04c80fb commit d2de547
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 4 deletions.
4 changes: 2 additions & 2 deletions generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from loguru import logger
from pytorch_lightning.cli import LightningCLI

from generator.datamodule import GeneratorDataModule
from generator.model import RetrievalAugmentedGenerator
from generation.datamodule import GeneratorDataModule
from generation.model import RetrievalAugmentedGenerator


class CLI(LightningCLI):
Expand Down
4 changes: 2 additions & 2 deletions generation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def validation_step(self, batch: Dict[str, Any], _) -> None:
)

def on_validation_epoch_end(self) -> None:
if self.eval_num_theorems == 0:
if self.eval_num_theorems == 0 or self.logger is None:
return

from prover.evaluate import evaluate # Avoid circular import.
Expand All @@ -228,7 +228,7 @@ def on_validation_epoch_end(self) -> None:
num_workers=self.eval_num_workers,
num_gpus=self.eval_num_gpus,
num_theorems=self.eval_num_theorems,
ckpt_path=ckpt_path,
gen_ckpt_path=ckpt_path,
)
else:
self.retriever.reindex_corpus(self.trainer.datamodule.eval_batch_size)
Expand Down
287 changes: 287 additions & 0 deletions prover/tactic_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import openai
from lean_dojo import Pos
from loguru import logger
from typing import List, Tuple
from abc import ABC, abstractmethod
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer

from retrieval.model import PremiseRetriever
from common import remove_marks, zip_strict, format_augmented_state


class TacticGenerator(ABC):
"""A tactic generator takes a state and generates multiple tactic candidates."""

@abstractmethod
async def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
raise NotImplementedError


class GPT4TacticGenerator(TacticGenerator):
def __init__(
self,
organization: str,
api_key: str,
model: str = "gpt-4",
max_tokens: int = 1024,
num_retries: int = 3,
threshold: float = 0.9,
):
super().__init__()
openai.organization = organization
openai.api_key = api_key
self.model = model
self.default_prompt = "You are an expert in theorem proving in Lean. We are trying to solve the Lean theorem 'THEOREM_FULL_NAME' from the mathlib file 'FILE_PATH'. The current tactic state is: 'TACTIC_STATE'. Suggest exactly NUM_SAMPLES unique tactics to progress in solving 'THEOREM_FULL_NAME', along with their confidence levels as a float between 0 and 1. Rank them in order of effectiveness. Present the tactics and their confidence levels as comma-separated tuples in this format: #(tactic_{1}, confidence_{1})#, #(tactic_{2}, confidence_{2})#, ..., #(tactic_{NUM_SAMPLES}, confidence_{NUM_SAMPLES})#."
self.max_tokens = max_tokens
self.num_retries = num_retries
self.threshold = threshold

async def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
prompt = (
self.default_prompt.replace("TACTIC_STATE", state)
.replace("FILE_PATH", file_path)
.replace("THEOREM_FULL_NAME", theorem_full_name)
.replace("NUM_SAMPLES", str(int(num_samples / self.threshold)))
)
logger.info(prompt)

for _ in range(self.num_retries):
response = None
# https://platform.openai.com/docs/guides/error-codes/python-library-error-types
try:
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
# temperature=0,
max_tokens=self.max_tokens,
# stop="E:" #
)
except openai.error.APIError as e:
# Handle API error here, e.g. retry or log
logger.info(f"OpenAI API returned an API Error: {e}")
continue
except openai.error.APIConnectionError as e:
# Handle connection error here
logger.info(f"Failed to connect to OpenAI API: {e}")
continue
except openai.error.RateLimitError as e:
# Handle rate limit error (we recommend using exponential backoff)
logger.info(f"OpenAI API request exceeded rate limit: {e}")
continue
except Exception as e:
logger.info(e)
continue

if response is None:
continue

logger.info(f"GPT-4 response: {response}")
output = response["choices"][0]["message"]["content"]
indices = []

for i, c in enumerate(output):
if c == "#":
indices.append(i)

tactics_with_scores = []

for i in range(1, len(indices), 2):
tactic_and_confidence = output[indices[i - 1] + 1 : indices[i]].strip()

try:
while tactic_and_confidence[0] == "(":
tactic_and_confidence = tactic_and_confidence[1:]

if tactic_and_confidence[-1] == ")":
tactic_and_confidence = tactic_and_confidence[:-1]

split_index = tactic_and_confidence.rindex(",")
tactic = tactic_and_confidence[:split_index].strip()
confidence = float(tactic_and_confidence[split_index + 1 :].strip())
except Exception as e:
logger.info(e)
logger.info(
f"{self.model} output {output[indices[i-1]+1:indices[i]]} was not formatted correctly and could not be parsed."
)
continue

tactics_with_scores.append((tactic, confidence))

if len(tactics_with_scores) < int(self.threshold * num_samples):
continue

tactics_with_scores = sorted(
tactics_with_scores, key=lambda x: x[1], reverse=True
)[: min(num_samples, len(tactics_with_scores))]
logger.debug(f"GPT-4 tactics: {tactics_with_scores}")
logger.debug(
f"GPT-4 tactic count requested: {num_samples} / {self.threshold} = {int(num_samples / self.threshold)}"
)
logger.debug(
f"GPT-4 tactic count received and parsed: {len(tactics_with_scores)}"
)
return tactics_with_scores

raise ValueError("GPT-4 outputs are unparsable.")


class FixedTacticGenerator(TacticGenerator):
def __init__(self, tactic, module) -> None:
self.tactic = tactic
self.module = module

async def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
return [(f"{{ {self.tactic} }}", 1.0)]


class HuggingFaceGenerator(TacticGenerator):
def __init__(
self,
model_path: str,
device,
max_oup_seq_len: int,
length_penalty: float,
template: str = "%s",
):
try:
self.generator = AutoModelForSeq2SeqLM.from_pretrained(model_path)
self.decoder_only = False
except ValueError:
self.generator = AutoModelForCausalLM.from_pretrained(model_path)
self.decoder_only = True
self.generator = self.generator.to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.device = device
self.max_oup_seq_len = max_oup_seq_len
self.length_penalty = length_penalty
self.template = template

async def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
state = self.template % state
logger.debug(state)
tokenized_state = self.tokenizer(state, return_tensors="pt")
state_ids = tokenized_state.input_ids.to(self.device)
state_mask = tokenized_state.attention_mask.to(self.device)

# Generate tactic candidates using beam search.
output = self.generator.generate(
input_ids=state_ids,
attention_mask=state_mask,
max_length=self.max_oup_seq_len,
num_beams=num_samples,
length_penalty=self.length_penalty,
do_sample=False,
num_return_sequences=num_samples,
early_stopping=False,
output_scores=True,
return_dict_in_generate=True,
)

# Return the output.
raw_output_text = self.tokenizer.batch_decode(
output.sequences, skip_special_tokens=True
)
raw_scores = output.sequences_scores.tolist()

output_text = []
output_score = []

for j in range(num_samples):
t = remove_marks(raw_output_text[j])
if self.decoder_only and t.startswith(state):
t = t[len(state) :]
if t not in output_text:
output_text.append(t)
output_score.append(raw_scores[j])

return list(zip_strict(output_text, output_score))


class RetrievalAugmentedGenerator(TacticGenerator):

def __init__(
self,
gen_path: str,
ret_path: str,
indexed_corpus_path: str,
device,
max_oup_seq_len: int,
length_penalty: float,
max_num_retrieved: int,
) -> None:
self.hf_gen = HuggingFaceGenerator(
gen_path, device, max_oup_seq_len, length_penalty
)
self.retriever = PremiseRetriever.load_hf(ret_path, device)
self.retriever.load_corpus(indexed_corpus_path)
self.max_num_retrieved = max_num_retrieved

async def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
retrieved_premises, _ = self.retriever.retrieve(
state,
file_path,
theorem_full_name,
theorem_pos,
self.max_num_retrieved,
)
aug_state = format_augmented_state(state, retrieved_premises)
return await self.hf_gen.generate(
aug_state, file_path, theorem_full_name, theorem_pos, num_samples
)


class VllmGenerator(TacticGenerator):
def __init__(self, vllm_actor, template: str = "[GOAL]\n%s\n[PROOFSTEP]\n") -> None:
self.vllm_actor = vllm_actor
self.template = template

async def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
# prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
prompt = self.template % state
response = await self.vllm_actor.generate.remote(prompt, num_samples)
return [
(remove_marks(x.text).strip(), x.cumulative_logprob)
for x in response.outputs
]

0 comments on commit d2de547

Please sign in to comment.