Skip to content

Commit

Permalink
[SD][vLLM] record acceptance (#1586)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Mar 2, 2024
1 parent 06b9a1a commit f511dc1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class VllmRbProperties(Properties):
speculative_draft_model: Optional[str] = None
speculative_length: int = 5
draft_model_tp_size: int = 1
record_acceptance_rate: Optional[bool] = False

@validator('engine')
def validate_engine(cls, engine):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import json
import logging
from collections import OrderedDict

Expand Down Expand Up @@ -150,6 +151,20 @@ def inference(self, input_data: list[str], parameters: list[dict]) -> list:
f"Beam search is not supported yet, use first output by default"
)
self.request_cache[req_id]["finished"] = request_output.finished
# Record SD metrics
completion_output = request_output.outputs[0]
if self.vllm_configs.record_acceptance_rate and request_output.finished and completion_output.acceptance_history:
record = {}
record["id"] = req_id
if len(completion_output.acceptance_history) > 0:
record["mean_acceptance"] = 1.0 * sum(
completion_output.acceptance_history) / len(
completion_output.acceptance_history)
else:
record["mean_acceptance"] = 0
record["prompt_size"] = len(request_output.prompt_token_ids)
record["output_size"] = len(completion_output.token_ids)
logging.info(f"Speculative Decoding {record}")
# step 2: send result back
finished_id = []
for (key, cache), request in zip(self.request_cache.items(),
Expand Down
1 change: 1 addition & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@
"speculative-llama-13b": {
"option.model_id": "TheBloke/Llama-2-13B-fp16",
"option.speculative_draft_model": "s3://djl-llm/tinyllama-1.1b-chat/",
"option.record_acceptance_rate": True,
"option.tensor_parallel_degree": "max",
"option.output_formatter": "jsonlines"
}
Expand Down

0 comments on commit f511dc1

Please sign in to comment.