Skip to content

Commit

Permalink
removing unused function argument
Browse files Browse the repository at this point in the history
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
  • Loading branch information
yannicks1 committed Jan 27, 2025
1 parent 124f3a9 commit 0c9bec0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
3 changes: 1 addition & 2 deletions vllm/worker/spyre_embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def __init__(
softmax=False)

def load_model(self, prompt_lens: Iterable[int],
num_decode_tokens: Iterable[int],
batch_sizes: Iterable[int]) -> None:
num_decode_tokens: Iterable[int]) -> None:
self.model = AutoModel.from_pretrained(self.model_config.model)
self.model.eval()
torch.set_grad_enabled(False)
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def __init__(
self.model: nn.Module

def load_model(self, prompt_lens: Iterable[int],
num_decode_tokens: Iterable[int],
batch_sizes: Iterable[int]) -> None:
num_decode_tokens: Iterable[int]) -> None:
max_pad_length = max(prompt_lens)
max_decode_length = max(num_decode_tokens)
self.model = get_spyre_model(self.model_config,
Expand Down
7 changes: 3 additions & 4 deletions vllm/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,12 @@ def load_model(self):
# printing env variables for debugging purposes
load_model_start_t = time.time()

wup_prompt_lens, wup_new_tokens, wup_batch_sizes = zip(
*[(s["prompt_length"], s["new_tokens"], s["batch_size"])
wup_prompt_lens, wup_new_tokens = zip(
*[(s["prompt_length"], s["new_tokens"])
for s in self.scheduler_config.spyre_warmup_shapes])

self.model_runner.load_model(prompt_lens=wup_prompt_lens,
num_decode_tokens=wup_new_tokens,
batch_sizes=wup_batch_sizes)
num_decode_tokens=wup_new_tokens)

load_model_end_t = time.time()
load_model_total_t = load_model_end_t - load_model_start_t
Expand Down

0 comments on commit 0c9bec0

Please sign in to comment.