Skip to content

Commit

Permalink
[tnx] support codellama 70b instruct tokenizer (#1653)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Mar 21, 2024
1 parent 21a6d57 commit ce7875c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ def increment_cache_id(self):
def trim_cache_id(self):
self._cache_id = self._cache_id.max()

def is_slot_eos_token(self, token) -> bool:
if hasattr(self._generation_config, "eos_token_id"):
if isinstance(self._generation_config.eos_token_id, int):
return token == self._generation_config.eos_token_id
else:
return token in self._generation_config.eos_token_id
else:
return False

@property
def stopped(self) -> bool:
return self._selector.stopping_criteria(self._tokens, None)
Expand Down Expand Up @@ -381,7 +390,9 @@ def _generate_token(
slot.append(next_token, next_token_text)
generated_text = None
finish_reason = None
if next_token == self.tokenizer.eos_token_id:
if slot.is_slot_eos_token(next_token):
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
elif next_token == self.tokenizer.eos_token_id:
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
elif slot.stopped:
finish_reason = FinishReason.FINISH_REASON_LENGTH
Expand All @@ -403,7 +414,8 @@ def _generate_token(
token_id=next_token,
token_logprob=next_log_prob,
token_text=next_token_text,
token_is_special=(next_token in [self.special_tokens]),
token_is_special=(next_token in [self.special_tokens])
or (finish_reason == FinishReason.FINISH_REASON_EOS_TOKEN),
generated_text=generated_text,
))
return generations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# The below code is heavily inspired from Optimum Neuron under the following link:
# https://github.com/huggingface/optimum-neuron/blob/974f34336bb36b1b64890c191c558a1575372be7/optimum/neuron/generation/token_selector.py
import logging
from typing import Optional
from typing import Optional, Union, List
import torch
from transformers.generation import (
GenerationConfig,
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_id: int,
eos_token_id: Union[List[int], int],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
):
Expand Down Expand Up @@ -139,9 +139,9 @@ def create(cls, input_ids: torch.Tensor,
# The generation requires special tokens
eos_token_id = generation_config.eos_token_id
# This is not supposed to happen for any of the models we support
assert eos_token_id is not None and not isinstance(eos_token_id, list)
if generation_config.pad_token_id is None:
generation_config.pad_token_id = eos_token_id
generation_config.pad_token_id = eos_token_id if isinstance(
eos_token_id, int) else eos_token_id[0]

generation_mode = model._get_generation_mode(generation_config, None)
if generation_mode not in [
Expand Down

0 comments on commit ce7875c

Please sign in to comment.