From 4263cfbbbbfd1ff890c71d7f9f6fb1708b474f88 Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Sat, 13 Jul 2024 07:25:08 +0000 Subject: [PATCH] minor fix --- prover/tactic_generator.py | 4 +++- retrieval/index.py | 2 +- retrieval/model.py | 4 +--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/prover/tactic_generator.py b/prover/tactic_generator.py index 660040e..5e01960 100644 --- a/prover/tactic_generator.py +++ b/prover/tactic_generator.py @@ -270,7 +270,9 @@ def __init__( def initialize(self) -> None: self.hf_gen.initialize() - self.retriever = PremiseRetriever.load_hf(self.ret_path, self.device) + self.retriever = PremiseRetriever.load_hf( + self.ret_path, self.max_inp_seq_len, self.device + ) self.retriever.load_corpus(self.indexed_corpus_path) async def generate( diff --git a/retrieval/index.py b/retrieval/index.py index c7b51b7..6beb7d8 100644 --- a/retrieval/index.py +++ b/retrieval/index.py @@ -30,7 +30,7 @@ def main() -> None: device = torch.device("cpu") else: device = torch.device("cuda") - model = PremiseRetriever.load_hf(args.ckpt_path, device, max_seq_len=2048) + model = PremiseRetriever.load_hf(args.ckpt_path, 2048, device) model.load_corpus(args.corpus_path) model.reindex_corpus(batch_size=args.batch_size) diff --git a/retrieval/model.py b/retrieval/model.py index daaf85f..7583563 100644 --- a/retrieval/model.py +++ b/retrieval/model.py @@ -51,10 +51,8 @@ def load(cls, ckpt_path: str, device, freeze: bool) -> "PremiseRetriever": @classmethod def load_hf( - cls, ckpt_path: str, device: int, dtype=None, max_seq_len: Optional[int] = None + cls, ckpt_path: str, max_seq_len: int, device: int, dtype=None ) -> "PremiseRetriever": - if max_seq_len is None: - max_seq_len = 999999999999 model = PremiseRetriever(ckpt_path, 0.0, 0, max_seq_len, 100).to(device).eval() if dtype is not None: return model.to(dtype)