diff --git a/benchmarks/generator/README.md b/benchmarks/generator/README.md index 0142ccbc..0d766d65 100644 --- a/benchmarks/generator/README.md +++ b/benchmarks/generator/README.md @@ -9,12 +9,27 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r export SHAREGPT_FILE_PATH=/tmp/ShareGPT_V3_unfiltered_cleaned_split.json ``` +### Generate a workload file based with constant target QPS (synthetic patterns) + +```shell +export TARGET_QPS=1 + +python workload_generator.py --prompt-file $SHAREGPT_FILE_PATH --interval-ms 1000 --duration-ms 300000 --target-qps $ta --trace-type constant --model "Qwen/Qwen2.5-Coder-7B-Instruct" --output-dir "output" --output-format jsonl +``` + ### Generate a workload file based on workload patterns (synthetic patterns) -If no trace file path is specified, the generator will generate workload file based on 4 synthetic pattern described [here](https://github.com/aibrix/aibrix/blob/main/benchmarks/autoscaling/bench_workload_generator.py): +The can generate workload file based on synthetic traffic (qps), input lengths (prompt lengths) and output lengths (completion lengths) patterns. Currently we support 4 patterns (`'quick_rising`, `'slow_rising'`, `'slight_fluctuation'`, `'severe_fluctuation'`), described [here](https://github.com/aibrix/aibrix/blob/main/benchmarks/autoscaling/bench_workload_generator.py).: +```shell +python workload_generator.py --prompt-file $SHAREGPT_FILE_PATH --interval-ms 1000 --duration-ms 300000 --trace-type synthetic --traffic-pattern "slight_fluctuation" --prompt-len-pattern "slight_fluctuation" --completion-len-pattern "slight_fluctuation" --model "Qwen/Qwen2.5-Coder-7B-Instruct" --output-dir "./output" --output-format jsonl +``` + +Alternatively, you could specify fluctuation patterns in .json file and pass to the generator like the following. Example configuration files are under `config` directory. ```shell -python workload_generator.py --prompt-file $SHAREGPT_FILE_PATH --num-prompts 100 --interval-ms 1000 --duration-ms 600000 --trace-type synthetic --model "Qwen/Qwen2.5-Coder-7B-Instruct" --output-dir "output" +python workload_generator.py --prompt-file $SHAREGPT_FILE_PATH --interval-ms 1000 --duration-ms 1400000 --trace-type synthetic --traffic-pattern-config config/traffic-config.json --prompt-len-pattern-config config/prompt-len-config.json --completion-len-pattern-config config/completion-len-config.json --model "Qwen/Qwen2.5-Coder-7B-Instruct" --output-dir "./output" --output-format jsonl ``` + + Here `--interval-ms` specifies the granularity of concurrent dispatched requests (in milliseconds). `--duration-ms` specifies the total length of the trace in milliseconds. The file would be stored under `output` folder based on the name of different patterns. And the plot illustrates the workload pattern will be under the `plot` directory. @@ -22,15 +37,48 @@ The file would be stored under `output` folder based on the name of different pa ## Generate a workload file based on internal load summary .csv file ```shell -export SUMMARY_FILE=${PATH_TO_SUMMARY_FILE} -python workload_generator.py --prompt-file $SHAREGPT_FILE_PATH --num-prompts 100 --interval-ms 1000 --duration-ms 600000 --trace-type internal --traffic-file "$SUMMARY_FILE" --model "Qwen/Qwen2.5-Coder-7B-Instruct" --output-dir "output" +export TRAFFIC_FILE=${PATH_TO_TRAFFIC_FILE} +export PROMPT_LEN_FILE=${PATH_TO_PROMPT_LEN_FILE} +export COMPLETION_LEN_FILE=${PATH_TO_COMPLETION_LEN_FILE} + +python workload_generator.py --prompt-file $SHAREGPT_FILE_PATH --interval-ms 1000 --duration-ms 1800000 --trace-type internal --traffic-file "$TRAFFIC_FILE" --prompt-len-file "$PROMPT_LEN_FILE" --completion-len-file "$COMPLETION_LEN_FILE" --model "Qwen/Qwen2.5-Coder-7B-Instruct" --output-dir "./output" --output-format jsonl --qps-scale 1.0 --output-scale 1.0 --input-scale 1.0 --internal-trace-type "maas" ``` -This generator assumes trace file to be in the following format +The scaling factor here (e.g., `qps-scale`) scale down rate from the original trace to the desired rate, i.e., if the peak rate in the original file is 80 and the desired peak rate is 8, the scale is set to 10.0. + +### `maas` trace type +- With `maas` trace type, the generator assumes the `$TRAFFIC_FILE` to be in the following format ``` "Time","Total","Success","4xx Error" 2024-10-1 00:00:00,100,99,1 ``` + +- `"$PROMPT_LEN_FILE"` to be in the following format +``` +"Time","P50","P70","P90","P99" +``` + +- `"$PROMPT_LEN_FILE"` to be in the following format +``` +"Time","P50","P70","P95","P99" +``` + +### `cloudide` trace type +- With `cloudide` trace type, the generator assumes the `$TRAFFIC_FILE` to be in the following format -- `"Rate"` column could have arbitrary names. +``` +"Time","Rate" +``` + +- `"$PROMPT_LEN_FILE"` to be in the following format +``` +"Time","recv_bytes","sent_bytes" +``` + +- `"$PROMPT_LEN_FILE"` to be in the following format +``` +"Time","recv_bytes","sent_bytes" +``` + ### Indicate the length of prompt/completion In this case, you can also indicate the request's prompt length by the `--prompt-len-file` config, or the output length by the `--completion-len-file`, based on the parameters, the generator will select the proper length in the prompt_file to simulate the length of the real flow's load. diff --git a/benchmarks/generator/config/completion-len-config.json b/benchmarks/generator/config/completion-len-config.json new file mode 100644 index 00000000..b8607f4f --- /dev/null +++ b/benchmarks/generator/config/completion-len-config.json @@ -0,0 +1,8 @@ +{ + "A": 8, + "B": 169, + "sigma": 0.1, + "period": 0.005, + "omega": null, + "only_rise": false +} \ No newline at end of file diff --git a/benchmarks/generator/config/prompt-len-config.json b/benchmarks/generator/config/prompt-len-config.json new file mode 100644 index 00000000..cd430e06 --- /dev/null +++ b/benchmarks/generator/config/prompt-len-config.json @@ -0,0 +1,8 @@ +{ + "A": 15, + "B": 309, + "sigma": 0.1, + "period": 0.005, + "omega": null, + "only_rise": false +} \ No newline at end of file diff --git a/benchmarks/generator/config/traffic-config.json b/benchmarks/generator/config/traffic-config.json new file mode 100644 index 00000000..c05e78fc --- /dev/null +++ b/benchmarks/generator/config/traffic-config.json @@ -0,0 +1,8 @@ +{ + "A": 2, + "B": 6, + "sigma": 0.1, + "period": 1, + "omega": null, + "only_rise": false +} \ No newline at end of file diff --git a/benchmarks/generator/sample_request.py b/benchmarks/generator/sample_request.py index b00c6852..1a0d9ded 100644 --- a/benchmarks/generator/sample_request.py +++ b/benchmarks/generator/sample_request.py @@ -109,7 +109,7 @@ def sample_requests_len_range( output_len = output_lens[i] err_perc = initial_err_perc - while err_perc < 1: + while err_perc <= 1: input_range = range(0, sys.maxsize) output_range = range(0, sys.maxsize) if input_len is not None: @@ -126,7 +126,6 @@ def sample_requests_len_range( (df["completion_len"] >= output_range[0]) & (df["completion_len"] <= output_range[1]) ] - if not filtered.empty: # Select the first match or random sample total_rows = len(filtered) @@ -141,7 +140,12 @@ def sample_requests_len_range( err_perc += err_step if err_perc >= 1: - raise Exception(f"No match found for request {i + 1} even after relaxing err_perc to 0") + logging.warn(f"No match found for request {i + 1} even after relaxing err_perc to {err_perc} fallback to random") + total_rows = len(df) + sample = df.iloc[random.randint(0, total_rows - 1)] + filtered_results.append({"prompt": sample["prompt"], + "prompt_length": sample["prompt_len"], + "output_length": sample["completion_len"]}) return filtered_results diff --git a/benchmarks/generator/utils.py b/benchmarks/generator/utils.py index 0bd2b723..2a1383b0 100644 --- a/benchmarks/generator/utils.py +++ b/benchmarks/generator/utils.py @@ -5,12 +5,120 @@ import numpy as np import matplotlib.pyplot as plt +import pandas as pd -from typing import List, Union, Any, Optional +from typing import List, Union, Any, Optional, Tuple, Dict from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) from datetime import datetime +def convert_to_stat_df(qps_file: str, + input_file: str, + output_file: str, + internal_trace_type: str) -> pd.DataFrame: + if internal_trace_type == "maas": + # Load CSV files into DataFrames + qps_df = pd.read_csv(qps_file) + input_len_df = pd.read_csv(input_file) + output_len_df = pd.read_csv(output_file) + + # Rename columns for merging and clarity + input_len_df.rename(columns={"P50": "input_len_p50", "P70": "input_len_p70", "P90": "input_len_p90", "P99": "input_len_p99"}, inplace=True) + output_len_df.rename(columns={"P50": "output_len_p50", "P70": "output_len_p70", "P95": "output_len_p90", "P99": "output_len_p99"}, inplace=True) + qps_df.rename(columns={"Success": "qps_success"}, inplace=True) + + # Merge DataFrames on the 'Time' column (now renamed to 'timestamp') + merged_df = pd.merge(input_len_df, output_len_df, on="Time") + merged_df = pd.merge(merged_df, qps_df, on="Time") + + # Drop unwanted columns (if needed) + merged_df.drop(columns=["Total", "5xx Error", "4xx Error"], inplace=True) + + # Rename the 'Time' column to 'timestamp' + merged_df.rename(columns={"Time": "timestamp"}, inplace=True) + + # Rearrange columns to match the desired order + merged_df = merged_df[[ + "timestamp", + "input_len_p50", "input_len_p70", "input_len_p90", "input_len_p99", + "output_len_p50", "output_len_p70", "output_len_p90", "output_len_p99", + "qps_success" + ]] + merged_df['timestamp'] = pd.to_datetime(merged_df['timestamp']) + elif internal_trace_type == "cloudide": + if input_file != output_file: + logging.error(f"input file {input_file} does not match output_file {output_file}") + df = pd.read_csv(input_file, parse_dates=['Time']) + df = df.replace("undefined", 0) + df['Time'] = pd.to_datetime(df['Time'], unit = 'ms') # Ensure timestamp is a datetime object + df = df.set_index('Time') # Set 'Time' as index for rolling window calculation + df_rate = pd.read_csv(qps_file, parse_dates=['Time']) + df_rate.columns.values[1] = "Rate" + df_rate = df_rate.replace("undefined", 0) + df_rate['Time'] = pd.to_datetime(df_rate['Time'], unit = 'ms') + df_rate = df_rate.set_index('Time') + + sent_columns = df.filter(regex = r'^sent_bytes.rate@') + sent_columns = sent_columns.apply(pd.to_numeric, errors='coerce').fillna(0) + df['sent'] = sent_columns.sum(axis = 1) + + recv_columns = df.filter(regex = r'^recv_bytes.rate@') + recv_columns = recv_columns.apply(pd.to_numeric, errors='coerce').fillna(0) + df['recv'] = recv_columns.sum(axis = 1) + + df_merged = pd.merge(df, df_rate, left_index=True, right_index=True, how='outer') + df_merged = df_merged.fillna(0) + df_merged = df_merged.apply(pd.to_numeric, errors='coerce').fillna(0) + + df_merged['sent_rate'] = df_merged.apply(lambda row : 0 if row['Rate'] == 0 else row['sent'] / row['Rate'], axis=1) + df_merged['recv_rate'] = df_merged.apply(lambda row : 0 if row['Rate'] == 0 else row['recv'] / row['Rate'], axis=1) + + df_merged = df_merged.reset_index() + merged_df = pd.DataFrame({ + "timestamp": df_merged['Time'], + "input_len_p50": df_merged['recv_rate'], + "input_len_p70": df_merged['recv_rate'], + "input_len_p90": df_merged['recv_rate'], + "input_len_p99": df_merged['recv_rate'], + "output_len_p50": df_merged['sent_rate'], + "output_len_p70": df_merged['sent_rate'], + "output_len_p90": df_merged['sent_rate'], + "output_len_p99": df_merged['sent_rate'], + "qps_success":df_merged['Rate'], + }) + return merged_df + +def read_distribution_stats(df: pd.DataFrame) -> Tuple[List[Dict], List[Dict], List[Dict]]: + time_diffs = df['timestamp'].diff().dt.total_seconds() + section_in_seconds = int(time_diffs.mean()) # Use average time difference + input_len_configs = [] + output_len_configs = [] + rps_configs = [] + for _, row in df.iterrows(): + input_len_configs.append({ + "p50": float(row['input_len_p50']), + "p70": float(row['input_len_p70']), + "p90": float(row['input_len_p90']), + "p99": float(row['input_len_p99']), + "period": section_in_seconds, + "total_seconds": section_in_seconds + }) + output_len_configs.append({ + "p50": float(row['output_len_p50']), + "p70": float(row['output_len_p70']), + "p90": float(row['output_len_p90']), + "p99": float(row['output_len_p99']), + "period": section_in_seconds, + "total_seconds": section_in_seconds + }) + rps_configs.append({ + "mean_rps": float(row['qps_success']), + "amplitude": float(row['qps_success']) * 0.2, # 20% variation + "period": section_in_seconds, + "total_seconds": section_in_seconds + }) + return input_len_configs, output_len_configs, rps_configs + def get_sample_interval_ms(file_path): # Initialize variables timestamps = [] @@ -59,50 +167,69 @@ def get_tokenizer( trust_remote_code=trust_remote_code) -def plot_workload(workload_dict, interval_ms, output_file: str = None): +def plot_workload(workload_name: str, + workload: str, + bin_size_sec: int = 1, + output_dir: str = None): """ - Plots the concurrency (item length) of the generated workload. + Plots workload statistics: total requests, prompt token count, and output token count binned by time. Args: - workload_dict (dict): A dictionary where the keys are workload names (labels) and the values are lists of lists representing the workload. - interval_ms (int): Interval in milliseconds. + workload_name (str): Name of the workload. + workload (list of dict): Workload entries with timestamps and request details. + bin_size_sec (int): Size of each bin in seconds for aggregation. + output_file (str, optional): File path to save the plot. """ - fig, ax = plt.subplots() - for workload_name, workload in workload_dict.items(): - concurrency_values = [len(item["requests"]) for item in workload] - ax.plot(np.arange(len(concurrency_values)) * interval_ms, concurrency_values, label=workload_name) - - ax.set_ylim(0, ) - plt.xlabel('Time (ms)') - plt.ylabel('Concurrency') - plt.title('Workload Concurrency') - plt.legend() - if output_file is None: - plt.show() + print(f"plot_workload in directory {output_dir}") + # Convert workload data to a DataFrame + data = [] + for entry in workload: + timestamp_sec = entry["timestamp"] / 1000 # Convert ms to sec + num_requests = len(entry["requests"]) + total_prompt_tokens = np.mean([req["prompt_length"] for req in entry["requests"]]) if entry["requests"] else 0 + total_output_tokens = np.mean([req["output_length"] for req in entry["requests"]]) if entry["requests"] else 0 + data.append((timestamp_sec, num_requests, total_prompt_tokens, total_output_tokens)) + + df = pd.DataFrame(data, columns=["timestamp", "num_requests", "total_prompt_tokens", "total_output_tokens"]) + + # Define bins based on min/max timestamp + min_time, max_time = df["timestamp"].min(), df["timestamp"].max() + bins = np.arange(min_time, max_time + bin_size_sec, bin_size_sec) + + # Bin the data + df["time_bin"] = pd.cut(df["timestamp"], bins, labels=bins[:-1]) + + # Aggregate within each bin + binned_df = df.groupby("time_bin").sum() + + # Convert index back to numeric + binned_df.index = binned_df.index.astype(float) + + # Plotting + fig, (ax_qps, ax_input, ax_output) = plt.subplots(3, 1, figsize=(15, 12)) + + ax_qps.plot(binned_df.index, binned_df["num_requests"], label="Total Requests") + ax_input.plot(binned_df.index, binned_df["total_prompt_tokens"], label="Total Prompt Tokens") + ax_output.plot(binned_df.index, binned_df["total_output_tokens"], label="Total Output Tokens") + + # Formatting plots + for ax, ylabel, title in zip([ax_qps, ax_input, ax_output], + ["Requests per Second", "Prompt Token Count", "Output Token Count"], + ["Total Requests Sent per Second", "Total Prompt Tokens per Second", "Total Output Tokens per Second"]): + ax.set_xlabel("Time (seconds)") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + plt.tight_layout() + + # Save or show the plot + if output_dir: + os.makedirs(os.path.dirname(output_dir), exist_ok=True) + plt.savefig(f"{output_dir}/{workload_name}.pdf") + logging.info(f'Saved workload plot to {output_dir}/{workload_name}.pdf') else: - os.makedirs(os.path.dirname(output_file), exist_ok=True) - plt.savefig(f"{output_file}-traffic.pdf") - logging.info(f'Saved traffic plot to {output_file}-traffic.pdf') - - - fig, ax = plt.subplots() - for workload_name, workload in workload_dict.items(): - input_lengths = [item["requests"][0]['prompt_length'] for item in workload] - output_lengths = [item["requests"][0]['output_length'] for item in workload] - ax.plot(np.arange(len(concurrency_values)) * interval_ms, input_lengths, label=f"{workload_name} prompt_length") - ax.plot(np.arange(len(concurrency_values)) * interval_ms, output_lengths, label=f"{workload_name} output_length") - - ax.set_ylim(0, ) - plt.xlabel('Time (ms)') - plt.ylabel('Lengths') - plt.title('Request Sizes') - plt.legend() - if output_file is None: plt.show() - else: - os.makedirs(os.path.dirname(output_file), exist_ok=True) - plt.savefig(f"{output_file}-requests.pdf") - logging.info(f'Saved traffic plot to {output_file}-requests.pdf') def save_workload(load_struct: List[Any], @@ -133,6 +260,10 @@ def load_workload(input_path: str) -> List[Any]: load_struct = json.load(file) return load_struct +def load_config(config_path: str) -> Dict[str, Any]: + with open(config_path, "r") as file: + config = json.load(file) + return config # Function to wrap the prompt into OpenAI's chat completion message format. def wrap_prompt_as_chat_message(prompt: str): diff --git a/benchmarks/generator/workload_generator.py b/benchmarks/generator/workload_generator.py index 95d1ed5d..6406c39f 100644 --- a/benchmarks/generator/workload_generator.py +++ b/benchmarks/generator/workload_generator.py @@ -1,88 +1,134 @@ import logging import math import random -import pandas as pd import argparse -import csv +import time +import pandas as pd +import numpy as np from pandas import Timedelta from typing import List, Tuple, Dict, Any from transformers import PreTrainedTokenizerBase from datetime import timedelta -from sample_request import (load_requests, sample_requests_len_range, sample_requests_all) -from utils import (get_tokenizer, plot_workload, make_serializable, save_workload, get_sample_interval_ms) +from sample_request import (load_requests, + sample_requests_len_range, + sample_requests_all, + ) +from distribution import (generate_poisson_dist, + generate_token_len_from_percentiles, + to_fluctuate_pattern_config, + ) + +from utils import (convert_to_stat_df, + read_distribution_stats, + get_tokenizer, + plot_workload, + make_serializable, + load_config, + save_workload, + ) # Set up logging to print only warning and above level messages logging.basicConfig(level=logging.INFO) -def generate_from_internal_csv(file_path: str, - prompt_file_path: str, - duration_ms: int, - tokenizer: PreTrainedTokenizerBase, - interval_ms: int = 1000, - output_file: str = 'output/output', - input_trace: str = None, - output_trace: str = None, - to_jsonl: bool = False, - ) -> List[List[Any]]: - traffic = [] - input_lengths = [] - output_lengths = [] - sample_interval_ms = get_sample_interval_ms(file_path) - with open(file_path, 'r') as file: - reader = csv.DictReader(file) - for row in reader: - if 'Total' in row: - total_value = row['Total'] - if total_value: - traffic.append(float(total_value)) - if input_trace is not None: - with open(input_trace, 'r') as file: - reader = csv.DictReader(file) - for row in reader: - if 'P50' in row: - length = row['P50'] - if length: - input_lengths.append(round(float(length))) - if output_trace is not None: - with open(output_trace, 'r') as file: - reader = csv.DictReader(file) - for row in reader: - if 'P50' in row: - length = row['P50'] - if length: - output_lengths.append(round(float(length))) - - workload = [] - ts = 0 - sharegpt_df = load_requests(dataset_path=prompt_file_path, tokenizer=tokenizer) - for i, interval_requests in enumerate(traffic): - mean_rate = round(interval_requests * (interval_ms / 1000)) - input_length = input_lengths[i] if len(input_lengths)>0 else None - output_length = output_lengths[i] if len(output_lengths)>0 else None - for ts_delta in list(range(0, sample_interval_ms, interval_ms)): - concurrent_sampled_reqs = sample_requests_len_range( - df=sharegpt_df, - num_requests=mean_rate, - input_lens=[input_length] * mean_rate, - output_lens=[output_length] * mean_rate, - initial_err_perc=0.5, - err_step=0.05 - ) - if concurrent_sampled_reqs: # Only add non-empty groups - workload.append({"timestamp": ts + ts_delta, "requests": concurrent_sampled_reqs}) - else: - logging.error(f"sampled return {concurrent_sampled_reqs}") - ts += sample_interval_ms - if ts > duration_ms: - break +def generate_from_internal_csv(prompt_file_path: str, + duration_ms: int, + tokenizer: PreTrainedTokenizerBase, + output_file: str = 'output/output', + qps_stat: str = None, + input_stat: str = None, + output_stat: str = None, + to_jsonl: bool = False, + qps_scale: float = 1.0, + input_scale: float = 1.0, + output_scale: float = 1.0, + internal_trace_type: str = 'maas', + ) -> Dict[str, Any]: + merged_df = convert_to_stat_df(qps_stat, input_stat, output_stat, internal_trace_type) + input_len_configs, output_len_configs, rps_configs = read_distribution_stats(merged_df) + input_len_dist = [] + output_len_dist = [] + rps_dist = [] + for rps_config in rps_configs: + rps_segment = generate_poisson_dist(target = rps_config['mean_rps'], sample_size = rps_config['total_seconds'], generate_poisson_dist = 10) + rps_dist.extend(rps_segment) + if internal_trace_type == "maas": + for config in input_len_configs: + config['scale'] = input_scale + input_segment = generate_token_len_from_percentiles(**config) + input_len_dist.extend(input_segment) + for config in output_len_configs: + config['scale'] = output_scale + output_segment = generate_token_len_from_percentiles(**config) + output_len_dist.extend(output_segment) + elif internal_trace_type == "cloudide": + for config in input_len_configs: + config['scale'] = input_scale + input_segment = generate_token_len_from_percentiles(**config) + input_len_dist.extend(input_segment) + output_segment = generate_token_len_from_percentiles(**config) + output_len_dist.extend(output_segment) + workload = generate_synthetic_from_dist( + prompt_file_path = prompt_file_path, + tokenizer = tokenizer, + duration_ms = duration_ms, + rps_dist = rps_dist, + input_token_len_dist = input_len_dist, + output_token_len_dist = output_len_dist, + qps_scale = qps_scale, + input_scale = input_scale, + output_scale = output_scale, + ) workload = make_serializable(workload) save_workload(workload, output_file, use_jsonl=to_jsonl) return workload - + +def generate_synthetic_from_dist( + prompt_file_path: str, + tokenizer: PreTrainedTokenizerBase, + duration_ms: int, + rps_dist: List[int], + input_token_len_dist: List[int], + output_token_len_dist: List[int], + qps_scale: float, + input_scale: float, + output_scale: float, + ) -> List[Dict[str, Any]]: + + if not (len(rps_dist) == len(input_token_len_dist) == len(output_token_len_dist)): + raise ValueError(f"All distributions must have the same length, len(rps_dist): {len(rps_dist)}, len(input_token_len_dist): {len(input_token_len_dist)}, len(output_token_len_dist): {len(output_token_len_dist)}") + workload = [] + current_time = 0 + total_seconds = len(rps_dist) + ts = time.time() + prompt_df = load_requests(dataset_path=prompt_file_path, tokenizer=tokenizer) + logging.info(f"Load requests took {int(time.time() - ts)}s") + while current_time < total_seconds * 1000: + time_idx = int(current_time / 1000) + if time_idx >= total_seconds: + time_idx = total_seconds - 1 + current_rate = rps_dist[time_idx] / qps_scale + current_input_len = input_token_len_dist[time_idx] / input_scale + current_output_len = output_token_len_dist[time_idx] / output_scale + inter_arrival_time = 1000 if current_rate == 0 else np.random.exponential(scale=1000/current_rate) + current_time += inter_arrival_time + if current_time < total_seconds * 1000: + request = sample_requests_len_range( + df=prompt_df, + num_requests=1, + input_lens=[current_input_len], + output_lens=[current_output_len], + initial_err_perc=0.5, + err_step=0.05 + ) + workload.append({"timestamp": int(current_time), "requests": request}) + if current_time > duration_ms: + break + + return workload def generate_constant(prompt_file_path: str, qps: int, @@ -121,12 +167,9 @@ def generate_constant(prompt_file_path: str, return workload def generate_synthetic(prompt_file_path: str, - A=1, B=1, - sigma=0.1, - only_rise: bool = False, - omega: float = None, - period=0.25, - length: int = None, + qps_pattern_config: Dict[str, Any], + input_pattern_config: Dict[str, Any], + output_pattern_config: Dict[str, Any], duration_ms: int = None, interval_ms: int = None, output_file: str = 'output/output', @@ -156,7 +199,7 @@ def generate_synthetic(prompt_file_path: str, list: A list of items, where each item is a list of requests to be sent concurrently. """ - def math_function(t): + def math_function(t, pattern_config, length, prev_value): """ Calculates the concurrency value based on the given concurrency function. @@ -171,39 +214,45 @@ def math_function(t): Returns: int: The concurrency value rounded to the nearest integer. """ - trend = A * math.sin(omega * t) + B - noise = random.gauss(0, sigma) - return round(trend + noise) - - assert length is not None or (duration_ms is not None and interval_ms is not None), \ - "duration_ms and interval_ms must be specified if length is not None" - if length is None: - length = int(duration_ms // interval_ms) + 1 - assert omega is not None or period is not None, "period must be specified if length is not None" - if omega is None: - omega = 2 * math.pi / (length / period) + assert length is not None, \ + "length cannot be None" + if pattern_config['omega'] is None: + omega = 2 * math.pi / (length / pattern_config['period']) + trend = pattern_config['A'] * math.sin(omega * t) + pattern_config['B'] + noise = random.gauss(0, pattern_config['sigma']) + current_value = round(trend + noise) + if pattern_config['only_rise']: + current_value = max(prev_value, current_value) + prev_value = current_value + return current_value, prev_value + + assert duration_ms is not None and interval_ms is not None, \ + "duration_ms and interval_ms must be specified." + length = int(duration_ms // interval_ms) + 1 workload = [] t = 0 previous_concurrency = -1 - end_index = 0 + previous_input_len = -1 + previous_output_len = -1 ts = 0 base_req_id = 0 sharegpt_df = load_requests(dataset_path=prompt_file_path, tokenizer=tokenizer) while t < length: - current_concurrency = math_function(t) - if only_rise: - current_concurrency = max(previous_concurrency, current_concurrency) - previous_concurrency = current_concurrency - + current_concurrency, previous_concurrency = math_function(t, qps_pattern_config, length, previous_concurrency) + current_input_len, previous_input_len = math_function(t, input_pattern_config, length, previous_input_len) + current_output_len, previous_output_len = math_function(t, output_pattern_config, length, previous_output_len) + current_concurrency_pois = generate_poisson_dist(target = current_concurrency, sample_size = 1) + current_input_len_pois = generate_poisson_dist(target = current_input_len, sample_size = 1) + current_output_len_pois = generate_poisson_dist(target = current_output_len, sample_size = 1) # start from last end index - end_index += current_concurrency + logging.debug(f"search requests for current_concurrency {current_concurrency} : {current_concurrency_pois} input_lens {current_input_len} : {current_input_len_pois} output_lens {current_output_len} : {current_output_len_pois}") concurrent_reqs = sample_requests_len_range( df=sharegpt_df, num_requests=current_concurrency, - input_lens=[None] * current_concurrency, - output_lens=[None] * current_concurrency, - initial_err_perc=0.5, + input_lens=[current_input_len] * current_concurrency, + output_lens=[current_output_len] * current_concurrency, + initial_err_perc=0.1, err_step=0.05 ) workload.append({"timestamp": ts, "requests": concurrent_reqs}) @@ -304,73 +353,118 @@ def pair_requests_with_prompts_round_robin(workload: List[List[Any]], if __name__ == '__main__': parser = argparse.ArgumentParser(description='Workload Generator') - parser.add_argument('--prompt-file', type=str, required=True, help='File containing prompts.') + parser.add_argument('--prompt-file', type=str, required=True, help='File containing sampling prompts.') parser.add_argument('--trace-type', type=str, required=True, choices=['constant','synthetic', 'internal', 'azure'], - help='Type of trace consumed. Choose among: synthetic, internal, azure') - parser.add_argument('--traffic-file', type=str, required=False, default=None, - help='Traffic file containing times of arrival, which workload generator depends upon to ' - 'convert to traffic used in workload. This is only needed for for internal and azure trace type. ') - parser.add_argument('--prompt-len-file', type=str, required=False, default=None, - help='File containing request input lengths varied by time, which workload generator depends upon to ' - 'select input prompt. This is only needed for for internal trace type. ') - parser.add_argument('--completion-len-file', type=str, required=False, default=None, - help='File containing request output lengths varied by time, which workload generator depends upon to ' - 'select input prompt. This is only needed for for internal trace type. ') + help='Type of trace consumed. Choose among: synthetic, internal, azure.') parser.add_argument('--model', type=str, required=False, default="Qwen/Qwen2.5-Coder-7B-Instruct", help='Target model tokenizer.') - parser.add_argument('--group-interval-seconds', type=int, default=1, help='Grouping interval seconds.') parser.add_argument('--interval-ms', type=int, required=False, default=1000, help='Granularity of request injection interval in milliseconds.') parser.add_argument('--duration-ms', type=int, default=60000, help='Duration of the trace generated.') - parser.add_argument('--output-dir', type=str, required=False, default="output", help='Output directory to save ' + parser.add_argument('--group-interval-seconds', type=int, default=1, help='Grouping interval seconds.') + parser.add_argument('--internal-trace-type', type=str, choices=['maas', 'cloudide'], default="maas", help='Type of internal traces.') + parser.add_argument('--output-dir', type=str, required=False, default="output", help='Output directory to save.' 'the workload.') parser.add_argument('--output-format', type=str, choices=['json', 'jsonl'], default='json', help='Set output data format to either .json or .jsonl (default is .json).') + + ###### Synthetic and constant workload + parser.add_argument('--target-qps', type=int, required=False, default=1, help='Target QPS for the workload.') + parser.add_argument('--traffic-pattern', type=str, required=False, choices=['quick_rising', 'slow_rising', 'slight_fluctuation', 'severe_fluctuation'], default=None, + help='Traffic patterns used for synthetic workload type.') + parser.add_argument('--prompt-len-pattern', type=str, required=False, choices=['quick_rising', 'slow_rising', 'slight_fluctuation', 'severe_fluctuation'], default=None, + help='Prompt lengths patterns used for synthetic workload type.') + parser.add_argument('--completion-len-pattern', type=str, required=False, choices=['quick_rising', 'slow_rising', 'slight_fluctuation', 'severe_fluctuation'], default=None, + help='Prompt lengths patterns used for synthetic workload type.') + parser.add_argument('--traffic-pattern-config', type=str, required=False, default=None, + help='Traffic configuration file used for synthetic workload type.') + parser.add_argument('--prompt-len-pattern-config', type=str, required=False, default=None, + help='Prompt lengths configuration file used for synthetic workload type.') + parser.add_argument('--completion-len-pattern-config', type=str, required=False, default=None, + help='Completion lengths configuration file used for synthetic workload type.') + + ##### Trace and stats-driven workload + parser.add_argument('--traffic-file', type=str, required=False, default=None, + help='Traffic file containing times of arrival, which workload generator depends upon to' + 'convert to traffic used in workload. This is only needed for for internal and azure trace type.') + parser.add_argument('--prompt-len-file', type=str, required=False, default=None, + help='File containing request input lengths varied by time, which workload generator depends upon to ' + 'select input prompt. This is only needed for for internal trace type. ') + parser.add_argument('--completion-len-file', type=str, required=False, default=None, + help='File containing request output lengths varied by time, which workload generator depends upon to ' + 'select input prompt. This is only needed for for internal trace type. ') + parser.add_argument('--qps-scale', type=float, required=False, default=1.0, help='QPS scaling factor.') + parser.add_argument('--input-scale', type=float, required=False, default=1.0, help='Input length scaling factor.') + parser.add_argument('--output-scale', type=float, required=False, default=1.0, help='Output length scaling factor.') + args = parser.parse_args() # Generate workloads and pair with prompts workload_dict = {} tokenizer = get_tokenizer(pretrained_model_name_or_path=args.model, trust_remote_code=True) - if args.trace_type == "constant": - generated_workload = generate_constant(prompt_file_path=args.prompt_file, - qps=1, - duration_ms=args.duration_ms, - interval_ms=args.interval_ms, - output_file=f"{args.output_dir}/{args.trace_type}", - to_jsonl=(args.output_format == "jsonl"), + if args.trace_type == "synthetic": + if args.traffic_pattern and args.prompt_len_pattern and args.completion_len_pattern: + logging.info(f"Generating synthetic workload with traffic pattern: {args.traffic_pattern}, prompt length pattern: {args.prompt_len_pattern}, completion length pattern: {args.completion_len_pattern}") + comp_pattern_type = f"synthetic_QPS_{args.traffic_pattern}_INPUT_{args.prompt_len_pattern}_OUTPUT_{args.completion_len_pattern}" + qps_pattern_config = to_fluctuate_pattern_config(config_type = args.traffic_pattern, mean = 6) + input_pattern_config = to_fluctuate_pattern_config(config_type = args.prompt_len_pattern, mean = 1024) + output_pattern_config = to_fluctuate_pattern_config(config_type = args.completion_len_pattern, mean = 1024) + logging.debug(f"qps_pattern_config {qps_pattern_config}") + logging.debug(f"input_pattern_config {input_pattern_config}") + logging.debug(f"output_pattern_config {output_pattern_config}") + generated_workload = generate_synthetic(prompt_file_path = args.prompt_file, + qps_pattern_config = qps_pattern_config, + input_pattern_config = input_pattern_config, + output_pattern_config = output_pattern_config, + duration_ms=args.duration_ms, + interval_ms=args.interval_ms, + output_file=f"{args.output_dir}/{comp_pattern_type}", + to_jsonl=(args.output_format == "jsonl"), + ) + workload_dict[comp_pattern_type] = generated_workload + elif args.traffic_pattern_config and args.prompt_len_pattern_config and args.completion_len_pattern_config: + logging.info(f"Generating synthetic workload with traffic pattern config: {args.traffic_pattern_config}, prompt length pattern config: {args.prompt_len_pattern_config}, completion length pattern config: {args.completion_len_pattern_config}") + comp_pattern_type = f"synthetic_manual_config" + qps_pattern_config = load_config(args.traffic_pattern_config) + input_pattern_config = load_config(args.prompt_len_pattern_config) + output_pattern_config = load_config(args.completion_len_pattern_config) + logging.debug(f"qps_pattern_config {qps_pattern_config}") + logging.debug(f"input_pattern_config {input_pattern_config}") + logging.debug(f"output_pattern_config {output_pattern_config}") + generated_workload = generate_synthetic(prompt_file_path = args.prompt_file, + qps_pattern_config = qps_pattern_config, + input_pattern_config = input_pattern_config, + output_pattern_config = output_pattern_config, + duration_ms=args.duration_ms, + interval_ms=args.interval_ms, + output_file=f"{args.output_dir}/{comp_pattern_type}", + to_jsonl=(args.output_format == "jsonl"), ) - elif args.trace_type == "synthetic": - # Define scenarios specific to synthetic type - scenarios = { - 'quick_rising': {'duration_ms': args.duration_ms, 'interval_ms': args.interval_ms, 'A': 5, 'period': 5, - 'only_rise': True}, - 'slow_rising': {'duration_ms': args.duration_ms, 'interval_ms': args.interval_ms, 'A': 5, 'period': 0.25, - 'only_rise': True}, - 'slight_fluctuation': {'duration_ms': args.duration_ms, 'interval_ms': args.interval_ms, 'A': 5, 'B': 5, - 'period': 1, 'only_rise': False}, - 'severe_fluctuation': {'duration_ms': args.duration_ms, 'interval_ms': args.interval_ms, 'A': 5, 'B': 10, - 'period': 12, 'only_rise': False}, - } - for scenario_name, params in scenarios.items(): - params['prompt_file_path'] = args.prompt_file - params['output_file'] = f"{args.output_dir}/{scenario_name}" - params['to_jsonl'] = (args.output_format == "jsonl") - for scenario_name, params in scenarios.items(): - generated_workload = generate_synthetic(**params) - workload_dict[scenario_name] = generated_workload + workload_dict[comp_pattern_type] = generated_workload else: # Process for 'internal' and 'azure' - if args.trace_type == "internal": - generated_workload = generate_from_internal_csv(file_path=args.traffic_file, - prompt_file_path=args.prompt_file, + if args.trace_type == "constant": + generated_workload = generate_constant(prompt_file_path=args.prompt_file, + qps=1, + duration_ms=args.duration_ms, + interval_ms=args.interval_ms, + output_file=f"{args.output_dir}/{args.trace_type}", + to_jsonl=(args.output_format == "jsonl"), + ) + elif args.trace_type == "internal": + generated_workload = generate_from_internal_csv(prompt_file_path=args.prompt_file, duration_ms=args.duration_ms, tokenizer=tokenizer, - interval_ms=args.interval_ms, output_file=f"{args.output_dir}/{args.trace_type}", - input_trace=args.prompt_len_file, - output_trace=args.completion_len_file, + qps_stat=args.traffic_file, + input_stat=args.prompt_len_file, + output_stat=args.completion_len_file, to_jsonl=(args.output_format == "jsonl"), + qps_scale=args.qps_scale, + input_scale=args.input_scale, + output_scale=args.output_scale, + internal_trace_type=args.internal_trace_type, ) elif args.trace_type == "azure": @@ -387,4 +481,9 @@ def pair_requests_with_prompts_round_robin(workload: List[List[Any]], if workload_dict: # Plot the workloads - plot_workload(workload_dict, interval_ms=args.interval_ms, output_file=f"{args.output_dir}/{args.trace_type}") + for workload_name, workload in workload_dict.items(): + plot_workload( + workload_name = workload_name, + workload = workload, + bin_size_sec = int(args.interval_ms/1000), + output_dir = f"./plot")