diff --git a/mii/batching/generation/stop_criterion.py b/mii/batching/generation/stop_criterion.py index 53209d9b..80270cb3 100644 --- a/mii/batching/generation/stop_criterion.py +++ b/mii/batching/generation/stop_criterion.py @@ -30,7 +30,7 @@ class TokenStopCriterion(BaseGenerationStopCriterion): def __init__(self, token: Union[str, int], tokenizer) -> None: super().__init__(tokenizer=tokenizer) if isinstance(token, str): - token_id = self.tokenizer.encode(token)[0] + token_id = self.tokenizer.convert_tokens_to_ids(token) else: token_id = token self.stop_token_id = token_id diff --git a/mii/batching/postprocess.py b/mii/batching/postprocess.py index 9a23f65d..1b8ff6e7 100644 --- a/mii/batching/postprocess.py +++ b/mii/batching/postprocess.py @@ -76,4 +76,6 @@ def run_batch_stop_criterion(next_tokens: torch.Tensor, Any]) -> torch.Tensor: stop_fns = {k: v for k, v in processor_map.items() if "Stop" in k} done_tokens = run_batch_processing(next_tokens, requests, stop_fns) + done_tokens = torch.any(done_tokens.view((len(requests), -1)), dim=1) + return done_tokens diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index c703fb46..b25a12a3 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -437,17 +437,19 @@ def make_request(self, stop = generate_params.stop if stop != []: - stop_name = "_".join([STOP_NAME] + stop) - if stop_name not in self._post_processors: - self._post_processors[stop_name] = TokenStopCriterion( - token=stop, - tokenizer=self.tokenizer) + for each_stop in stop: + stop_name = STOP_NAME + '_' + each_stop + if stop_name not in self._post_processors: + self._post_processors[stop_name] = TokenStopCriterion( + token=each_stop, + tokenizer=self.tokenizer) + post_processing.append(stop_name) else: stop_name = STOP_NAME if STOP_NAME not in self._post_processors: self._post_processors[stop_name] = EosGenerationStopCriterion( tokenizer=self.tokenizer) - post_processing.append(stop_name) + post_processing.append(stop_name) return Request( tid=tid, diff --git a/mii/modeling/tokenizers.py b/mii/modeling/tokenizers.py index 46190759..9cb21a84 100644 --- a/mii/modeling/tokenizers.py +++ b/mii/modeling/tokenizers.py @@ -60,6 +60,9 @@ def eos_token_id(self) -> int: def encode(self, input: str) -> torch.Tensor: return self.tokenizer.encode(input, return_tensors="pt").flatten() + def convert_tokens_to_ids(self, input: str) -> int: + return self.tokenizer.convert_tokens_to_ids(input) + def decode(self, tokens: torch.Tensor) -> str: return self.tokenizer.decode(tokens) diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 03a18c87..529e233d 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -110,13 +110,6 @@ def test_do_sample(deployment, query): ), "do_sample=False should always return the same output" -def test_stop_token(deployment, query): - pytest.skip("not working yet") - output = deployment(query, stop=".", max_length=512) - print(str(output.response)) - assert str(output.response[0]).endswith("."), "output should end with 'the'" - - def test_return_full_text(deployment, query): outputs = deployment(query, max_length=128, return_full_text=True) assert outputs[0].generated_text.startswith(query), "output should start with the prompt"