Skip to content

Commit

Permalink
allow TGI compat to work with output token ids (#1900)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored May 13, 2024
1 parent ab4cdd0 commit 22b2849
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Token(object):
"""

def __init__(self,
id: List[int],
id: Union[List[int], int],
text: str,
log_prob: float = None,
special_token: bool = None):
Expand All @@ -49,6 +49,8 @@ def as_dict(self):
output = {}
if self.id:
output["id"] = self.id
if TGI_COMPAT:
output["id"] = self.id[0]
if self.text:
output["text"] = self.text
if self.log_prob:
Expand Down
27 changes: 24 additions & 3 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_json_fmt_with_appending(self):
print(req.get_next_token(), end='')
assert req.get_next_token() == ' world"}'

def test_json_fmt_hf_compat(self):
def test_fmt_hf_compat(self):
djl_python.rolling_batch.rolling_batch.TGI_COMPAT = True

req = Request(0,
Expand All @@ -64,6 +64,7 @@ def test_json_fmt_hf_compat(self):
"max_new_tokens": 256,
"return_full_text": True,
},
details=True,
output_formatter=_json_output_formatter)

final_str = []
Expand All @@ -78,8 +79,28 @@ def test_json_fmt_hf_compat(self):
final_json = json.loads(''.join(final_str))
print(final_json, end='')
assert final_json == [{
"generated_text":
"This is a wonderful dayHello world",
'generated_text': 'This is a wonderful dayHello world',
'details': {
'finish_reason':
'length',
'generated_tokens':
3,
'inputs':
'This is a wonderful day',
'tokens': [{
'id': 244,
'text': 'He',
'log_prob': -0.334532
}, {
'id': 576,
'text': 'llo',
'log_prob': -0.123123
}, {
'id': 4558,
'text': ' world',
'log_prob': -0.567854
}]
}
}]
djl_python.rolling_batch.rolling_batch.TGI_COMPAT = False

Expand Down

0 comments on commit 22b2849

Please sign in to comment.