Skip to content

Commit

Permalink
Remove marlin(old) kernel codes & do ruff (#719)
Browse files Browse the repository at this point in the history
* clean marlin

* do ruff
  • Loading branch information
CSY-ModelCloud authored Nov 30, 2024
1 parent 958a066 commit 0a4b885
Show file tree
Hide file tree
Showing 14 changed files with 17 additions and 1,085 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import os.path
import torch
from os.path import isdir, join
from typing import Dict, List, Optional, Union

import torch
from gptqmodel.quantization import QUANT_CONFIG_FILENAME
from huggingface_hub import list_repo_files
from transformers import AutoConfig
Expand Down
8 changes: 4 additions & 4 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from packaging import version
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils

from ._const import CPU, get_best_device
from .loader import ModelLoader
from .writer import QUANT_LOG_DAMP, QUANT_LOG_LAYER, QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter
from ..quantization import GPTQ, QuantizeConfig
from ..quantization.config import FORMAT, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig
from ..utils.backend import BACKEND
Expand All @@ -24,6 +21,9 @@
get_module_by_name_suffix, get_moe_layer_modules, move_to,
nested_move_to, pack_model, simple_dispatch_model)
from ..utils.progress import ProgressBar
from ._const import CPU, get_best_device
from .loader import ModelLoader
from .writer import QUANT_LOG_DAMP, QUANT_LOG_LAYER, QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter


def check_support_param_buffer_assignment(*args, **kwargs):
Expand Down Expand Up @@ -203,7 +203,7 @@ def quantize(

if self.quantize_config.format == FORMAT.MARLIN:
raise ValueError(
f"FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ."
"FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ."
)

if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig):
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from torch.nn.parameter import Parameter

from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from torch.nn.parameter import Parameter

marlin_import_exception = None
try:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
import torch


class BACKEND(Enum):
AUTO = 0 # choose the fastest one based on quant model compatibility
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch

from .backend import BACKEND
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear
from ..nn_modules.qlinear.qlinear_cuda import CudaQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear
Expand All @@ -12,6 +11,7 @@
from ..nn_modules.qlinear.qlinear_tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear
from ..quantization import FORMAT
from ..utils.logger import setup_logger
from .backend import BACKEND

logger = setup_logger()

Expand Down Expand Up @@ -90,7 +90,7 @@ def select_quant_linear(
if hasattr(torch, "xpu") and torch.xpu.is_available():
return IPEXQuantLinear

# Fallback to IPEX/CPU if cpu supports AVX512
# Fallback to IPEX/CPU if cpu supports AVX512
from device_smi import Device
if "avx512_vnni" not in Device("cpu").features:
raise ValueError("IPEX/CPU requires minimum avx512_vnni support.")
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import torch
from accelerate.utils import find_tied_parameters

from .model import recurse_getattr, recurse_setattr
from .progress import ProgressBar
from ..nn_modules.qlinear.qlinear_marlin_inference import MarlinInferenceQuantLinear
from ..nn_modules.qlinear.qlinear_marlin_inference import _get_perms, unpack_qzeros
from ..nn_modules.qlinear.qlinear_marlin_inference import MarlinInferenceQuantLinear, _get_perms, unpack_qzeros
from ..quantization import FORMAT, QuantizeConfig
from ..utils.logger import setup_logger
from .model import recurse_getattr, recurse_setattr
from .progress import ProgressBar

logger = setup_logger()

Expand Down
8 changes: 4 additions & 4 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
from transformers import AutoConfig, PretrainedConfig
from transformers.utils.hub import cached_file

from .backend import BACKEND
from .importer import select_quant_linear
from .logger import setup_logger
from .progress import ProgressBar
from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import ExllamaV2QuantLinear
from ..nn_modules.qlinear.qlinear_marlin_inference import MarlinInferenceQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from .backend import BACKEND
from .importer import select_quant_linear
from .logger import setup_logger
from .progress import ProgressBar

logger = setup_logger()

Expand Down
80 changes: 0 additions & 80 deletions gptqmodel_ext/marlin/marlin_cuda.cpp

This file was deleted.

Loading

0 comments on commit 0a4b885

Please sign in to comment.