Skip to content

Commit

Permalink
add: response cache strategy
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Fan <fany@buaa.edu.cn>
  • Loading branch information
FuryMartin committed Sep 20, 2024
1 parent 143eaf3 commit 48c2aa6
Show file tree
Hide file tree
Showing 16 changed files with 497 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def run(self):
def _inference(self, job):
# Ianvs API
inference_dataset = self.dataset.load_data(self.dataset.test_data_info, "inference")
inference_output_dir = os.path.join(self.workspace, "output/inference/")
inference_output_dir = os.path.join(os.path.dirname(self.workspace), "output/inference/")
os.environ["RESULT_SAVED_URL"] = inference_output_dir
os.makedirs(inference_output_dir, exist_ok=True)

results = []

cloud_count, edge_count = 0,0
pbar = tqdm(inference_dataset.x, desc="Inference: ")
pbar = tqdm(inference_dataset.x, ncols=100)

for data in pbar:
# inference via sedna JointInference API
Expand Down
16 changes: 11 additions & 5 deletions examples/cloud-edge-collaborative-inference-for-llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,17 @@ Run the following command:
After several seconds, you will see the following output:

```bash
+------+---------------+-----+-----------+----------------+-----------+------------+---------------------+--------------------------+-------------------+------------------------+---------------------+---------------------+--------------------------------------------------------------------------------+
| rank | algorithm | acc | edge-rate | paradigm | basemodel | apimodel | hard_example_mining | basemodel-model | basemodel-backend | basemodel-quantization | apimodel-model | time | url |
+------+---------------+-----+-----------+----------------+-----------+------------+---------------------+--------------------------+-------------------+------------------------+---------------------+---------------------+--------------------------------------------------------------------------------+
| 1 | query-routing | 1.0 | 0.4 | jointinference | EdgeModel | CloudModel | BERT | Qwen/Qwen2-1.5B-Instruct | huggingface | full | gpt-4o-mini | 2024-08-17 01:02:45 | ./workspace/benchmarkingjob/query-routing/493d14ea-5bf1-11ef-bf9b-755996a48c84 |
+------+---------------+-----+-----------+----------------+-----------+------------+---------------------+--------------------------+-------------------+------------------------+---------------------+---------------------+--------------------------------------------------------------------------------+
+------+---------------+----------+--------------+---------------------+------------+------------------------+---------------------+-------------------------+--------------------+------------------------+----------------+-----------+------------+---------------------+------------------------+-------------------+------------------------+-----------------------+-----------------+---------------------+------------------------------+--------------------------------+----------------------------------+------------------+------------------------+------------------+----------------------+-------------------------------+------------------------------+-------------------------------+---------------------+-------------------------------------------------------------------------------------+
| rank | algorithm | Accuracy | Rate to Edge | Time to First Token | Throughput | Internal Token Latency | Cloud Prompt Tokens | Cloud Completion Tokens | Edge Prompt Tokens | Edge Completion Tokens | paradigm | edgemodel | cloudmodel | hard_example_mining | edgemodel-model | edgemodel-backend | edgemodel-quantization | edgemodel-temperature | edgemodel-top_p | edgemodel-max_token | edgemodel-repetition_penalty | edgemodel-tensor_parallel_size | edgemodel-gpu_memory_utilization | cloudmodel-model | cloudmodel-temperature | cloudmodel-top_p | cloudmodel-max_token | cloudmodel-repetition_penalty | hard_example_mining-model | hard_example_mining-threshold | time | url |
+------+---------------+----------+--------------+---------------------+------------+------------------------+---------------------+-------------------------+--------------------+------------------------+----------------+-----------+------------+---------------------+------------------------+-------------------+------------------------+-----------------------+-----------------+---------------------+------------------------------+--------------------------------+----------------------------------+------------------+------------------------+------------------+----------------------+-------------------------------+------------------------------+-------------------------------+---------------------+-------------------------------------------------------------------------------------+
| 1 | query-routing | 0.664 | 0.24 | 0.48 | 127.07 | 0.008 | 85816 | 26718 | 63031 | 7853 | jointinference | EdgeModel | CloudModel | BERTRouter | Qwen/Qwen2-7B-Instruct | vllm | full | 0.8 | 0.8 | 512 | 1.05 | 4 | 0.8 | gpt-4o-mini | 0.8 | 0.8 | 512 | 1.05 | routellm/bert_mmlu_augmented | 0.4 | 2024-09-12 16:40:45 | ./workspace-mmlu/benchmarkingjob/query-routing/edf31729-70dc-11ef-8910-d79f5fadb467 |
| 2 | query-routing | 0.663 | 0.04 | 0.577 | 149.51 | 0.007 | 136952 | 37144 | 10429 | 1629 | jointinference | EdgeModel | CloudModel | BERTRouter | Qwen/Qwen2-7B-Instruct | vllm | full |
| 2 | query-routing | 0.663 | 0.04 | 0.577 | 149.51 | 0.007 | 136952 | 37144 | 10429 | 1629 | jointinference | EdgeModel | CloudModel | BERTRouter | Qwen/Qwen2-7B-Instruct | vllm | full | 0.8 | 0.8 | 512 | 1.05 | 4 | 0.8 | gpt-4o-mini | 0.8 | 0.8 | 512 | 1.05 | routellm/bert_mmlu_augmented | 0.5 | 2024-09-12 16:21:29 | ./workspace-mmlu/benchmarkingjob/query-routing/edf31728-70dc-11ef-8910-d79f5fadb467 |
| 3 | query-routing | 0.656 | 0.0 | 0.608 | 303.15 | 0.007 | 147122 | 39000 | 0 | 0 | jointinference | EdgeModel | CloudModel | CloudOnly | Qwen/Qwen2-7B-Instruct | vllm | full | 0.8 | 0.8 | 512 | 1.05 | 4 | 0.8 | gpt-4o-mini | 0.8 | 0.8 | 512 | 1.05 | | | 2024-09-12 15:35:47 | ./workspace-mmlu/benchmarkingjob/query-routing/7225e02c-70d6-11ef-8910-d79f5fadb467 |
| 4 | query-routing | 0.642 | 0.44 | 0.457 | 118.63 | 0.008 | 59321 | 17952 | 90514 | 13952 | jointinference | EdgeModel | CloudModel | BERTRouter | Qwen/Qwen2-7B-Instruct | vllm | full | 0.8 | 0.8 | 512 | 1.05 | 4 | 0.8 | gpt-4o-mini | 0.8 | 0.8 | 512 | 1.05 | routellm/bert_mmlu_augmented | 0.3 | 2024-09-12 16:58:13 | ./workspace-mmlu/benchmarkingjob/query-routing/edf3172a-70dc-11ef-8910-d79f5fadb467 |
| 5 | query-routing | 0.631 | 0.66 | 0.26 | 112.43 | 0.009 | 35797 | 6962 | 115233 | 20447 | jointinference | EdgeModel | CloudModel | BERTRouter | Qwen/Qwen2-7B-Instruct | vllm | full | 0.8 | 0.8 | 512 | 1.05 | 4 | 0.8 | gpt-4o-mini | 0.8 | 0.8 | 512 | 1.05 | routellm/bert_mmlu_augmented | 0.2 | 2024-09-12 17:10:35 | ./workspace-mmlu/benchmarkingjob/query-routing/edf3172b-70dc-11ef-8910-d79f5fadb467 |
| 6 | query-routing | 0.604 | 1.0 | 0.069 | 104.44 | 0.01 | 0 | 0 | 152489 | 30175 | jointinference | EdgeModel | CloudModel | EdgeOnly | Qwen/Qwen2-7B-Instruct | vllm | full | 0.8 | 0.8 | 512 | 1.05 | 4 | 0.8 | gpt-4o-mini | 0.8 | 0.8 | 512 | 1.05 | | | 2024-09-12 15:00:23 | ./workspace-mmlu/benchmarkingjob/query-routing/906363fa-70d3-11ef-8910-d79f5fadb467 |
+------+---------------+----------+--------------+---------------------+------------+------------------------+---------------------+-------------------------+--------------------+------------------------+----------------+-----------+------------+---------------------+------------------------+-------------------+------------------------+-----------------------+-----------------+---------------------+------------------------------+--------------------------------+----------------------------------+------------------+------------------------+------------------+----------------------+-------------------------------+------------------------------+-------------------------------+---------------------+-------------------------------------------------------------------------------------+
```

Ianvs will output a `rank.csv` and `selected_rank.csv` in `ianvs/workspace`, which will record the test results of each test.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ benchmarkingjob:
# currently the options of value are as follows:
# 1> "all": select all modules in the leaderboard;
# 2> modules in the leaderboard, e.g., "basemodel"
modules: [ "all" ]
modules: [ "hard_example_mining" ]
# currently the options of value are as follows:
# 1> "all": select all hyperparameters in the leaderboard;
# 2> hyperparameters in the leaderboard, e.g., "momentum"
hyperparameters: [ "all" ]
hyperparameters: [ "edgemodel-model", "edgemodel-backend", "cloudmodel-model"]
# currently the options of value are as follows:
# 1> "all": select all metrics in the leaderboard;
# 2> metrics in the leaderboard, e.g., "f1_score"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ class CloudModel:
def __init__(self, **kwargs):
# The API KEY and API URL are confidential data and should not be written in yaml.

self.client = APIBasedLLM(**kwargs)
self.model = APIBasedLLM(**kwargs)

self.client.load(model = kwargs.get("model", "gpt-4o-mini"))
self.model.load(model = kwargs.get("model", "gpt-4o-mini"))

def inference(self, data, input_shape=None, **kwargs):
return self.client.inference(data)
return self.model.inference(data)

def cleanup(self):
self.model.save_cache()
self.model.cleanup()
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ def predict(self, data, input_shape=None, **kwargs):
return answer_list

def cleanup(self):
self.model.save_cache()
self.model.cleanup()
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _text_classification_postprocess(self, result):
return False if label == "LABEL_0" else True

def _predict(self, data):
print(data)
# result = self.classifier(data)
if self.task == "text-classification":
result = self.classifier(data, top_k=None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _infer(self, question, system_prompt):
internal_token_latency = sum(internal_token_latency) / len(internal_token_latency)
throughput = 1 / internal_token_latency

return self._format_response(
response = self._format_response(
text,
prompt_tokens,
completion_tokens,
Expand All @@ -66,6 +66,8 @@ def _infer(self, question, system_prompt):
throughput
)

return response

if __name__ == '__main__':
llm = APIBasedLLM(model="gpt-4o-mini")
data = ["你好吗?介绍一下自己"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import time
from functools import wraps
import os
import json

class BaseLLM:
def __init__(self, **kwargs) -> None:
self.config = kwargs
self.parse_kwargs(**kwargs)

self.is_cache_loaded = False

def load(self):
raise NotImplementedError

Expand All @@ -14,16 +16,38 @@ def parse_kwargs(self, **kwargs):
self.top_p = kwargs.get("top_p", 0.8)
self.repetition_penalty = kwargs.get("repetition_penalty", 1.05)
self.max_tokens = kwargs.get("max_tokens", 512)
self.use_cache = kwargs.get("use_cache", True)

def cache_response(self, question, cache_dir):
with open(cache_dir, "r") as f:
f.write(question)
self.kwargs

def inference(self, data):
if isinstance(data, list):
return [self._infer(line) for line in data]

elif isinstance(data, str):
return self._infer(data)
return self._infer(data)

elif isinstance(data, dict):
self.validate_input(data)
# from viztracer import VizTracer
# import sys
# with VizTracer(output_file="optional.json") as tracer:
question, system_prompt = self.parse_input(data)
return self._infer(question, system_prompt)

if self.use_cache:
response = self.try_cache(question, system_prompt)
if response is not None:
return response

response = self._infer(question, system_prompt)
if self.use_cache:
self._update_cache(question, system_prompt, response)

# sys.exit(0)
return response

else:
raise ValueError(f"DataType {type(data)} is not supported, it must be `list` or `str` or `dict`")

Expand All @@ -49,6 +73,7 @@ def validate_input(self, data):
raise ValueError(f"Missing Key 'prompts' in data, data should have format like {expected_format}")

def parse_input(self,data):
self.validate_input(data)
# data should have format like:
# {"question":"Lorem", "prompt": {infer_system_prompt:"Lorem"}}
question = data.get("question")
Expand Down Expand Up @@ -83,5 +108,54 @@ def _format_response(self, text, prompt_tokens, completion_tokens, time_to_first

return resposne

def _load_cache(self):
self.cache = None
self.cache_hash = {}
self.cache_models = []

cache_file = os.path.join(os.environ["RESULT_SAVED_URL"], "cache.json")
if os.path.exists(cache_file):
with open(cache_file, "r", encoding="utf-8") as f:
self.cache_models = json.load(f)
for cache in self.cache_models:
if cache["config"] == self.config:
self.cache = cache
self.cache_hash = {(item["question"], item["system_prompt"]):item['response'] for item in cache["result"]}
self.is_cache_loaded = True

def try_cache(self, question, system_prompt):

if not self.is_cache_loaded:
self._load_cache()

return self.cache_hash.get((question, system_prompt), None)

def _update_cache(self, question, system_prompt, response):

if not self.is_cache_loaded:
self._load_cache()

new_item = {
"question": question,
"system_prompt": system_prompt,
"response": response
}

self.cache_hash[(question, system_prompt)] = response

if self.cache is not None:
self.cache["result"].append(new_item)
else:
self.cache = {"config": self.config, "result": [new_item]}
self.cache_models.append(self.cache)

def save_cache(self):

cache_file = os.path.join(os.environ["RESULT_SAVED_URL"], "cache.json")

if self.is_cache_loaded:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(self.cache_models, f, indent=4)

def cleanup(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _infer(self, question, system_prompt):

model_inputs = self.tokenizer([text], return_tensors="pt").to(device)

streamer = TextIteratorStreamer(self.tokenizer)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)

generation_kwargs = dict(
model_inputs,
Expand Down Expand Up @@ -75,16 +75,18 @@ def _infer(self, question, system_prompt):
generated_text += chunk
completion_tokens += 1

text = generated_text
text = generated_text.replace("<|im_end|>", "")
prompt_tokens = len(model_inputs.input_ids[0])
internal_token_latency = sum(internal_token_latency) / len(internal_token_latency)
throughput = 1 / internal_token_latency

return self._format_response(
response = self._format_response(
text,
prompt_tokens,
completion_tokens,
time_to_first_token,
internal_token_latency,
throughput
)
)

return response
Loading

0 comments on commit 48c2aa6

Please sign in to comment.