Skip to content

Commit

Permalink
support both weight loading methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jun 27, 2024
1 parent 90094db commit 6e71226
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 16 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from vllm.model_executor.parameter import *
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.parameter import *

__all__ = [
"SamplingMetadata",
Expand Down
363 changes: 355 additions & 8 deletions vllm/model_executor/layers/linear.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .compressed_tensors import CompressedTensorsLinearMethod # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)

__all__ = ["CompressedTensorsLinearMethod"]


class CompressedTensorsConfig(QuantizationConfig):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.parameter import vLLMParameter, PackedvLLMParameter
from vllm.model_executor.parameter import PackedvLLMParameter, vLLMParameter

__all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_BITS = [4]
Expand Down Expand Up @@ -65,7 +66,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_dim=1,
input_dim=input_dim,
use_col_loading=True,
use_row_loading=True if input_dim is not None else False,
use_row_loading=True #noqa: SIM210
if input_dim is not None else False,
weight_loader=weight_loader)

weight_shape = vLLMParameter(data=torch.empty(2, dtype=torch.int64),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.parameter import vLLMParameter, ScalerToArrayvLLMParameter
from vllm.model_executor.parameter import (ScalerToArrayvLLMParameter,
vLLMParameter)


class CompressedTensorsW8A8(CompressedTensorsScheme):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
marlin_permute_scales)
from vllm.model_executor.parameter import vLLMParameter, PackedvLLMParameter
from vllm.model_executor.parameter import PackedvLLMParameter, vLLMParameter

__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8]
Expand Down Expand Up @@ -68,7 +69,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=weight_scale_dim,
output_dim=0,
use_col_loading=True,
use_row_loading=True if weight_scale_dim is not None else False,
use_row_loading=True # noqa: SIM210
if weight_scale_dim is not None else False,
weight_loader=weight_loader,
data=torch.empty(
output_size_per_partition,
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch.nn import Parameter
from typing import Optional, Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch.nn import Parameter

from vllm.logger import init_logger

__all__ = [
Expand Down

0 comments on commit 6e71226

Please sign in to comment.