Skip to content

Commit

Permalink
add request length for traces
Browse files Browse the repository at this point in the history
clean up

bug fix: error bound relaxation
  • Loading branch information
Le Xu committed Jan 16, 2025
1 parent 0e9dd75 commit 4f541cb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 41 deletions.
14 changes: 7 additions & 7 deletions benchmarks/generator/sample_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def sample_sharegpt_requests_len_range(
input_len = input_lens[i]
output_len = output_lens[i]
err_perc = initial_err_perc
while err_perc >= 0:

while err_perc < 1:
input_range = range(0, sys.maxsize)
output_range = range(0, sys.maxsize)
if input_len is not None:
input_range = (int(input_len * err_perc), int(input_len * (1 + err_perc)))
input_range = (int(input_len * (1 - err_perc)), int(input_len * (1 + err_perc)))
if output_len is not None:
output_range = (int(output_len * err_perc), int(output_len * (1 + err_perc)))

output_range = (int(output_len * (1 - err_perc)), int(output_len * (1 + err_perc)))
filtered = df[
(df["prompt_len"] >= input_range[0]) &
(df["prompt_len"] <= input_range[1]) &
Expand All @@ -105,10 +105,10 @@ def sample_sharegpt_requests_len_range(
break # Stop relaxing for this request once a match is found

# Reduce err_perc for next iteration
logging.warn(f"Relax err_perc {err_perc} by {err_step}")
err_perc -= err_step
logging.debug(f"Relax err_perc {err_perc} by {err_step} new err_perc {err_perc + err_step} input_range {input_range} output_range {output_range}")
err_perc += err_step

if err_perc < 0:
if err_perc >= 1:
raise Exception(f"No match found for request {i + 1} even after relaxing err_perc to 0")

return filtered_results
52 changes: 49 additions & 3 deletions benchmarks/generator/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
import logging
import json
import os
import csv

import numpy as np
import matplotlib.pyplot as plt

from typing import List, Union, Any, Optional
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from datetime import datetime

def get_sample_interval_ms(file_path):
# Initialize variables
timestamps = []

# Read the file and extract the first two timestamps
with open(file_path, 'r') as file:
reader = csv.DictReader(file)
for row in reader:
if 'Time' in row and row['Time']:
# Parse the timestamp
timestamps.append(datetime.strptime(row['Time'], "%Y-%m-%d %H:%M:%S"))
# Stop after reading the first two timestamps
if len(timestamps) == 2:
break

# Calculate the interval in milliseconds
interval = None
if len(timestamps) == 2:
interval = int((timestamps[1] - timestamps[0]).total_seconds() * 1000)
logging.info(f"Sampling interval: {interval} milliseconds")
else:
logging.error("Insufficient data to calculate the sampling interval.")
return interval


def make_serializable(data):
Expand Down Expand Up @@ -43,7 +69,7 @@ def plot_workload(workload_dict, interval_ms, output_file: str = None):
"""
fig, ax = plt.subplots()
for workload_name, workload in workload_dict.items():
concurrency_values = [len(item) for (_, item) in workload]
concurrency_values = [len(item["requests"]) for item in workload]
ax.plot(np.arange(len(concurrency_values)) * interval_ms, concurrency_values, label=workload_name)

ax.set_ylim(0, )
Expand All @@ -55,8 +81,28 @@ def plot_workload(workload_dict, interval_ms, output_file: str = None):
plt.show()
else:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
plt.savefig(output_file)
logging.info(f'Saved workload plot to {output_file}')
plt.savefig(f"{output_file}-traffic.pdf")
logging.info(f'Saved traffic plot to {output_file}-traffic.pdf')


fig, ax = plt.subplots()
for workload_name, workload in workload_dict.items():
input_lengths = [item["requests"][0]['prompt_length'] for item in workload]
output_lengths = [item["requests"][0]['output_length'] for item in workload]
ax.plot(np.arange(len(concurrency_values)) * interval_ms, input_lengths, label=f"{workload_name} prompt_length")
ax.plot(np.arange(len(concurrency_values)) * interval_ms, output_lengths, label=f"{workload_name} output_length")

ax.set_ylim(0, )
plt.xlabel('Time (ms)')
plt.ylabel('Lengths')
plt.title('Request Sizes')
plt.legend()
if output_file is None:
plt.show()
else:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
plt.savefig(f"{output_file}-requests.pdf")
logging.info(f'Saved traffic plot to {output_file}-requests.pdf')


def save_workload(load_struct: List[Any],
Expand Down
40 changes: 9 additions & 31 deletions benchmarks/generator/workload_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import PreTrainedTokenizerBase
from datetime import timedelta
from sample_request import (load_sharegpt_requests, sample_sharegpt_requests_len_range)
from utils import (get_tokenizer, plot_workload, make_serializable, save_workload)
from utils import (get_tokenizer, plot_workload, make_serializable, save_workload, get_sample_interval_ms)

# Set up logging to print only warning and above level messages
logging.basicConfig(level=logging.INFO)
Expand All @@ -19,7 +19,6 @@
def generate_from_internal_csv(file_path: str,
prompt_file_path: str,
duration_ms: int,
summary_interval_ms: int,
tokenizer: PreTrainedTokenizerBase,
interval_ms: int = 1000,
output_file: str = 'output/output',
Expand All @@ -30,6 +29,7 @@ def generate_from_internal_csv(file_path: str,
traffic = []
input_lengths = []
output_lengths = []
sample_interval_ms = get_sample_interval_ms(file_path)
with open(file_path, 'r') as file:
reader = csv.DictReader(file)
for row in reader:
Expand All @@ -55,51 +55,30 @@ def generate_from_internal_csv(file_path: str,
output_lengths.append(round(float(length)))

workload = []
# base = 0
ts = 0

print(f"input_lengths size {len(input_lengths)} output_lengths size {len(output_lengths)}")
sharegpt_df = load_sharegpt_requests(dataset_path=prompt_file_path, tokenizer=tokenizer)
for i, interval_requests in enumerate(traffic):
mean_rate = round(interval_requests / (summary_interval_ms / interval_ms))
mean_rate = round(interval_requests * (interval_ms / 1000))
input_length = input_lengths[i] if len(input_lengths)>0 else None
output_length = output_lengths[i] if len(output_lengths)>0 else None
for ts_delta in list(range(0, summary_interval_ms, interval_ms)):
#concurrent_reqs = [(req_id, input_length, output_length) for req_id in range(base, base + mean_rate)]
for ts_delta in list(range(0, sample_interval_ms, interval_ms)):
concurrent_sampled_reqs = sample_sharegpt_requests_len_range(
df=sharegpt_df,
num_requests=mean_rate,
input_lens=[input_length] * mean_rate, #[input_length for _ in range(base, base + mean_rate)],
output_lens=[output_length] * mean_rate, #[output_length for _ in range(base, base + mean_rate)],
input_lens=[input_length] * mean_rate,
output_lens=[output_length] * mean_rate,
initial_err_perc=0.5,
err_step=0.05
)
if concurrent_sampled_reqs: # Only add non-empty groups
workload.append({"timestamp": ts + ts_delta, "requests": concurrent_sampled_reqs})
else:
print(f"sampled return {concurrent_sampled_reqs}")
#workload.append((ts + ts_delta, concurrent_reqs))
#base += mean_rate
ts += summary_interval_ms
logging.error(f"sampled return {concurrent_sampled_reqs}")
ts += sample_interval_ms
if ts > duration_ms:
break

# grouped_requests = []

# for ts, reqs in workload:
# sampled_requests = sample_sharegpt_requests_len_range(
# df=sharegpt_df,
# num_requests=len(reqs),
# input_lens=[req[1] for req in reqs],
# output_lens=[req[2] for req in reqs],
# initial_err_perc=0.5,
# err_step=0.05
# )
# grouped_requests.append({"timestamp": ts, "requests": sampled_requests})

print(f"head {workload[0]}")
typename = type(workload[0]["requests"])
print(f"value type {typename}")
workload = make_serializable(workload)
save_workload(workload, output_file, use_jsonl=to_jsonl)
return workload
Expand Down Expand Up @@ -345,7 +324,6 @@ def pair_requests_with_prompts_round_robin(workload: List[List[Any]],
generated_workload = generate_from_internal_csv(file_path=args.traffic_file,
prompt_file_path=args.prompt_file,
duration_ms=args.duration_ms,
summary_interval_ms=15000,
tokenizer=tokenizer,
interval_ms=args.interval_ms,
output_file=f"{args.output_dir}/{args.trace_type}",
Expand All @@ -368,4 +346,4 @@ def pair_requests_with_prompts_round_robin(workload: List[List[Any]],

if workload_dict:
# Plot the workloads
plot_workload(workload_dict, interval_ms=args.interval_ms, output_file=f"plot/{args.trace_type}.pdf")
plot_workload(workload_dict, interval_ms=args.interval_ms, output_file=f"{args.output_dir}/{args.trace_type}")

0 comments on commit 4f541cb

Please sign in to comment.