Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed the issue where the mii.pipeline.pipe(stop) was ineffective #409

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mii/batching/generation/stop_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TokenStopCriterion(BaseGenerationStopCriterion):
def __init__(self, token: Union[str, int], tokenizer) -> None:
super().__init__(tokenizer=tokenizer)
if isinstance(token, str):
token_id = self.tokenizer.encode(token)[0]
token_id = self.tokenizer.convert_tokens_to_ids(token)
else:
token_id = token
self.stop_token_id = token_id
Expand Down
2 changes: 2 additions & 0 deletions mii/batching/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,6 @@ def run_batch_stop_criterion(next_tokens: torch.Tensor,
Any]) -> torch.Tensor:
stop_fns = {k: v for k, v in processor_map.items() if "Stop" in k}
done_tokens = run_batch_processing(next_tokens, requests, stop_fns)
done_tokens = torch.any(done_tokens.view((len(requests), -1)), dim=1)

return done_tokens
14 changes: 8 additions & 6 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,19 @@ def make_request(self,

stop = generate_params.stop
if stop != []:
stop_name = "_".join([STOP_NAME] + stop)
if stop_name not in self._post_processors:
self._post_processors[stop_name] = TokenStopCriterion(
token=stop,
tokenizer=self.tokenizer)
for each_stop in stop:
stop_name = STOP_NAME + '_' + each_stop
if stop_name not in self._post_processors:
self._post_processors[stop_name] = TokenStopCriterion(
token=each_stop,
tokenizer=self.tokenizer)
post_processing.append(stop_name)
else:
stop_name = STOP_NAME
if STOP_NAME not in self._post_processors:
self._post_processors[stop_name] = EosGenerationStopCriterion(
tokenizer=self.tokenizer)
post_processing.append(stop_name)
post_processing.append(stop_name)

return Request(
tid=tid,
Expand Down
3 changes: 3 additions & 0 deletions mii/modeling/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def eos_token_id(self) -> int:
def encode(self, input: str) -> torch.Tensor:
return self.tokenizer.encode(input, return_tensors="pt").flatten()

def convert_tokens_to_ids(self, input: str) -> int:
return self.tokenizer.convert_tokens_to_ids(input)

def decode(self, tokens: torch.Tensor) -> str:
return self.tokenizer.decode(tokens)

Expand Down
7 changes: 0 additions & 7 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,6 @@ def test_do_sample(deployment, query):
), "do_sample=False should always return the same output"


def test_stop_token(deployment, query):
pytest.skip("not working yet")
output = deployment(query, stop=".", max_length=512)
print(str(output.response))
assert str(output.response[0]).endswith("."), "output should end with 'the'"


def test_return_full_text(deployment, query):
outputs = deployment(query, max_length=128, return_full_text=True)
assert outputs[0].generated_text.startswith(query), "output should start with the prompt"
Expand Down