From 1cddbf4382064e90f1bcf6a36d7a03b6e5df7002 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 27 Apr 2023 10:30:15 +0000 Subject: [PATCH 1/3] Add an option to launch cacheflow without ray --- .gitignore | 3 ++ benchmark/benchmark_latency.py | 4 +-- benchmark/benchmark_text_completion.py | 8 +++--- cacheflow/http_frontend/fastapi_frontend.py | 16 +++++++++-- cacheflow/master/server.py | 32 +++++++++++++++++++-- cacheflow/worker/controller.py | 30 +++++++++++++------ simple_server.py | 6 ++-- 7 files changed, 77 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index e6c41a1b61774..471b1f3090e87 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,11 @@ *.egg-info/ *.eggs/ *.so +*.log +*.csv build/ *.pkl *.png **/log.txt +.vscode/ \ No newline at end of file diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index 33ec6c495b15b..12a4efdde032c 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -8,7 +8,7 @@ from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, - initialize_ray_cluster) + initialize_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory @@ -20,7 +20,7 @@ def main(args: argparse.Namespace): (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( - initialize_ray_cluster( + initialize_cluster( address='local', pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) diff --git a/benchmark/benchmark_text_completion.py b/benchmark/benchmark_text_completion.py index e6741d577c6e1..6ab6fc1611b43 100644 --- a/benchmark/benchmark_text_completion.py +++ b/benchmark/benchmark_text_completion.py @@ -11,7 +11,7 @@ from benchmark.trace import generate_text_completion_requests from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, - initialize_ray_cluster) + initialize_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory @@ -25,7 +25,7 @@ def main(args: argparse.Namespace): (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( - initialize_ray_cluster( + initialize_cluster( address='local', pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) @@ -134,7 +134,7 @@ def main(args: argparse.Namespace): finished.append({ 'group_id': seq_group.group_id, 'seq_id': seq.seq_id, - 'arrival_time': arrival_time, + 'arrival_time': arrival_time, 'finish_time': finish_time, 'prompt_len': seq.prompt_len, 'output_len': output_len, @@ -226,7 +226,7 @@ def get_sampling_dir_name( if __name__ == '__main__': parser = argparse.ArgumentParser(description='CacheFlow simple server.') - parser = add_server_arguments(parser) + parser = add_server_arguments(parser) parser.add_argument('--output-dir', type=str, help='path to output directory', default=None) parser.add_argument('--dataset', type=str, help='path to dataset', required=True) diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index 209536310e875..c6155b656eff2 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -13,7 +13,7 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.master.server import (Server, add_server_arguments, - initialize_ray_cluster) + initialize_cluster) from cacheflow.worker.controller import DeviceID from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory @@ -33,17 +33,22 @@ def __init__( seed: int, swap_space: int, max_num_batched_tokens: int, + max_num_sequences: int, num_nodes: int, num_devices_per_node: int, distributed_init_method: str, all_stage_devices: List[List[DeviceID]], + server_use_ray: bool, ): self.block_size = block_size self.tokenizer = AutoTokenizer.from_pretrained(model) self.seq_group_counter = Counter() self.seq_counter = Counter() - remote_server_class = ray.remote(num_cpus=0)(Server) + if server_use_ray: + remote_server_class = ray.remote(num_cpus=0)(Server) + else: + remote_server_class = ray.remote(num_gpus=1)(Server) self.server = remote_server_class.remote( model=model, model_path=model_path, @@ -55,12 +60,14 @@ def __init__( seed=seed, swap_space=swap_space, max_num_batched_tokens=max_num_batched_tokens, + max_num_sequences=max_num_sequences, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), + use_ray=server_use_ray, ) self.running_seq_groups: Dict[int, SequenceGroup] = {} @@ -156,7 +163,8 @@ async def generate_stream(request: Request): (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( - initialize_ray_cluster( + initialize_cluster( + use_ray=True, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) @@ -170,10 +178,12 @@ async def generate_stream(request: Request): seed=args.seed, swap_space=args.swap_space, max_num_batched_tokens=args.max_num_batched_tokens, + max_num_sequences=args.max_num_sequences, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, + server_use_ray=args.use_ray, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 5b8110a3dab4c..fff508dcef101 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -2,7 +2,11 @@ from typing import List, Tuple import random -import ray +import torch +try: + import ray +except ImportError: + ray = None from cacheflow.master.scheduler import Scheduler from cacheflow.models import get_memory_analyzer @@ -31,6 +35,7 @@ def __init__( all_stage_devices: List[List[DeviceID]], gpu_memory: int, cpu_memory: int, + use_ray: bool, collect_stats: bool = False, do_memory_analysis: bool = False, ): @@ -38,6 +43,10 @@ def __init__( self.num_devices_per_node = num_devices_per_node self.world_size = pipeline_parallel_size * tensor_parallel_size + if not use_ray: + assert self.world_size == 1, ( + "Only support single GPU without Ray.") + self.memory_analyzer = get_memory_analyzer( model_name=model, block_size=block_size, @@ -72,6 +81,7 @@ def __init__( model_path=model_path, use_dummy_weights=use_dummy_weights, max_num_batched_tokens=max_num_batched_tokens, + use_ray=use_ray, ) self.controllers.append(controller) @@ -105,11 +115,28 @@ def has_unfinished_requests(self): self.scheduler.swapped) -def initialize_ray_cluster( +def initialize_cluster( + use_ray: bool = False, address: str = 'auto', pipeline_parallel_size: int = 1, tensor_parallel_size: int = 1, ) -> Tuple[int, int, str, List[List[DeviceID]]]: + # Initialize cluster locally. + if not use_ray: + assert pipeline_parallel_size * tensor_parallel_size == 1, ( + "Only support single GPU without Ray.") + num_nodes = 1 + num_devices_per_node = torch.cuda.device_count() + port = random.randint(10000, 20000) + distributed_init_method = f"tcp://localhost:{port}" + all_stage_devices = [[(0, None, 0)]] + return (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) + + assert ray is not None, ( + "Ray is not installed. Please install Ray to use distributed " + "serving.") + # Connect to a ray cluster. ray.init(address=address) @@ -177,6 +204,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', help='model path to download and load the weights') # Parallel arguments + parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed training, required when using more than 1 GPU') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index dce3fddf89fad..7c3d66cc99a93 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -1,6 +1,9 @@ from typing import Dict, List, Union, Tuple -import ray +try: + import ray +except ImportError: + ray = None from cacheflow.master.scheduler import Scheduler from cacheflow.sequence import SequenceGroupInputs @@ -29,6 +32,7 @@ def __init__( model_path: str, use_dummy_weights: bool, max_num_batched_tokens: int, + use_ray: bool, ) -> None: self.stage_id = stage_id self.stage_devices = stage_devices @@ -36,6 +40,7 @@ def __init__( self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks + self.use_ray = use_ray # Which pipeline stage is this node assigned to? self.is_first_stage = stage_id == 0 @@ -43,10 +48,13 @@ def __init__( self.workers: List[Worker] = [] for rank, node_resource, device_id in stage_devices: - worker_cls = ray.remote(num_cpus=0, - num_gpus=1, - resources={node_resource: 1e-5})(Worker) - worker = worker_cls.remote( + if self.use_ray: + worker_cls = ray.remote(num_cpus=0, + num_gpus=1, + resources={node_resource: 1e-5})(Worker).remote + else: + worker_cls = Worker + worker = worker_cls( model_name=model_name, block_size=block_size, num_gpu_blocks=num_gpu_blocks, @@ -78,17 +86,21 @@ def execute_stage( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> None: - futures = [] + all_outputs = [] for worker in self.workers: - future = worker.execute_stage.remote( + executor = (worker.execute_stage.remote + if self.use_ray else worker.execute_stage) + output = executor( input_seq_groups, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, ) - futures.append(future) + all_outputs.append(output) + + if self.use_ray: + all_outputs = ray.get(all_outputs) - all_outputs = ray.get(futures) # Make sure all workers have the same results. output = all_outputs[0] for other_output in all_outputs[1:]: diff --git a/simple_server.py b/simple_server.py index 17df72aff0554..c2f8edd6b2b9a 100644 --- a/simple_server.py +++ b/simple_server.py @@ -3,7 +3,7 @@ from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, - initialize_ray_cluster) + initialize_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory @@ -14,7 +14,8 @@ def main(args: argparse.Namespace): (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( - initialize_ray_cluster( + initialize_cluster( + use_ray=args.use_ray, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) @@ -37,6 +38,7 @@ def main(args: argparse.Namespace): all_stage_devices=all_stage_devices, gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), + use_ray=args.use_ray, ) # Create a frontend. From 1649750a8f93e87db589ca7e5af7b3fb2eacaacf Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 29 Apr 2023 15:36:12 +0000 Subject: [PATCH 2/3] Fix review issues --- .gitignore | 2 +- benchmark/benchmark_latency.py | 2 ++ benchmark/benchmark_text_completion.py | 2 ++ cacheflow/http_frontend/fastapi_frontend.py | 2 ++ cacheflow/master/server.py | 13 ++++++++++--- simple_server.py | 2 ++ 6 files changed, 19 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 471b1f3090e87..25d148565ad83 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ build/ *.pkl *.png **/log.txt -.vscode/ \ No newline at end of file +.vscode/ diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index 12a4efdde032c..5dc5bd51c359d 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -8,6 +8,7 @@ from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, + process_server_arguments, initialize_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory @@ -99,6 +100,7 @@ def profile_step(profile=False): parser.add_argument('--n', type=int, default=1) parser.add_argument('--use-beam-search', action='store_true') args = parser.parse_args() + args = process_server_arguments(args) args.max_num_batched_tokens = max( args.max_num_batched_tokens, args.batch_size * args.input_len) print(args) diff --git a/benchmark/benchmark_text_completion.py b/benchmark/benchmark_text_completion.py index 6ab6fc1611b43..cb361fbd45266 100644 --- a/benchmark/benchmark_text_completion.py +++ b/benchmark/benchmark_text_completion.py @@ -11,6 +11,7 @@ from benchmark.trace import generate_text_completion_requests from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, + process_server_arguments, initialize_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory @@ -246,6 +247,7 @@ def get_sampling_dir_name( parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0) parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0) args = parser.parse_args() + args = process_server_arguments(args) if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0: raise ValueError('The ratios of requests must sum to 1.') diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index c6155b656eff2..a2ae6ce4c8a2d 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -13,6 +13,7 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.master.server import (Server, add_server_arguments, + process_server_arguments, initialize_cluster) from cacheflow.worker.controller import DeviceID from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory @@ -156,6 +157,7 @@ async def generate_stream(request: Request): parser.add_argument("--port", type=int, default=10002) parser = add_server_arguments(parser) args = parser.parse_args() + args = process_server_arguments(args) # TODO(zhuohan): Support pipeline parallelism. assert args.pipeline_parallel_size == 1, ( diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index fff508dcef101..311251800d22b 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -1,5 +1,5 @@ import argparse -from typing import List, Tuple +from typing import List, Tuple, Optional import random import torch @@ -117,7 +117,7 @@ def has_unfinished_requests(self): def initialize_cluster( use_ray: bool = False, - address: str = 'auto', + address: Optional[str] = None, pipeline_parallel_size: int = 1, tensor_parallel_size: int = 1, ) -> Tuple[int, int, str, List[List[DeviceID]]]: @@ -128,6 +128,8 @@ def initialize_cluster( num_nodes = 1 num_devices_per_node = torch.cuda.device_count() port = random.randint(10000, 20000) + # We need to setup the distributed init method to make sure + # the distributed megatron code (e.g., get world size) works correctly. distributed_init_method = f"tcp://localhost:{port}" all_stage_devices = [[(0, None, 0)]] return (num_nodes, num_devices_per_node, distributed_init_method, @@ -204,7 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights', help='model path to download and load the weights') # Parallel arguments - parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed training, required when using more than 1 GPU') + parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments @@ -218,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') return parser + +def process_server_arguments(args: argparse.Namespace): + if args.pipeline_parallel_size * args.tensor_parallel_size > 1: + args.use_ray = True + return args diff --git a/simple_server.py b/simple_server.py index c2f8edd6b2b9a..0c0c7cb4c2335 100644 --- a/simple_server.py +++ b/simple_server.py @@ -3,6 +3,7 @@ from cacheflow.master.simple_frontend import SimpleFrontend from cacheflow.master.server import (Server, add_server_arguments, + process_server_arguments, initialize_cluster) from cacheflow.sampling_params import SamplingParams from cacheflow.utils import get_gpu_memory, get_cpu_memory @@ -72,4 +73,5 @@ def main(args: argparse.Namespace): parser = argparse.ArgumentParser(description='CacheFlow simple server.') parser = add_server_arguments(parser) args = parser.parse_args() + args = process_server_arguments(args) main(args) From fa0cb9348b82e8eabebc58678f142291a83b9959 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 29 Apr 2023 16:06:25 +0000 Subject: [PATCH 3/3] fix --- benchmark/benchmark_latency.py | 6 ++++-- benchmark/benchmark_text_completion.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py index 5dc5bd51c359d..f9c6966120da0 100644 --- a/benchmark/benchmark_latency.py +++ b/benchmark/benchmark_latency.py @@ -22,7 +22,7 @@ def main(args: argparse.Namespace): (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( initialize_cluster( - address='local', + use_ray=args.use_ray, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) @@ -45,6 +45,7 @@ def main(args: argparse.Namespace): all_stage_devices=all_stage_devices, gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), + use_ray=args.use_ray, ) # Create a frontend. @@ -92,7 +93,8 @@ def profile_step(profile=False): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='CacheFlow simple server.') + parser = argparse.ArgumentParser( + description='Benchmark the latency of decoding a single sentence.') parser = add_server_arguments(parser) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) diff --git a/benchmark/benchmark_text_completion.py b/benchmark/benchmark_text_completion.py index cb361fbd45266..55d602a064532 100644 --- a/benchmark/benchmark_text_completion.py +++ b/benchmark/benchmark_text_completion.py @@ -27,7 +27,7 @@ def main(args: argparse.Namespace): (num_nodes, num_devices_per_node, distributed_init_method, all_stage_devices) = ( initialize_cluster( - address='local', + use_ray=args.use_ray, pipeline_parallel_size=args.pipeline_parallel_size, tensor_parallel_size=args.tensor_parallel_size)) @@ -50,6 +50,7 @@ def main(args: argparse.Namespace): all_stage_devices=all_stage_devices, gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), + use_ray=args.use_ray, collect_stats=True, do_memory_analysis=args.do_memory_analysis, ) @@ -226,7 +227,8 @@ def get_sampling_dir_name( if __name__ == '__main__': - parser = argparse.ArgumentParser(description='CacheFlow simple server.') + parser = argparse.ArgumentParser( + description='Benchmark the performance on a series of requests.') parser = add_server_arguments(parser) parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)