Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Hao integration #31

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*.eggs/
*.so
build/
.idea
8 changes: 8 additions & 0 deletions cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def __init__(
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
tokenizer,
) -> None:
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens
self.tokenizer = tokenizer

# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
Expand Down Expand Up @@ -233,6 +235,7 @@ def post_step(
group_id = seq_group.group_id
self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids
stop_str = self.sampling_params[group_id].stop_str

# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
Expand Down Expand Up @@ -263,6 +266,11 @@ def post_step(
self._free_seq(seq)
continue

if stop_str is not None:
if self.tokenizer.decode(seq.get_token_ids(), skip_special_tokens=True).endswith(stop_str):
self._free_seq(seq)
continue

# Check if the sequence has reached the maximum number of steps.
max_num_steps = self.sampling_params[group_id].max_num_steps
if self.num_steps[group_id] == max_num_steps:
Expand Down
5 changes: 4 additions & 1 deletion cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random

import ray
from transformers import AutoTokenizer

from cacheflow.master.scheduler import Scheduler
from cacheflow.models import get_memory_analyzer
Expand Down Expand Up @@ -33,9 +34,10 @@ def __init__(
self.num_nodes = num_nodes
self.num_devices_per_node = num_devices_per_node
self.world_size = pipeline_parallel_size * tensor_parallel_size

self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
self.memory_analyzer = get_memory_analyzer(
model_name=model,
model_path=model_path,
block_size=block_size,
dtype=dtype,
gpu_memory=gpu_memory,
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_num_batched_tokens,
tokenizer=self.tokenizer
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
Expand Down
4 changes: 3 additions & 1 deletion cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,22 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
model_path: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.model_path = model_path
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size

config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_path)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
Expand Down
19 changes: 15 additions & 4 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
Expand All @@ -16,26 +16,36 @@
_MODELS = {
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
'vicuna': LlamaForCausalLM,
'koala': LlamaForCausalLM,
'alpaca': LlamaForCausalLM,
}

_MEMORY_ANALYZERS = {
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
'vicuna': LlamaMemoryAnalyzer,
'koala': LlamaMemoryAnalyzer,
'alpaca': LlamaMemoryAnalyzer,
}


def get_model(
model_name: str,
model_path: str,
dtype: Union[torch.dtype, str],
path: str,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)

# config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_path)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# weights_dir = model_class.get_weights(model_name, path=path)
weights_dir = model_class.get_weights(model_path, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
Expand All @@ -46,6 +56,7 @@ def get_model(

def get_memory_analyzer(
model_name: str,
model_path: str,
block_size: int,
dtype: Union[torch.dtype, str],
gpu_memory: int,
Expand All @@ -56,6 +67,6 @@ def get_memory_analyzer(
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
model_name, model_path, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')
2 changes: 2 additions & 0 deletions cacheflow/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
max_num_steps: int,
num_logprobs: int,
context_window_size: Optional[int],
stop_str=None
) -> None:
if n < 1:
raise ValueError(f'n must be at least 1, got {n}.')
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self.max_num_steps = max_num_steps
self.num_logprobs = num_logprobs
self.context_window_size = context_window_size
self.stop_str = stop_str

def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, '
Expand Down
3 changes: 2 additions & 1 deletion cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def __init__(
set_random_seed(seed)

# Initialize the model.
self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path)
self.model, self.dtype = get_model(model_name, model_path, dtype=dtype, path=model_path)
self.model = self.model.cuda()
print("loading model done...")
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(
Expand Down