diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index f408b52b8..c0314ad7a 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -41,6 +41,18 @@ def __init__(self, self.log_prob = log_prob self.special_token = special_token + def as_dict(self): + output = {} + if self.id: + output["id"] = self.id + if self.text: + output["text"] = self.text + if self.log_prob: + output["log_prob"] = self.log_prob + if self.special_token: + output["special_token"] = self.special_token + return output + def _json_output_formatter(token: Token, first_token: bool, last_token: bool, details: dict): @@ -68,13 +80,8 @@ def _jsonlines_output_formatter(token: Token, first_token: bool, :return: formatted output """ - token_dict = token.__dict__ - # backwards compatible to V5 - final_dict = { - "token": token_dict, - "details": None, - "outputs": [token.text] - } + token_dict = token.as_dict() + final_dict = {"token": token_dict} if last_token and details: final_dict["details"] = { "finish_reason": details.get("finish_reason", None) @@ -133,7 +140,7 @@ def set_next_token(self, if isinstance(next_token, str): next_token = Token(-1, next_token) if self.token_cache is not None: - self.token_cache.append(next_token.__dict__) + self.token_cache.append(next_token.as_dict()) details = {} if last_token and self.token_cache is not None: details["finish_reason"] = finish_reason