-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix beam search logits processor #3454
Changes from 7 commits
9535c1e
7e1edef
afe7aa5
a1bca5e
00db489
5ea2899
1128f86
c0e4028
d5e62e7
fea77ba
4f31aa4
39b26b9
4c2a75a
d62e522
4e6444c
76f7520
c63b61c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import copy | ||
import json | ||
import math | ||
from collections import defaultdict | ||
|
@@ -25,7 +26,7 @@ | |
from outlines.fsm.json_schema import build_regex_from_schema | ||
|
||
|
||
class BaseLogitsProcessor: | ||
class BaseGuidedLogitsProcessor: | ||
maximzubkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): | ||
"""Adapt vLLM's tokenizer to use to compile the FSM. | ||
|
@@ -81,9 +82,29 @@ def __call__(self, input_ids: List[int], | |
|
||
seq_id = hash(tuple(input_ids)) | ||
|
||
if len(input_ids) == 0: | ||
if not hasattr(self, "fsm_state"): | ||
self.init_state() | ||
else: | ||
if not hasattr(self, "fsm_state"): | ||
if len(input_ids) == 1: | ||
maximzubkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# This special scenario arises during sampling strategies | ||
# such as beam search when the number of sequences to be | ||
# generated (`n`) is bigger then 1. | ||
# During the initial step of beam search, only the input | ||
#`prompt` is given, while the beams themselves are yet | ||
# to be defined. | ||
# Consequently, the logits will have a shape of | ||
# [1, vocab_size]. | ||
# Due to this `self.fsm_stat` initialization will be | ||
# triggered onlys for the very first `logits_processor`, | ||
# leaving the remaining `n-1` uninitialized. | ||
self.init_state() | ||
empty_seq_id = hash(tuple([])) | ||
self.fsm.allowed_token_ids(self.fsm_state[empty_seq_id]) | ||
else: | ||
raise ValueError( | ||
f"Multiple ids were generated: {input_ids}, " | ||
"while fsm is still not initialized") | ||
last_token = input_ids[-1] | ||
last_seq_id = hash(tuple(input_ids[:-1])) | ||
self.fsm_state[seq_id] = self.fsm.next_state( | ||
|
@@ -99,8 +120,16 @@ def __call__(self, input_ids: List[int], | |
|
||
return scores | ||
|
||
def __deepcopy__(self, memo): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that by implementing the Code to reproduce:
Error message:
|
||
logits_processor = self.__class__( | ||
copy.deepcopy(self.fsm.cfg_string, memo), | ||
copy.deepcopy(self.fsm.tokenizer, memo), | ||
) | ||
logits_processor.fsm = self.fsm.copy() | ||
return logits_processor | ||
|
||
class RegexLogitsProcessor(BaseLogitsProcessor): | ||
|
||
class RegexLogitsProcessor(BaseGuidedLogitsProcessor): | ||
|
||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): | ||
"""Compile the FSM that drives the regex-structured generation. | ||
|
@@ -120,10 +149,12 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): | |
|
||
class JSONLogitsProcessor(RegexLogitsProcessor): | ||
|
||
def __init__(self, | ||
schema: Union[str, Dict, BaseModel], | ||
tokenizer: PreTrainedTokenizerBase, | ||
whitespace_pattern: Optional[str] = None): | ||
def __init__( | ||
self, | ||
schema: Union[str, Dict, BaseModel], | ||
tokenizer: PreTrainedTokenizerBase, | ||
whitespace_pattern: Optional[str] = None, | ||
): | ||
"""Compile the FSM that drives the JSON-guided generation. | ||
|
||
Parameters | ||
|
@@ -154,7 +185,7 @@ def __init__(self, | |
super().__init__(regex_string, tokenizer) | ||
|
||
|
||
class CFGLogitsProcessor(BaseLogitsProcessor): | ||
class CFGLogitsProcessor(BaseGuidedLogitsProcessor): | ||
maximzubkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): | ||
maximzubkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Compile the FSM that drives the context free grammar generation. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,7 +148,14 @@ def __init__( | |
self.prompt_logprobs = prompt_logprobs | ||
self.skip_special_tokens = skip_special_tokens | ||
self.spaces_between_special_tokens = spaces_between_special_tokens | ||
self.logits_processors = logits_processors | ||
# A separate logit processor is needed for each output sequence | ||
# since certain logits processors (such as BaseGuidedLogitsProcessor) | ||
# in multi-beam generation must track the sequences generated | ||
# by each beam up to that point. | ||
# See https://github.com/vllm-project/vllm/issues/3448 for more | ||
self.logits_processors = [ | ||
copy.deepcopy(logits_processors) for _ in range(n) | ||
maximzubkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
self.include_stop_str_in_output = include_stop_str_in_output | ||
self._verify_args() | ||
if self.use_beam_search: | ||
|
@@ -246,10 +253,10 @@ def clone(self) -> "SamplingParams": | |
See https://github.com/vllm-project/vllm/issues/3087 | ||
""" | ||
|
||
logit_processor_refs = None if self.logits_processors is None else { | ||
logit_processor_refs = (None if self.logits_processors is None else { | ||
id(lp): lp | ||
for lp in self.logits_processors | ||
} | ||
for lp in self.logits_processors if not hasattr(lp[0], "fsm") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This affects the following line and as described in the following issue and partially addressed by the following PR. Unfortunately, in this case, the copy of logits_processors is inevitable and seems like it indeed does slow down the inference. However, I'm using a relatively old GPU (4x NVIDIA RTX A4000, 16Gb) so I would need someone with better hardware to benchmark the speed. Refer to the issue for the code snippets to reproduce the tests |
||
}) | ||
return copy.deepcopy(self, memo=logit_processor_refs) | ||
|
||
def __repr__(self) -> str: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since n = 3 here let's also verify all the outputs, you might also need to add
use_beam_search
toextra_body