Skip to content

Commit

Permalink
[Bugfix] Support cpu offloading with fp8 quantization (vllm-project#6960
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mgoin authored Jul 31, 2024
1 parent bd70013 commit 460c188
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 33 deletions.
43 changes: 37 additions & 6 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
from vllm.utils import is_hip
import pytest

from tests.quantization.utils import is_quant_method_supported

from ..utils import compare_two_settings


def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
["--cpu-offload-gb", "1"])


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
def test_cpu_offload_fp8():
# Test quantization of an unquantized checkpoint
compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct",
["--quantization", "fp8"],
["--quantization", "fp8", "--cpu-offload-gb", "2"])
# Test loading a quantized checkpoint
compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [],
["--cpu-offload-gb", "2"])


@pytest.mark.skipif(not is_quant_method_supported("awq"),
reason="awq is not supported on this GPU type.")
def test_cpu_offload_awq():
compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
["--cpu-offload-gb", "2"])


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")
def test_cpu_offload_compressed_tensors():
# Test wNa16
compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [],
["--cpu-offload-gb", "1"])
# Test w4a16_marlin24
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
# Test w8a8
compare_two_settings(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [],
["--cpu-offload-gb", "1"])
56 changes: 53 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type

import huggingface_hub
Expand Down Expand Up @@ -37,7 +38,49 @@
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_tpu
from vllm.utils import is_pin_memory_available, is_tpu


@contextmanager
def device_loading_context(module: torch.nn.Module,
target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return

original_device_states: Dict[str, torch.device] = {}

# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched

try:
yield module

finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched


logger = init_logger(__name__)

Expand Down Expand Up @@ -275,8 +318,9 @@ def load_model(self, *, model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config, scheduler_config)
Expand All @@ -291,7 +335,13 @@ def load_model(self, *, model_config: ModelConfig,
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()


Expand Down
50 changes: 26 additions & 24 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,42 +87,44 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:

# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True

if offloaded_parameters:
original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output

state_dict: Dict[str, torch.Tensor] = module.state_dict()

original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

return module

Expand Down

0 comments on commit 460c188

Please sign in to comment.