diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index a4a0d92fc..3f17e1fc4 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -304,8 +304,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } if self.quant_type is not None: # Create quantized attributes from quantization config - self.quant_attrs["bits"] = config.quantization_config["bits"] - self.quant_attrs["group_size"] = config.quantization_config["group_size"] + self.quant_attrs["config"] = config.quantization_config self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): @@ -2101,7 +2100,15 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) + model = QuantModel.from_pretrained( + self.quant_type, + input_path = input_path, + quant_attrs = self.quant_attrs, + q_size = q_size, + kv_size = kv_size, + intermediate_size = self.intermediate_size, + num_layers = self.num_layers, + ) else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py index 52e1876da..d24160990 100644 --- a/src/python/py/models/quantized_model.py +++ b/src/python/py/models/quantized_model.py @@ -52,7 +52,6 @@ def __init__(self): self.weight = None self.bias = None - class QuantizedAttention: def __init__(self, bits, group_size): self.q_proj = QuantizedTensorModule(bits, group_size) @@ -77,29 +76,38 @@ def __init__(self, layer_id, bits, group_size): self.input_layernorm = TensorModule() self.self_attn = QuantizedAttention(bits, group_size) self.post_attention_layernorm = TensorModule() + self.pre_feedforward_layernorm = TensorModule() + self.post_feedforward_layernorm = TensorModule() self.mlp = QuantizedMLP(bits, group_size) + self.bits = bits + self.group_size = group_size def is_empty(self): return self.input_layernorm.weight is None class QuantizedModel: - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): + def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers): self.quant_type = quant_type self.embedding = TensorModule() self.final_norm = TensorModule() self.lm_head = TensorModule() self.layers = {} self.num_layers = num_layers + self._quant_attrs = quant_attrs + self._load_quant_config(quant_attrs) - layer_id = 0 for weight_file in os.listdir(input_path): if weight_file.endswith(".safetensors"): - module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) weights = load_file(os.path.join(input_path, weight_file)) # Map weights to modules for name, tensor in weights.items(): + + # Per-layer quantization support + local_bits = self.get_layer_bits(name) + local_group_size = self.get_layer_group_size(name) + if tensor.dtype == torch.bfloat16: # Cast bfloat16 to float32 since NumPy does not support bfloat16 tensor = tensor.to(torch.float32) @@ -118,26 +126,25 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # Skip rotary embedding weights since they can be re-calculated when looping through the model continue elif name == "lm_head.qweight" or name == "transformer.output_layer.qweight": - self._initialize_quantized_lm_head(bits, group_size) + self._initialize_quantized_lm_head(local_bits, local_group_size) self.lm_head.qweight = tensor - elif name == "lm_head.qzeros" or name == "transformer.output_layer.qzeros": - self._initialize_quantized_lm_head(bits, group_size) + elif name in {"lm_head.qzeros", "lm_head.weight_zero_point", "transformer.output_layer.qzeros"}: + self._initialize_quantized_lm_head(local_bits, local_group_size) self.lm_head.qzeros = tensor - elif name == "lm_head.scales" or name == "transformer.output_layer.scales": - self._initialize_quantized_lm_head(bits, group_size) + elif name in {"lm_head.scales", "lm_head.weight_scale", "transformer.output_layer.scales"}: + self._initialize_quantized_lm_head(local_bits, local_group_size) self.lm_head.scales = tensor elif name == "lm_head.g_idx" or name == "transformer.output_layer.g_idx": - self._initialize_quantized_lm_head(bits, group_size) + self._initialize_quantized_lm_head(local_bits, local_group_size) self.lm_head.g_idx = tensor else: if name.startswith("transformer.encoder"): # Chatglm3, e.g., transformer.encoder.layers.0.input_layernorm.weight name = name.replace("transformer.encoder", "model") - curr_layer_id = int(name.split(".")[2]) - if curr_layer_id != layer_id: - # Switch layer module used - layer_id = curr_layer_id - module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) + layer_id = int(name.split(".")[2]) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, local_bits, local_group_size)) + if local_bits != module.bits or local_group_size != module.group_size: + raise NotImplementedError("Setting different bits or group sizes for various linear modules within a decoder layer is not yet supported in the builder.") # Map weights and biases of norm, attention, and feed-forward network # Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj @@ -151,14 +158,17 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # model.layers.layer_id.self_attn.rotary_emb.inv_freq # Skip rotary embedding weights since they can be re-calculated when looping through the model continue - elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.q?weight$", name)): + # model.layers.layer_id.self_attn.q_proj.weight # model.layers.layer_id.self_attn.q_proj.qweight module.self_attn.q_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.q_proj.scales + # model.layers.layer_id.self_attn.q_proj.weight_scale module.self_attn.q_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.q_proj.qzeros + # model.layers.layer_id.self_attn.q_proj.weight_zero_point module.self_attn.q_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.g_idx$", name)): # model.layers.layer_id.self_attn.q_proj.g_idx @@ -166,14 +176,17 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.bias$", name)): # model.layers.layer_id.self_attn.q_proj.bias module.self_attn.q_proj.bias = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.q?weight$", name)): # model.layers.layer_id.self_attn.k_proj.qweight + # model.layers.layer_id.self_attn.k_proj.weight module.self_attn.k_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.k_proj.scales + # model.layers.layer_id.self_attn.k_proj.weight_scale module.self_attn.k_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.k_proj.qzeros + # model.layers.layer_id.self_attn.k_proj.weight_zero_point module.self_attn.k_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.g_idx$", name)): # model.layers.layer_id.self_attn.k_proj.g_idx @@ -181,14 +194,17 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.bias$", name)): # model.layers.layer_id.self_attn.k_proj.bias module.self_attn.k_proj.bias = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.q?weight$", name)): # model.layers.layer_id.self_attn.v_proj.qweight + # model.layers.layer_id.self_attn.v_proj.weight module.self_attn.v_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.v_proj.scales + # model.layers.layer_id.self_attn.v_proj.weight_scale module.self_attn.v_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.v_proj.qzeros + # model.layers.layer_id.self_attn.v_proj.weight_zero_point module.self_attn.v_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.g_idx$", name)): # model.layers.layer_id.self_attn.v_proj.g_idx @@ -196,17 +212,21 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.bias$", name)): # model.layers.layer_id.self_attn.v_proj.bias module.self_attn.v_proj.bias = tensor - elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.q?weight$", name)): # model.layers.layer_id.self_attn.o_proj.qweight # model.layers.layer_id.self_attention.dense.qweight module.self_attn.o_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.o_proj.scales # model.layers.layer_id.self_attention.dense.scales + # model.layers.layer_id.self_attn.o_proj.weight_scale + # model.layers.layer_id.self_attention.dense.weight_scale module.self_attn.o_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.o_proj.qzeros # model.layers.layer_id.self_attention.dense.qzeros + # model.layers.layer_id.self_attn.o_proj.weight_zero_point + # model.layers.layer_id.self_attention.dense.weight_zero_point module.self_attn.o_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.g_idx$", name)): # model.layers.layer_id.self_attn.o_proj.g_idx @@ -222,14 +242,29 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in elif bool(re.match(r"^model.layers\.\d+\.post_attention_layernorm\.bias$", name)): # model.layers.layer_id.post_attention_layernorm.bias module.post_attention_layernorm.bias = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.pre_feedforward_layernorm\.weight$", name)): + # model.layers.layer_id.pre_feedforward_layernorm.weight + module.pre_feedforward_layernorm.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.pre_feedforward_layernorm\.bias$", name)): + # model.layers.layer_id.pre_feedforward_layernorm.bias + module.pre_feedforward_layernorm.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.post_feedforward_layernorm\.weight$", name)): + # model.layers.layer_id.post_feedforward_layernorm.weight + module.post_feedforward_layernorm.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.post_feedforward_layernorm\.bias$", name)): + # model.layers.layer_id.post_feedforward_layernorm.bias + module.post_feedforward_layernorm.bias = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.q?weight$", name)): # model.layers.layer_id.mlp.gate_proj.qweight + # model.layers.layer_id.mlp.gate_proj.weight module.mlp.gate_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.gate_proj.scales + # model.layers.layer_id.mlp.gate_proj.weight_scale module.mlp.gate_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.gate_proj.qzeros + # model.layers.layer_id.mlp.gate_proj.weight_zero_point module.mlp.gate_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.g_idx$", name)): # model.layers.layer_id.mlp.gate_proj.g_idx @@ -237,14 +272,17 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.bias$", name)): # model.layers.layer_id.mlp.gate_proj.bias module.mlp.gate_proj.bias = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.q?weight$", name)): # model.layers.layer_id.mlp.up_proj.qweight + # model.layers.layer_id.mlp.up_proj.weight module.mlp.up_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.up_proj.scales + # model.layers.layer_id.mlp.up_proj.weight_scale module.mlp.up_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.up_proj.qzeros + # model.layers.layer_id.mlp.up_proj.weight_zero_point module.mlp.up_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.g_idx$", name)): # model.layers.layer_id.mlp.up_proj.g_idx @@ -252,17 +290,23 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.bias$", name)): # model.layers.layer_id.mlp.up_proj.bias module.mlp.up_proj.bias = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.q?weight$", name)): # model.layers.layer_id.mlp.down_proj.qweight # model.layers.layer_id.mlp.dense_4h_to_h.qweight + # model.layers.layer_id.mlp.down_proj.weight + # model.layers.layer_id.mlp.dense_4h_to_h.weight module.mlp.down_proj.qweight = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.down_proj.scales # model.layers.layer_id.mlp.dense_4h_to_h.scales + # model.layers.layer_id.mlp.down_proj.weight_scale + # model.layers.layer_id.mlp.dense_4h_to_h.weight_scale module.mlp.down_proj.scales = tensor - elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.down_proj.qzeros # model.layers.layer_id.mlp.dense_4h_to_h.qzeros + # model.layers.layer_id.mlp.down_proj.weight_zero_point + # model.layers.layer_id.mlp.dense_4h_to_h.weight_zero_point module.mlp.down_proj.qzeros = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.g_idx$", name)): # model.layers.layer_id.mlp.down_proj.g_idx @@ -273,25 +317,31 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # model.layers.layer_id.mlp.dense_4h_to_h.bias module.mlp.down_proj.bias = tensor # Match against fused layers - elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.q?weight$", name)): # model.layers.layer_id.self_attn.qkv_proj.qweight # model.layers.layer_id.self_attention.query_key_value.qweight - q_dim = q_size // (32 // bits) if quant_type == "awq" else q_size - kv_dim = kv_size // (32 // bits) if quant_type == "awq" else kv_size + # model.layers.layer_id.self_attn.qkv_proj.weight + # model.layers.layer_id.self_attention.query_key_value.weight + q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size + kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size module.self_attn.q_proj.qweight = tensor[:, : q_dim] module.self_attn.k_proj.qweight = tensor[:, q_dim : q_dim + kv_dim] module.self_attn.v_proj.qweight = tensor[:, q_dim + kv_dim :] - elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.qkv_proj.scales # model.layers.layer_id.self_attention.query_key_value.scales + # model.layers.layer_id.self_attn.qkv_proj.weight_scale + # model.layers.layer_id.self_attention.query_key_value.weight_scale module.self_attn.q_proj.scales = tensor[:, : q_size] module.self_attn.k_proj.scales = tensor[:, q_size : q_size + kv_size] module.self_attn.v_proj.scales = tensor[:, q_size + kv_size :] - elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.qkv_proj.qzeros # model.layers.layer_id.self_attention.query_key_value.qzeros - q_dim = q_size // (32 // bits) if quant_type in {"awq", "gptq"} else q_size - kv_dim = kv_size // (32 // bits) if quant_type in {"awq", "gptq"} else kv_size + # model.layers.layer_id.self_attn.qkv_proj.weight_zero_point + # model.layers.layer_id.self_attention.query_key_value.weight_zero_point + q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark"} else q_size + kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark"} else kv_size module.self_attn.q_proj.qzeros = tensor[:, : q_dim] module.self_attn.k_proj.qzeros = tensor[:, q_dim : q_dim + kv_dim] module.self_attn.v_proj.qzeros = tensor[:, q_dim + kv_dim :] @@ -307,21 +357,27 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in module.self_attn.q_proj.bias = tensor[: q_size] module.self_attn.k_proj.bias = tensor[q_size : q_size + kv_size] module.self_attn.v_proj.bias = tensor[q_size + kv_size : ] - elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.qweight$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.q?weight$", name)): # model.layers.layer_id.mlp.gate_up_proj.qweight # model.layers.layer_id.mlp.dense_h_to_4h.qweight - intermediate_dim = intermediate_size // (32 // bits) if quant_type == "awq" else intermediate_size + # model.layers.layer_id.mlp.gate_up_proj.weight + # model.layers.layer_id.mlp.dense_h_to_4h.weight + intermediate_dim = intermediate_size // (32 // local_bits) if quant_type in {"awq", "quark"} else intermediate_size module.mlp.gate_proj.qweight = tensor[:, : intermediate_dim] module.mlp.up_proj.qweight = tensor[:, intermediate_dim :] - elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.scales$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.gate_up_proj.scales # model.layers.layer_id.mlp.dense_h_to_4h.scales + # model.layers.layer_id.mlp.gate_up_proj.weight_scale + # model.layers.layer_id.mlp.dense_h_to_4h.weight_scale module.mlp.gate_proj.scales = tensor[:, : intermediate_size] module.mlp.up_proj.scales = tensor[:, intermediate_size :] - elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.qzeros$", name)): + elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.gate_up_proj.qzeros # model.layers.layer_id.mlp.dense_h_to_4h.qzeros - intermediate_dim = intermediate_size // (32 // bits) if quant_type in {"awq", "gptq"} else intermediate_size + # model.layers.layer_id.mlp.gate_up_proj.weight_zero_point + # model.layers.layer_id.mlp.dense_h_to_4h.weight_zero_point + intermediate_dim = intermediate_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark"} else intermediate_size module.mlp.gate_proj.qzeros = tensor[:, : intermediate_dim] module.mlp.up_proj.qzeros = tensor[:, intermediate_dim :] elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.g_idx$", name)): @@ -351,22 +407,34 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # Set properties of each layer based on quantization type self.set_properties() + def _load_quant_config(self, quant_attrs): + self.global_group_size = quant_attrs["config"]["group_size"] + self.global_bits = quant_attrs["config"]["bits"] + + def get_layer_bits(self, layer_name): + # 'bits' is globally defined for all layers + return self.global_bits + + def get_layer_group_size(self, layer_name): + # 'group_size' is globally defined for all layers + return self.global_group_size + def _initialize_quantized_lm_head(self, bits, group_size): """ Initialize `QuantizedTensorModule` for LM head if not already set """ - if isinstance(self.lm_head, TensorModule): - assert self.lm_head.weight is None - assert self.lm_head.bias is None if not isinstance(self.lm_head, QuantizedTensorModule): - self.lm_head = QuantizedTensorModule(bits, group_size) + q_lm_head = QuantizedTensorModule(bits, group_size) + q_lm_head.qweight = self.lm_head.weight + q_lm_head.bias = self.lm_head.bias + self.lm_head = q_lm_head def set_properties(self): """ Set in_features, out_features, and g_idx based on quantization type """ if isinstance(self.lm_head, QuantizedTensorModule): - if self.quant_type == "awq": + if self.quant_type == "awq" or self.quant_type == "quark": self.lm_head.out_features = self.lm_head.scales.shape[1] self.lm_head.in_features = self.lm_head.qweight.shape[0] # Set g_idx if not already set @@ -377,7 +445,7 @@ def set_properties(self): else: raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.") for module in self.layers: - if self.quant_type == "awq": + if self.quant_type == "awq" or self.quant_type == "quark": # Set in_features and out_features module.self_attn.q_proj.out_features = module.self_attn.q_proj.scales.shape[1] module.self_attn.q_proj.in_features = module.self_attn.q_proj.qweight.shape[0] @@ -587,8 +655,8 @@ def pack_ort_format(self, module, intweight): class AWQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) + def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): @@ -598,7 +666,7 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # Unpack and repack all `QuantizedTensorModule` classes in attention self_attn = getattr(layer, "self_attn", None) or getattr(layer, "self_attention", None) - for name, q_tensors in self_attn.__dict__.items(): + for _, q_tensors in self_attn.__dict__.items(): if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: self.unpack(q_tensors) self.repack(q_tensors) @@ -606,8 +674,8 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # Set `g_idx` to None since it's not used in `MatMulNBits` q_tensors.g_idx = None - # Unpack and repack all `Quantized TensorModule` classes in MLP - for name, q_tensors in layer.mlp.__dict__.items(): + # Unpack and repack all `QuantizedTensorModule` classes in MLP + for _, q_tensors in layer.mlp.__dict__.items(): if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: self.unpack(q_tensors) self.repack(q_tensors) @@ -662,8 +730,8 @@ def reverse_reorder_tensor(self, tensor, bits): class GPTQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) + def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): @@ -672,24 +740,24 @@ def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, print(f"Unpacking and repacking layer {i}") # Unpack and repack all `QuantizedTensorModule` classes in attention - for name, q_tensors in layer.self_attn.__dict__.items(): + for _, q_tensors in layer.self_attn.__dict__.items(): if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: self.handle_qzeros(q_tensors) self.unpack(q_tensors) self.repack(q_tensors) - if not use_g_idx: + if not quant_attrs["use_g_idx"]: # Set `g_idx` to None since it's not used in `MatMulNBits` q_tensors.g_idx = None # Unpack and repack all `QuantizedTensorModule` classes in MLP - for name, q_tensors in layer.mlp.__dict__.items(): + for _, q_tensors in layer.mlp.__dict__.items(): if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: self.handle_qzeros(q_tensors) self.unpack(q_tensors) self.repack(q_tensors) - if not use_g_idx: + if not quant_attrs["use_g_idx"]: # Set `g_idx` to None since it's not used in `MatMulNBits` q_tensors.g_idx = None @@ -698,7 +766,7 @@ def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, self.unpack(self.lm_head) self.repack(self.lm_head) - if not use_g_idx: + if not quant_attrs["use_g_idx"]: # Set `g_idx` to None since it's not used in `MatMulNBits` self.lm_head.g_idx = None @@ -726,19 +794,129 @@ def __init__(self, module): self.pack_qzeros(temp_module) module.qzeros = temp_module.qzeros +class QuarkModel(QuantizedModel): + def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers) + + # Unpack and repack all `QuantizedTensorModule` classes in model + for i, layer in enumerate(self.layers): + if i >= self.num_layers: + break + + # Unpack and repack all `QuantizedTensorModule` classes in attention + self_attn = getattr(layer, "self_attn", None) or getattr(layer, "self_attention", None) + for _, q_tensors in self_attn.__dict__.items(): + if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: + self.unpack(q_tensors) + self.repack(q_tensors) + + # Set `g_idx` to None since it's not used in `MatMulNBits` + q_tensors.g_idx = None + + # Unpack and repack all `QuantizedTensorModule` classes in MLP + for _, q_tensors in layer.mlp.__dict__.items(): + if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: + self.unpack(q_tensors) + self.repack(q_tensors) + + # Set `g_idx` to None since it's not used in `MatMulNBits` + q_tensors.g_idx = None + + if isinstance(self.lm_head, QuantizedTensorModule) and self.lm_head.qweight is not None: + self.unpack(self.lm_head) + self.repack(self.lm_head) + + # Set `g_idx` to None since it's not used in `MatMulNBits` + self.lm_head.g_idx = None + + def _load_quant_config(self, quant_attrs): + self.global_quant_config = quant_attrs["config"]["global_quant_config"]["weight"] + self.global_group_size = self.global_quant_config["group_size"] + global_dtype = self.global_quant_config["dtype"] + + dtype_bits_maps = { + "uint4": 4, + "int4": 4, + } + + if global_dtype not in dtype_bits_maps: + raise ValueError(f"Unexpected dtype: {global_dtype}.") + self.global_bits = dtype_bits_maps[global_dtype] + + def get_layer_bits(self, layer_name): + name = layer_name.split(".")[0] + if name in self._quant_attrs["config"]["layer_quant_config"]: + layer_quant_config = self._quant_attrs["config"]["layer_quant_config"][name]["weight"] + local_dtype = layer_quant_config["dtype"] + + dtype_bits_maps = { + "uint4": 4, + "int4": 4, + } + if local_dtype not in dtype_bits_maps: + raise ValueError(f"Unexpected dtype: {local_dtype}.") + return dtype_bits_maps[local_dtype] + return self.global_bits + + def get_layer_group_size(self, layer_name): + name = layer_name.split(".")[0] + if name in self._quant_attrs["config"]["layer_quant_config"]: + layer_quant_config = self._quant_attrs["config"]["layer_quant_config"][name]["weight"] + return layer_quant_config["group_size"] + return self.global_group_size + + def unpack_qweight(self, module): + """ + Unpack `qweight` to standard format + """ + expected_shape = (module.qweight.shape[0], module.out_features) + transpose = module.qweight.shape != expected_shape + module.qweight = self.unpack_on_row(module.qweight.T, module.bits, transpose) + module.qweight = self.reverse_reorder_tensor(module.qweight.T, module.bits) + + def unpack_qzeros(self, module): + """ + Unpack `qzeros` to standard format + """ + super().unpack_qzeros(module) + module.qzeros = self.reverse_reorder_tensor(module.qzeros, module.bits) + + def reverse_reorder_tensor(self, tensor, bits): + """ + Re-arrange tensor data in a new order + """ + compress_ratio = 32 // bits + assert tensor.shape[-1] % compress_ratio == 0 + + if bits == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + raise NotImplementedError(f"Unpacking for {bits}-bit quantization is not currently supported.") + + order_tensor = torch.tensor(order_map, dtype=torch.int32).reshape(1, -1) + order_tensor = order_tensor.repeat(tensor.shape[1] // compress_ratio, 1) + order_tensor = order_tensor + torch.arange(0, tensor.shape[1], compress_ratio, dtype=torch.int32).reshape(-1, 1) + order_tensor = order_tensor.reshape(-1) + + reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor] + reverse_order_tensor = reverse_order_tensor[order_tensor] + int_tensor = tensor[:, reverse_order_tensor] + return int_tensor class QuantModel: @staticmethod - def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): + def from_pretrained(quant_type, **kwargs): """ Unpack quantized weights in PyTorch models, store them in a standard format, and repack them into ONNX Runtime's format. Also performs any pre-processing and post-processing when unpacking the quantized weights. """ if quant_type == "awq": - model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) + model = AWQModel(quant_type, **kwargs) elif quant_type == "gptq": - model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers) + model = GPTQModel(quant_type, **kwargs) + elif quant_type == "quark": + model = QuarkModel(quant_type, **kwargs) else: raise NotImplementedError(f"The {quant_type} quantized model is not currently supported.")