From 460c1884e3cb781730f85cb5591a85d5864bdac8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 31 Jul 2024 15:47:46 -0400 Subject: [PATCH] [Bugfix] Support cpu offloading with fp8 quantization (#6960) --- tests/basic_correctness/test_cpu_offload.py | 43 +++++++++++++--- vllm/model_executor/model_loader/loader.py | 56 +++++++++++++++++++-- vllm/model_executor/models/utils.py | 50 +++++++++--------- 3 files changed, 116 insertions(+), 33 deletions(-) diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 9ebcc48a9b93e..180b926637ecb 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -1,4 +1,6 @@ -from vllm.utils import is_hip +import pytest + +from tests.quantization.utils import is_quant_method_supported from ..utils import compare_two_settings @@ -6,8 +8,37 @@ def test_cpu_offload(): compare_two_settings("meta-llama/Llama-2-7b-hf", [], ["--cpu-offload-gb", "4"]) - if not is_hip(): - # compressed-tensors quantization is currently not supported in ROCm. - compare_two_settings( - "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [], - ["--cpu-offload-gb", "1"]) + + +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") +def test_cpu_offload_fp8(): + # Test quantization of an unquantized checkpoint + compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct", + ["--quantization", "fp8"], + ["--quantization", "fp8", "--cpu-offload-gb", "2"]) + # Test loading a quantized checkpoint + compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [], + ["--cpu-offload-gb", "2"]) + + +@pytest.mark.skipif(not is_quant_method_supported("awq"), + reason="awq is not supported on this GPU type.") +def test_cpu_offload_awq(): + compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [], + ["--cpu-offload-gb", "2"]) + + +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), + reason="gptq_marlin is not supported on this GPU type.") +def test_cpu_offload_compressed_tensors(): + # Test wNa16 + compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [], + ["--cpu-offload-gb", "1"]) + # Test w4a16_marlin24 + compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + [], ["--cpu-offload-gb", "1"]) + # Test w8a8 + compare_two_settings( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [], + ["--cpu-offload-gb", "1"]) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020da..f72515e014829 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -7,6 +7,7 @@ import math import os from abc import ABC, abstractmethod +from contextlib import contextmanager from typing import Any, Dict, Generator, List, Optional, Tuple, Type import huggingface_hub @@ -37,7 +38,49 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu +from vllm.utils import is_pin_memory_available, is_tpu + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + logger = init_logger(__name__) @@ -275,8 +318,9 @@ def load_model(self, *, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) @@ -291,7 +335,13 @@ def load_model(self, *, model_config: ModelConfig, for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: - quant_method.process_weights_after_loading(module) + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) return model.eval() diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 197d3839a766a..91b4a27814bf4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False for p in module.parameters(): if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: # we use per-parameter offloading @@ -94,35 +95,36 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: break # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty(size=p.data.size(), - dtype=p.data.dtype, - layout=p.data.layout, - device='cpu', - pin_memory=pin_memory) + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device='cpu', + pin_memory=pin_memory) cpu_data.copy_(p.data) p.data = cpu_data _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, + device_state, + args=args, + kwargs=kwargs) + module.forward = forward + return output - state_dict: Dict[str, torch.Tensor] = module.state_dict() - - original_forward = module.forward - - def forward(*args, **kwargs): - module.forward = original_forward - device_state = { - # here we blindly call `to(device)` - # if the parameter is already on the device, it will be a no-op - k: v.to(device, non_blocking=True) - for k, v in state_dict.items() - } - output = functional_call(module, - device_state, - args=args, - kwargs=kwargs) module.forward = forward - return output - - module.forward = forward return module