Skip to content

Commit

Permalink
fix: Fixed the issue where the mii.pipeline.pipe(stop) was ineffective (
Browse files Browse the repository at this point in the history
#409)

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
kitstar and mrwyattii authored Feb 13, 2024
1 parent 47ce760 commit 76b9639
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion mii/batching/generation/stop_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TokenStopCriterion(BaseGenerationStopCriterion):
def __init__(self, token: Union[str, int], tokenizer) -> None:
super().__init__(tokenizer=tokenizer)
if isinstance(token, str):
token_id = self.tokenizer.encode(token)[0]
token_id = self.tokenizer.convert_tokens_to_ids(token)
else:
token_id = token
self.stop_token_id = token_id
Expand Down
2 changes: 2 additions & 0 deletions mii/batching/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,6 @@ def run_batch_stop_criterion(next_tokens: torch.Tensor,
Any]) -> torch.Tensor:
stop_fns = {k: v for k, v in processor_map.items() if "Stop" in k}
done_tokens = run_batch_processing(next_tokens, requests, stop_fns)
done_tokens = torch.any(done_tokens.view((len(requests), -1)), dim=1)

return done_tokens
14 changes: 8 additions & 6 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,19 @@ def make_request(self,

stop = generate_params.stop
if stop != []:
stop_name = "_".join([STOP_NAME] + stop)
if stop_name not in self._post_processors:
self._post_processors[stop_name] = TokenStopCriterion(
token=stop,
tokenizer=self.tokenizer)
for each_stop in stop:
stop_name = STOP_NAME + '_' + each_stop
if stop_name not in self._post_processors:
self._post_processors[stop_name] = TokenStopCriterion(
token=each_stop,
tokenizer=self.tokenizer)
post_processing.append(stop_name)
else:
stop_name = STOP_NAME
if STOP_NAME not in self._post_processors:
self._post_processors[stop_name] = EosGenerationStopCriterion(
tokenizer=self.tokenizer)
post_processing.append(stop_name)
post_processing.append(stop_name)

return Request(
tid=tid,
Expand Down
3 changes: 3 additions & 0 deletions mii/modeling/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def eos_token_id(self) -> int:
def encode(self, input: str) -> torch.Tensor:
return self.tokenizer.encode(input, return_tensors="pt").flatten()

def convert_tokens_to_ids(self, input: str) -> int:
return self.tokenizer.convert_tokens_to_ids(input)

def decode(self, tokens: torch.Tensor) -> str:
return self.tokenizer.decode(tokens)

Expand Down
7 changes: 0 additions & 7 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,6 @@ def test_do_sample(deployment, query):
), "do_sample=False should always return the same output"


def test_stop_token(deployment, query):
pytest.skip("not working yet")
output = deployment(query, stop=".", max_length=512)
print(str(output.response))
assert str(output.response[0]).endswith("."), "output should end with 'the'"


def test_return_full_text(deployment, query):
outputs = deployment(query, max_length=128, return_full_text=True)
assert outputs[0].generated_text.startswith(query), "output should start with the prompt"
Expand Down

0 comments on commit 76b9639

Please sign in to comment.