Skip to content

Commit

Permalink
[RollingBatch] add customized rollingbatch (#1468)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Jan 11, 2024
1 parent 63954bb commit dca98df
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __init__(self, **kwargs):
self.output_formatter = _jsonlines_output_formatter
elif "none" == formatter:
pass
elif callable(formatter):
self.output_formatter = formatter
else:
# TODO: allows to load custom formatter from a module
logging.warning(f"Unsupported formatter: {formatter}")
Expand Down
51 changes: 50 additions & 1 deletion engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import unittest
from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter
from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter, \
RollingBatch


class TestRollingBatch(unittest.TestCase):
Expand Down Expand Up @@ -113,6 +114,54 @@ def test_details(self):
}
}

def test_custom_fmt(self):

def custom_fmt(token: Token, first_token: bool, last_token: bool,
details: dict, generated_tokens: str):
result = {"token_id": token.id, "token_text": token.text}
if last_token:
result["finish_reason"] = details["finish_reason"]
return json.dumps(result) + "\n"

class CustomRB(RollingBatch):

def preprocess_requests(self, requests):
pass

def postprocess_results(self):
pass

def inference(self, input_data, parameters):
pass

rb = CustomRB(output_formatter=custom_fmt)

req = Request(0, "This is a wonderful day", {
"max_new_tokens": 256,
"details": True
})
final_str = []
req.set_next_token(Token(244, "He", -0.334532), rb.output_formatter)
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
'token_id': 244,
'token_text': 'He'
}
req.set_next_token(Token(576, "llo", -0.123123), rb.output_formatter)
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
'token_id': 576,
'token_text': 'llo'
}
req.set_next_token(Token(4558, " world", -0.567854),
rb.output_formatter, True, 'length')
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
'token_id': 4558,
'token_text': ' world',
'finish_reason': 'length'
}


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def run_server(self):
if self.tensor_parallel_degree:
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
prop["device_id"] = self.device_id
if "output_formatter" in prop and hasattr(
self.service, prop["output_formatter"]):
prop["output_formatter"] = getattr(self.service,
prop["output_formatter"])
function_name = inputs.get_function_name()
try:
outputs = self.service.invoke_handler(function_name, inputs)
Expand Down

0 comments on commit dca98df

Please sign in to comment.