Skip to content

Commit

Permalink
fix for tp
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
  • Loading branch information
Varun Sundar Rabindranath committed Dec 17, 2024
1 parent 83ee29a commit 51ef92a
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,19 @@ def _get_logits(
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)

# TODO (varun) : Replace with base layer get_logits()
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)

if logits is None:
return None

Expand Down

0 comments on commit 51ef92a

Please sign in to comment.