Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Quark Quantizer Support #1207

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

shobrienDMA
Copy link
Contributor

This allows Quark Quantized models to be processed by ONNX Runtime GenAI.

Quark models must be exported in hf_format.

An example quark_quantize.py command:

python quantize_quark.py --model_dir /[Model_Path] /
--output_dir /[Output_Model_Path] /
--quant_scheme w_uint4_per_group_asym /
--num_calib_data 128 /
--quant_algo awq /
--dataset pileval_for_awq_benchmark /
--seq_len 512 /
--model_export hf_format /
--data_type float32

It also allows different group sizes for different layers depending on what is present in the config.json that Quark produces, a Quark config can look like:

...
  "quantization_config": {
    "algo_config": {
      "model_decoder_layers": "model.layers",
      "name": "awq",
      "num_attention_heads": -1,
      "num_key_value_heads": -1,
      "scaling_layers": [
        {
          "inp": "self_attn.q_proj",
          "layers": [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj"
          ],
          "module2inspect": "self_attn",
          "prev_op": "input_layernorm"
        },
        {
          "inp": "self_attn.o_proj",
          "layers": [
            "self_attn.o_proj"
          ],
          "prev_op": "self_attn.v_proj"
        },
        {
          "inp": "mlp.gate_proj",
          "layers": [
            "mlp.gate_proj",
            "mlp.up_proj"
          ],
          "module2inspect": "mlp",
          "prev_op": "post_attention_layernorm"
        },
        {
          "inp": "mlp.down_proj",
          "layers": [
            "mlp.down_proj"
          ],
          "prev_op": "mlp.up_proj"
        }
      ]
    },
    "exclude": [],
    "export": {
      "kv_cache_group": [],
      "pack_method": "reorder",
      "weight_format": "real_quantized",
      "weight_merge_groups": null
    },
    "global_quant_config": {
      "bias": null,
      "input_tensors": null,
      "output_tensors": null,
      "target_device": null,
      "weight": {
        "ch_axis": 1,
        "dtype": "uint4",
        "group_size": 128,
        "is_dynamic": false,
        "observer_cls": "PerGroupMinMaxObserver",
        "qscheme": "per_group",
        "round_method": "half_even",
        "scale_type": "float",
        "symmetric": false
      }
    },
    "layer_quant_config": {
      "lm_head": {
        "bias": null,
        "input_tensors": null,
        "output_tensors": null,
        "target_device": null,
        "weight": {
          "ch_axis": 1,
          "dtype": "uint4",
          "group_size": 32,
          "is_dynamic": false,
          "observer_cls": "PerGroupMinMaxObserver",
          "qscheme": "per_group",
          "round_method": "half_even",
          "scale_type": "float",
          "symmetric": false
        }
      }
    },
    "layer_type_quant_config": {},
    "quant_method": "quark",
    "quant_mode": "eager_mode"
  },
...

As you can see the lm_head in layer_quant_config has a different group size.

@BowenBao
Copy link
Contributor

cc @kunal-vaishnavi this is the PR for Quark integration, please take a look and let us know what you think, thanks!

self.local_bits = self.global_bits
self.local_group_size = self.global_group_size

# Per-layer quantization support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this section before the if-elif-else checks cleaner by having a method that sets all properties for Quark vs non-Quark scenarios (similar to self.get_config)? I'm thinking something like this could work.

for name, tensor in weights.items():
    module = self.set_module(...)

    if tensor.dtype == torch.bfloat16:
    ...
def set_module(self, ...):
    # Set any shared attributes and variables

    if quant_type == "quark":
        # Set Quark-related attributes and variables
    else:
        # Set non-Quark-related attributes and variables

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with get_local_bits and get_local_group_size methods that are overridden by QuarkModel.

if isinstance(self.lm_head, TensorModule):
weight = self.lm_head.weight
bias = self.lm_head.bias
self.lm_head = QuantizedTensorModule(self.local_bits, self.local_group_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we find a cleaner way to initialize self.lm_head to be a TensorModule versus a QuantizedTensorModule beforehand rather than using the self._initialize_quantized_lm_head approach or this approach?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of code is removed and merged with existing blocks for updating self.lm_head.

@@ -151,13 +188,13 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in
# model.layers.layer_id.self_attn.rotary_emb.inv_freq
# Skip rotary embedding weights since they can be re-calculated when looping through the model
continue
elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.qweight$", name)):
elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.q?weight$", name)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regex appears to work for both .qweight and .weight as the suffix. Does Quark produce quantized weight tensors with the .weight suffix instead of the .qweight suffix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

# model.layers.layer_id.mlp.up_proj.scales
module.mlp.up_proj.scales = tensor
elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.qzeros$", name)):
elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.(qzeros|weight_zero_point)$", name)):
# model.layers.layer_id.mlp.up_proj.qzeros
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each new regex pattern added in all of the condition checks, can you add a comment listing the exact pattern that is matched (like model.layers.layer_id.mlp.up_proj.qzeros)? It's easier to read the comments to know what is being matched instead of processing the regex and then trying to understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Set `g_idx` to None since it's not used in `MatMulNBits`
q_tensors.g_idx = None

# Unpack and repack all `Quantized TensorModule` classes in MLP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Unpack and repack all `Quantized TensorModule` classes in MLP
# Unpack and repack all `QuantizedTensorModule` classes in MLP

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -77,29 +76,38 @@ def __init__(self, layer_id, bits, group_size):
self.input_layernorm = TensorModule()
self.self_attn = QuantizedAttention(bits, group_size)
self.post_attention_layernorm = TensorModule()
self.pre_feedforward_layernorm = TensorModule()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for Gemma2 model support.

layer_id = curr_layer_id
module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size))
layer_id = int(name.split(".")[2])
module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, local_bits, local_group_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved module initialization to here. Since module is only used when layer_id is available. It is unused for above if-elif branches for other layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants