Skip to content

Commit

Permalink
[python] validate each request in the batch (deepjavalibrary#1008)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Aug 10, 2023
1 parent d6f07f1 commit 12985e9
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 41 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/client-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
run: |
cd tests/binary && python prepare.py pt && cd ..
djl-serving -m test::PyTorch=file://$PWD/binary/model.pt &> output.log &
sleep 30
sleep 40
cd java-client
./gradlew build
cd ../binary && python prepare.py pt --clean && cd ..
Expand All @@ -52,7 +52,7 @@ jobs:
run: |
cd tests/binary && python prepare.py pt && cd ..
djl-serving -m test::PyTorch=file://$PWD/binary/model.pt &> output.log &
sleep 30
sleep 40
python test_binary.py 1,3,224,224 1,1000
cd binary && python prepare.py pt --clean && cd ..
jobs
Expand Down
115 changes: 79 additions & 36 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import json
import logging
import os

Expand Down Expand Up @@ -199,54 +199,88 @@ def initialize(self, properties: dict):

self.initialized = True

def inference(self, inputs):
def parse_input(self, inputs):
input_data = []
input_size = []
parameters = []
errors = {}
batch = inputs.get_batches()
first = True
for i, item in enumerate(batch):
content_type = item.get_property("Content-Type")
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
if isinstance(_inputs, list):
input_data.extend(_inputs)
input_size.append(len(_inputs))
else:
input_data.append(_inputs)
input_size.append(1)
if first or self.rolling_batch_type:
parameters.append(input_map.pop("parameters", {}))
first = False
else:
if parameters != input_map.pop("parameters", {}):
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)
if "cached_prompt" in input_map:
parameters[i]["cached_prompt"] = input_map.pop("cached_prompt")

seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed'
if item.contains_key(seed_key):
seed = parameters[i].get("seed")
if not seed:
# set server provided seed if seed is not part of request
parameters[i]["seed"] = item.get_as_string(key=seed_key)
try:
content_type = item.get_property("Content-Type")
input_map = decode(item, content_type)
_inputs = input_map.pop("inputs", input_map)
if first or self.rolling_batch_type:
parameters.append(input_map.pop("parameters", {}))
first = False
else:
param = input_map.pop("parameters", {})
if parameters[0] != param:
logging.warning(
f"expected param: {parameters}, actual: {param}")
raise ValueError(
"In order to enable dynamic batching, all input batches must have the same parameters"
)
if isinstance(_inputs, list):
input_data.extend(_inputs)
input_size.append(len(_inputs))
else:
input_data.append(_inputs)
input_size.append(1)

if "cached_prompt" in input_map:
parameters[i]["cached_prompt"] = input_map.pop(
"cached_prompt")

seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed'
if item.contains_key(seed_key):
seed = parameters[i].get("seed")
if not seed:
# set server provided seed if seed is not part of request
parameters[i]["seed"] = item.get_as_string(
key=seed_key)
except Exception as e: # pylint: disable=broad-except
logging.exception(f"Parse input failed: {i}")
errors[i] = str(e)

return input_data, input_size, parameters, errors, batch

def inference(self, inputs):
outputs = Output()

input_data, input_size, parameters, errors, batch = self.parse_input(
inputs)
if len(input_data) == 0:
for i in range(len(batch)):
err = errors.get(i)
err = json.dumps({"code": 500, "error": err})
if self.rolling_batch_type:
err = json.dumps({"data": err, "last": True})
outputs.add(err, key="data", batch_index=i)
return outputs

if self.rolling_batch_type:
if inputs.get_property("reset_rollingbatch"):
self.rolling_batch.reset()
result = self.rolling_batch.inference(input_data, parameters)
for i in range(inputs.get_batch_size()):
outputs.add(result[i], key="data", batch_index=i)
idx = 0
for i in range(len(batch)):
err = errors.get(i)
if err:
err = json.dumps({"code": 500, "error": err})
err = json.dumps({"data": err, "last": True})
outputs.add(err, key="data", batch_index=i)
else:
outputs.add(result[idx], key="data", batch_index=i)
idx += 1

content_type = self.rolling_batch.get_content_type()
if content_type:
outputs.add_property("content-type", content_type)
return outputs
elif self.enable_streaming:
# TODO support dynamic batch
outputs.add_property("content-type", "application/jsonlines")
if self.enable_streaming == "huggingface":
outputs.add_stream_content(
Expand All @@ -268,15 +302,24 @@ def inference(self, inputs):
content_type = item.get_property("Content-Type")
accept = item.get_property("Accept")
if not accept:
content_type = content_type if content_type else "application/json"
accept = content_type if content_type.startswith(
"tensor/") else "application/json"
elif "*/*" in accept:
accept = "application/json"
encode(outputs,
prediction[offset:offset + input_size[i]],
accept,
key=inputs.get_content().key_at(i))
offset += input_size[i]

err = errors.get(i)
if err:
encode(outputs,
err,
accept,
key=inputs.get_content().key_at(i))
else:
encode(outputs,
prediction[offset:offset + input_size[i]],
accept,
key=inputs.get_content().key_at(i))
offset += input_size[i]

return outputs

Expand Down Expand Up @@ -451,7 +494,7 @@ def _read_model_config(self, model_config_path: str):
trust_remote_code=self.trust_remote_code)
except Exception as e:
logging.error(
f"{self.model_id_or_path} does not contain a config.json or adapter_config.json for lora models. "
f"{model_config_path} does not contain a config.json or adapter_config.json for lora models. "
f"This is required for loading huggingface models")
raise e

Expand Down
7 changes: 4 additions & 3 deletions engines/python/setup/djl_python/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ def __str__(self):
return cur_str

def is_batch(self) -> bool:
return self.get_batch_size() > 1
return "batch_size" in self.properties

def get_batch_size(self) -> int:
return int(self.properties.get("batch_size", "1"))

def get_batches(self) -> list:
if not self.is_batch():
return [self]

batch_size = self.get_batch_size()
batch = []
for i in range(batch_size):
Expand All @@ -99,8 +102,6 @@ def get_batches(self) -> list:
if key.startswith(prefix):
key = key[length:]
item.properties[key] = value
elif not key.startswith("batch_"):
item.properties[key] = value

batch.append(item)

Expand Down

0 comments on commit 12985e9

Please sign in to comment.