From 440915b37e2666a8fbdaf4d54db0f252b9102832 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Tue, 4 Apr 2023 14:06:30 -0700 Subject: [PATCH 1/8] changes --- cacheflow/http_frontend/test_cli_client.py | 6 +++--- cacheflow/master/scheduler.py | 9 +++++++++ cacheflow/master/server.py | 5 ++++- cacheflow/models/memory_analyzer.py | 4 +++- cacheflow/models/model_utils.py | 15 +++++++++++---- cacheflow/sampling_params.py | 2 ++ cacheflow/worker/worker.py | 3 ++- simple_server.py | 10 +++++++--- 8 files changed, 41 insertions(+), 13 deletions(-) diff --git a/cacheflow/http_frontend/test_cli_client.py b/cacheflow/http_frontend/test_cli_client.py index 217f8088645ab..90c43cb773537 100644 --- a/cacheflow/http_frontend/test_cli_client.py +++ b/cacheflow/http_frontend/test_cli_client.py @@ -7,11 +7,11 @@ def http_request(): headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, - "n": 4, - "use_beam_search": True, + "n": 1, + "use_beam_search": False, "temperature": 0.0, } - response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) + response = requests.post("http://localhost:10002/worker_generate_stream", headers=headers, json=pload, stream=True) for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index c0ab33066c977..67c4d900f7af0 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -34,12 +34,14 @@ def __init__( num_gpu_blocks: int, num_cpu_blocks: int, max_num_batched_tokens: int, + tokenizer, ) -> None: self.controllers = controllers self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks self.max_num_batched_tokens = max_num_batched_tokens + self.tokenizer = tokenizer # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name='fcfs') @@ -233,6 +235,7 @@ def post_step( group_id = seq_group.group_id self.num_steps[group_id] += 1 stop_token_ids = self.sampling_params[group_id].stop_token_ids + stop_func = self.sampling_params[group_id].stop_func # Process beam search results before processing the next tokens. for seq in seq_group.seqs: @@ -263,6 +266,12 @@ def post_step( self._free_seq(seq) continue + if stop_func is not None: + if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stop_func): + print(f"hitting the separation symbols: {seq.get_token_ids()[-2:]}.. Stopped!") + self._free_seq(seq) + continue + # Check if the sequence has reached the maximum number of steps. max_num_steps = self.sampling_params[group_id].max_num_steps if self.num_steps[group_id] == max_num_steps: diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 1f224316c01b4..10ce7bbc042ad 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -3,6 +3,7 @@ import random import ray +from transformers import AutoTokenizer from cacheflow.master.scheduler import Scheduler from cacheflow.models import get_memory_analyzer @@ -33,9 +34,10 @@ def __init__( self.num_nodes = num_nodes self.num_devices_per_node = num_devices_per_node self.world_size = pipeline_parallel_size * tensor_parallel_size - + self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.memory_analyzer = get_memory_analyzer( model_name=model, + model_path=model_path, block_size=block_size, dtype=dtype, gpu_memory=gpu_memory, @@ -76,6 +78,7 @@ def __init__( num_gpu_blocks=self.num_gpu_blocks, num_cpu_blocks=self.num_cpu_blocks, max_num_batched_tokens=max_batch_size, + tokenizer=self.tokenizer ) # Connect the controllers. for i in range(len(self.controllers) - 1): diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 3a539cca97633..fc651948fddb0 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -145,6 +145,7 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer): def __init__( self, model_name: str, + model_path: str, block_size: int, dtype: torch.dtype, gpu_memory: int, @@ -152,13 +153,14 @@ def __init__( tensor_parallel_size: int, ) -> None: self.model_name = model_name + self.model_path = model_path self.block_size = block_size self.dtype = dtype self.gpu_memory = gpu_memory self.cpu_memory = cpu_memory self.tensor_parallel_size = tensor_parallel_size - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_path) self.num_layers = config.num_hidden_layers self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index aaf81bc2b5130..f07cc2f7163f6 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -3,7 +3,7 @@ import numpy as np import torch import torch.nn as nn -from transformers import AutoConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer @@ -16,26 +16,32 @@ _MODELS = { 'llama': LlamaForCausalLM, 'opt': OPTForCausalLM, + 'vicuna': LlamaForCausalLM } _MEMORY_ANALYZERS = { 'llama': LlamaMemoryAnalyzer, 'opt': OPTMemoryAnalyzer, + 'vicuna': LlamaMemoryAnalyzer } def get_model( model_name: str, + model_path: str, dtype: Union[torch.dtype, str], path: str, ) -> nn.Module: torch_dtype = get_torch_dtype(dtype) torch.set_default_dtype(torch_dtype) - config = AutoConfig.from_pretrained(model_name) + + # config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_path) for model_class_name, model_class in _MODELS.items(): if model_class_name in model_name: # Download model weights if it's not cached. - weights_dir = model_class.get_weights(model_name, path=path) + # weights_dir = model_class.get_weights(model_name, path=path) + weights_dir = model_class.get_weights(model_path, path=path) # Create a model instance. model = model_class(config) # Load the weights from the cached or downloaded files. @@ -46,6 +52,7 @@ def get_model( def get_memory_analyzer( model_name: str, + model_path: str, block_size: int, dtype: Union[torch.dtype, str], gpu_memory: int, @@ -56,6 +63,6 @@ def get_memory_analyzer( for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): if model_class in model_name: return memory_analyzer( - model_name, block_size, torch_dtype, gpu_memory, cpu_memory, + model_name, model_path, block_size, torch_dtype, gpu_memory, cpu_memory, tensor_parallel_size) raise ValueError(f'Unsupported model name: {model_name}') diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 4daeaa486e569..d71f504540362 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -13,6 +13,7 @@ def __init__( max_num_steps: int, num_logprobs: int, context_window_size: Optional[int], + stop_func=None ) -> None: if n < 1: raise ValueError(f'n must be at least 1, got {n}.') @@ -59,6 +60,7 @@ def __init__( self.max_num_steps = max_num_steps self.num_logprobs = num_logprobs self.context_window_size = context_window_size + self.stop_func = stop_func def __repr__(self) -> str: return (f'SamplingParams(n={self.n}, ' diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index db0d46aabe9e1..54b15feb5f6b3 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -40,8 +40,9 @@ def __init__( set_random_seed(seed) # Initialize the model. - self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path) + self.model, self.dtype = get_model(model_name, model_path, dtype=dtype, path=model_path) self.model = self.model.cuda() + print("loading model done...") tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) self.num_layers = self.model.config.num_hidden_layers diff --git a/simple_server.py b/simple_server.py index 4d6aa93b97b88..41532ea56e2d3 100644 --- a/simple_server.py +++ b/simple_server.py @@ -45,9 +45,13 @@ def main(args: argparse.Namespace): # Test the following inputs. test_inputs = [ - ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), - ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}), - ('The future of cloud computing is', {}), # Use default parameters. + # ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), + # ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}), + # ('The future of cloud computing is', {}), # Use default parameters. + + ('Ion Stoica is a', {'temperature': 0.7, 'max_num_steps': 256}), + ('UC Berkeley is', {'temperature': 0.7, 'max_num_steps': 256}), + ('The future of cloud computing is', {'temperature': 0.7, 'max_num_steps': 256}), # Use default parameters. ] while True: if test_inputs: From c858c580c33999ef2b5936d7019cbb533b8813cd Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 7 Apr 2023 02:44:46 -0700 Subject: [PATCH 2/8] update stop_str --- cacheflow/master/scheduler.py | 7 +++---- cacheflow/master/server.py | 2 +- cacheflow/sampling_params.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 67c4d900f7af0..587ef7d533275 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -235,7 +235,7 @@ def post_step( group_id = seq_group.group_id self.num_steps[group_id] += 1 stop_token_ids = self.sampling_params[group_id].stop_token_ids - stop_func = self.sampling_params[group_id].stop_func + stor_str = self.sampling_params[group_id].stop_str # Process beam search results before processing the next tokens. for seq in seq_group.seqs: @@ -266,9 +266,8 @@ def post_step( self._free_seq(seq) continue - if stop_func is not None: - if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stop_func): - print(f"hitting the separation symbols: {seq.get_token_ids()[-2:]}.. Stopped!") + if stor_str is not None: + if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stor_str): self._free_seq(seq) continue diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 1b4a15f2591db..dcf469a6588ac 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -1,5 +1,5 @@ import argparse -from typing import List, Tuple +from typing import List, Tuple, Union import random import ray diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index d71f504540362..78bff4362a2ab 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -13,7 +13,7 @@ def __init__( max_num_steps: int, num_logprobs: int, context_window_size: Optional[int], - stop_func=None + stop_str=None ) -> None: if n < 1: raise ValueError(f'n must be at least 1, got {n}.') @@ -60,7 +60,7 @@ def __init__( self.max_num_steps = max_num_steps self.num_logprobs = num_logprobs self.context_window_size = context_window_size - self.stop_func = stop_func + self.stop_str = stop_str def __repr__(self) -> str: return (f'SamplingParams(n={self.n}, ' From 3f235207b09f9d8d63c17eafe65b7112fb8f2f37 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 7 Apr 2023 13:11:35 -0700 Subject: [PATCH 3/8] recover --- cacheflow/http_frontend/test_cli_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cacheflow/http_frontend/test_cli_client.py b/cacheflow/http_frontend/test_cli_client.py index 90c43cb773537..217f8088645ab 100644 --- a/cacheflow/http_frontend/test_cli_client.py +++ b/cacheflow/http_frontend/test_cli_client.py @@ -7,11 +7,11 @@ def http_request(): headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, - "n": 1, - "use_beam_search": False, + "n": 4, + "use_beam_search": True, "temperature": 0.0, } - response = requests.post("http://localhost:10002/worker_generate_stream", headers=headers, json=pload, stream=True) + response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): if chunk: From 6442e3ebc15452ec8d24b231d5e9e3d2704b33da Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 7 Apr 2023 13:12:57 -0700 Subject: [PATCH 4/8] fix a stop_str name --- cacheflow/master/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 587ef7d533275..88fd7966b8117 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -235,7 +235,7 @@ def post_step( group_id = seq_group.group_id self.num_steps[group_id] += 1 stop_token_ids = self.sampling_params[group_id].stop_token_ids - stor_str = self.sampling_params[group_id].stop_str + stop_str = self.sampling_params[group_id].stop_str # Process beam search results before processing the next tokens. for seq in seq_group.seqs: @@ -266,7 +266,7 @@ def post_step( self._free_seq(seq) continue - if stor_str is not None: + if stop_str is not None: if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stor_str): self._free_seq(seq) continue From 7b121fab871de010d4a410cb0b0262466f1629fa Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 7 Apr 2023 13:20:45 -0700 Subject: [PATCH 5/8] update --- cacheflow/master/server.py | 2 +- simple_server.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index dcf469a6588ac..1b4a15f2591db 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -1,5 +1,5 @@ import argparse -from typing import List, Tuple, Union +from typing import List, Tuple import random import ray diff --git a/simple_server.py b/simple_server.py index 7b613fdab0579..d333ece34c072 100644 --- a/simple_server.py +++ b/simple_server.py @@ -45,13 +45,9 @@ def main(args: argparse.Namespace): # Test the following inputs. test_inputs = [ - # ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), - # ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}), - # ('The future of cloud computing is', {}), # Use default parameters. - - ('Ion Stoica is a', {'temperature': 0.7, 'max_num_steps': 256}), - ('UC Berkeley is', {'temperature': 0.7, 'max_num_steps': 256}), - ('The future of cloud computing is', {'temperature': 0.7, 'max_num_steps': 256}), # Use default parameters. + ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}), + ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}), + ('The future of cloud computing is', {}), # Use default parameters. ] while True: if test_inputs: From b85250e75812f63ddbba889a46d8bc630e6d6a01 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Fri, 7 Apr 2023 13:32:02 -0700 Subject: [PATCH 6/8] update --- cacheflow/master/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 88fd7966b8117..abf2a83245e7f 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -267,7 +267,7 @@ def post_step( continue if stop_str is not None: - if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stor_str): + if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stop_str): self._free_seq(seq) continue From 0bdd8143f35f2d9cf71f413a7e7a4073f2bfaf34 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Sun, 9 Apr 2023 19:34:16 -0700 Subject: [PATCH 7/8] not using fast tokenizer --- cacheflow/master/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 1b4a15f2591db..bc6d45d9418d1 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -34,7 +34,7 @@ def __init__( self.num_nodes = num_nodes self.num_devices_per_node = num_devices_per_node self.world_size = pipeline_parallel_size * tensor_parallel_size - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) self.memory_analyzer = get_memory_analyzer( model_name=model, model_path=model_path, From 8e56ab628b250a0906237c29356a7cbe4ad88669 Mon Sep 17 00:00:00 2001 From: zhisbug Date: Sun, 9 Apr 2023 23:05:35 -0700 Subject: [PATCH 8/8] add support for koala and alpaca --- cacheflow/models/model_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index f07cc2f7163f6..72af3f64e00a2 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -16,13 +16,17 @@ _MODELS = { 'llama': LlamaForCausalLM, 'opt': OPTForCausalLM, - 'vicuna': LlamaForCausalLM + 'vicuna': LlamaForCausalLM, + 'koala': LlamaForCausalLM, + 'alpaca': LlamaForCausalLM, } _MEMORY_ANALYZERS = { 'llama': LlamaMemoryAnalyzer, 'opt': OPTMemoryAnalyzer, - 'vicuna': LlamaMemoryAnalyzer + 'vicuna': LlamaMemoryAnalyzer, + 'koala': LlamaMemoryAnalyzer, + 'alpaca': LlamaMemoryAnalyzer, }