Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix beam search logits processor #3454

1 change: 1 addition & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI):
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
n=3,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since n = 3 here let's also verify all the outputs, you might also need to add use_beam_search to extra_body

max_tokens=500,
extra_body=dict(guided_grammar=simple_sql_grammar))

Expand Down
47 changes: 39 additions & 8 deletions vllm/model_executor/guided_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import math
from collections import defaultdict
Expand All @@ -25,7 +26,7 @@
from outlines.fsm.json_schema import build_regex_from_schema


class BaseLogitsProcessor:
class BaseGuidedLogitsProcessor:

def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
Expand Down Expand Up @@ -81,9 +82,29 @@ def __call__(self, input_ids: List[int],

seq_id = hash(tuple(input_ids))

if len(input_ids) == 0:
if not hasattr(self, "fsm_state"):
self.init_state()
else:
if not hasattr(self, "fsm_state"):
if len(input_ids) == 1:
# This special scenario arises during sampling strategies
# such as beam search when the number of sequences to be
# generated (`n`) is bigger then 1.
# During the initial step of beam search, only the input
#`prompt` is given, while the beams themselves are yet
# to be defined.
# Consequently, the logits will have a shape of
# [1, vocab_size].
# Due to this `self.fsm_stat` initialization will be
# triggered onlys for the very first `logits_processor`,
# leaving the remaining `n-1` uninitialized.
self.init_state()
empty_seq_id = hash(tuple([]))
self.fsm.allowed_token_ids(self.fsm_state[empty_seq_id])
else:
raise ValueError(
f"Multiple ids were generated: {input_ids}, "
"while fsm is still not initialized")
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
Expand All @@ -99,8 +120,16 @@ def __call__(self, input_ids: List[int],

return scores

def __deepcopy__(self, memo):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that by implementing the __deepcopy__ method here it is now possible to copy.deepcopy an entire BaseGuidedLogitsProcessor object. This lines are needed since a naive copy of outlines.fsm.guide.CFGGuide fails:

Code to reproduce:

import copy
from outlines.fsm.guide import CFGGuide
from transformers import AutoTokenizer

model = "microsoft/phi_1"

tokenizer = AutoTokenizer.from_pretrained(model)

simple_sql_grammar = """
start: select_statement

select_statement: "SELECT" column "from" table "where" condition

column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number

number: "1" | "2"
"""

fsm = CFGGuide(simple_sql_grammar, tokenizer)
copy.deepcopy(fsm)

Error message:

Traceback (most recent call last):
  File "/Users/maximzubkov/tmp.py", line 22, in <module>
    copy.deepcopy(fsm)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/maximzubkov/.pyenv/versions/3.10.2/lib/python3.10/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle 'module' object

logits_processor = self.__class__(
copy.deepcopy(self.fsm.cfg_string, memo),
copy.deepcopy(self.fsm.tokenizer, memo),
)
logits_processor.fsm = self.fsm.copy()
return logits_processor

class RegexLogitsProcessor(BaseLogitsProcessor):

class RegexLogitsProcessor(BaseGuidedLogitsProcessor):

def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.
Expand All @@ -120,10 +149,12 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):

class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self,
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None):
def __init__(
self,
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None,
):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
Expand Down Expand Up @@ -154,7 +185,7 @@ def __init__(self,
super().__init__(regex_string, tokenizer)


class CFGLogitsProcessor(BaseLogitsProcessor):
class CFGLogitsProcessor(BaseGuidedLogitsProcessor):

def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation.
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,18 @@ def _apply_logits_processors(
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
for i, seq_id in enumerate(seq_ids):
logits_processors = sampling_params.logits_processors[i]
if logits_processors:
found_logits_processors = True
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits
Expand Down
15 changes: 11 additions & 4 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,14 @@ def __init__(
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
# A separate logit processor is needed for each output sequence
# since certain logits processors (such as BaseGuidedLogitsProcessor)
# in multi-beam generation must track the sequences generated
# by each beam up to that point.
# See https://github.com/vllm-project/vllm/issues/3448 for more
self.logits_processors = [
copy.deepcopy(logits_processors) for _ in range(n)
]
self.include_stop_str_in_output = include_stop_str_in_output
self._verify_args()
if self.use_beam_search:
Expand Down Expand Up @@ -246,10 +253,10 @@ def clone(self) -> "SamplingParams":
See https://github.com/vllm-project/vllm/issues/3087
"""

logit_processor_refs = None if self.logits_processors is None else {
logit_processor_refs = (None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
for lp in self.logits_processors if not hasattr(lp[0], "fsm")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This affects the following line and as described in the following issue and partially addressed by the following PR. Unfortunately, in this case, the copy of logits_processors is inevitable and seems like it indeed does slow down the inference. However, I'm using a relatively old GPU (4x NVIDIA RTX A4000, 16Gb) so I would need someone with better hardware to benchmark the speed. Refer to the issue for the code snippets to reproduce the tests

})
return copy.deepcopy(self, memo=logit_processor_refs)

def __repr__(self) -> str:
Expand Down
Loading