Skip to content

Commit

Permalink
AutoFP8 to llmcompressor migration for FP8 quantization (#2701)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ys authored Feb 2, 2025
1 parent 6e0cca2 commit d4f5ee7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 60 deletions.
3 changes: 2 additions & 1 deletion serving/docker/lmi-container-requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ onnx
sentence_transformers
onnxruntime-gpu==1.20.0
autoawq==0.2.5
llmcompressor==0.3.1
tokenizers==0.20.3
pydantic==2.9.2
optimum==1.23.2
torch==2.5.1
torchvision==0.20.1
# sequence scheduler wheel for hf accelerate rolling batch
https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl
https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl
102 changes: 43 additions & 59 deletions serving/docker/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from properties_manager import PropertiesManager
from huggingface_hub import snapshot_download
from datasets import load_dataset

from utils import (get_partition_cmd, extract_python_jar,
get_python_executable, get_download_dir,
Expand Down Expand Up @@ -217,8 +216,8 @@ def run_quantization(self):
self.properties_manager.generate_properties_file()
self.upload_checkpoints_to_s3()
elif quant_method == 'fp8':
logging.info("Running AutoFP8 quantization")
self.autofp8_quantize()
logging.info("Running FP8 quantization")
self.fp8_quantize()
self.properties_manager.generate_properties_file()
self.upload_checkpoints_to_s3()
else:
Expand Down Expand Up @@ -266,67 +265,52 @@ def autoawq_quantize(self):
raise ImportError(
"AutoAWQ is not installed. Failing during quantization.")

def autofp8_quantize(self):
def fp8_quantize(self):
"""
Quantizes model using AutoFP8.
Quantizes model using llm-compressor.
Recipe: Simple PTQ + FP8 weight & activation quantization.
"""
# initialize configs
hf_configs, tokenizer = load_hf_config_and_tokenizer(self.properties)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot
from transformers import AutoModelForCausalLM

quant_config = {
"activation_scheme":
self.properties.get("option.fp8_activation_scheme", "static"),
# initialize configs and model
hf_configs, tokenizer = load_hf_config_and_tokenizer(self.properties)
output_path = self.properties['option.save_mp_checkpoint_path']
model = AutoModelForCausalLM.from_pretrained(
hf_configs.model_id_or_path, **hf_configs.kwargs)

# parse options and define quantization recipe
quant_config = {"targets": "Linear"}
quant_config["scheme"] = self.properties.get("option.fp8_scheme",
"FP8")
quant_config["ignore"] = [
s.strip() for s in self.properties.get("option.fp8_ignore",
"lm_head").split(',')
]
recipe = QuantizationModifier(**quant_config)

# calibration dataset options
oneshot_kwargs = {
"model": model,
"recipe": recipe,
}
if self.properties.get("option.fp8_kv_cache_quant_targets"):
quant_config["kv_cache_quant_targets"] = tuple([
s.strip() for s in self.properties.get(
"option.fp8_kv_cache_quant_targets").split(',')
])
if self.properties.get("option.fp8_ignore_patterns"):
quant_config["ignore_patterns"] = [
s.strip() for s in self.properties.get(
"option.fp8_ignore_patterns").split(',')
]

# create samples for calibrating scaling factors
if quant_config["activation_scheme"] == "dynamic":
# If using dynamic activation scales, a calibration dataset is not required
examples = []
if "dynamic" in recipe.scheme:
pass
else:
calib_size = int(self.properties.get("option.calib_size", 512))
# Tokenize dataset for calibrating static activation scales
ds = load_dataset("abisee/cnn_dailymail",
"3.0.0",
split="validation").shuffle(seed=42).select(
range(calib_size))
examples = [batch["article"] for batch in ds]
examples = tokenizer(examples,
padding=True,
truncation=True,
return_tensors="pt").to("cuda")

# quantization
try:
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
quantize_config = BaseQuantizeConfig(**quant_config)
logging.info(
f"Using the following configurations for fp8 quantization: {vars(quantize_config)}"
)
model = AutoFP8ForCausalLM.from_pretrained(
hf_configs.model_id_or_path, quantize_config,
**hf_configs.kwargs)
model.quantize(examples)
output_path = self.properties['option.save_mp_checkpoint_path']
logging.info(
f"Quantization complete. Saving model to: {output_path}")
model.save_quantized(output_path)
except ImportError:
logging.error(
"AutoFP8 is not installed. Failing during quantization.")
raise ImportError(
"AutoFP8 is not installed. Failing during quantization.")
oneshot_kwargs["dataset"] = "cnn_dailymail"
oneshot_kwargs["num_calibration_samples"] = int(
self.properties.get("option.calib_size", 512))
oneshot_kwargs["max_seq_length"] = int(
self.properties.get("option.max_model_len", 2048))

logging.info(
f"Using the following configuartions for fp8 quantization: {oneshot_kwargs}"
)
oneshot(**oneshot_kwargs)
logging.info(f"Quantization complete. Saving model to: {output_path}")
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)


def main():
Expand Down
8 changes: 8 additions & 0 deletions serving/docker/partition/sm_neo_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,16 @@ def write_properties(self):
"""
Updates outputted serving.properties.
## tensor parallel degree & device_map
We set option.tensor_parallel_degree & option.device_map for quantization.
This function passes through these values to the outputted serving.properties if received from the customer.
Otherwise, nothing is outputted for these values.
## quantization
For FP8 quantization with llm-compressor, vllm requires quantization_method to be set to 'compressed-tensors'
"""
passthrough_properties = {}
# checking if customer set property through envvar or serving.properties.
passthrough_properties[
"option.tensor_parallel_degree"] = os.environ.get(
"OPTION_TENSOR_PARALLEL_DEGREE") if os.environ.get(
Expand All @@ -127,6 +132,9 @@ def write_properties(self):
f"User did not pass {k}. Outputted serving.properties "
"will not include this field.")

if output_properties.get("option.quantize") == "fp8":
output_properties["option.quantize"] = "compressed-tensors"

self.properties_manager.properties = output_properties
self.properties_manager.generate_properties_file()

Expand Down

0 comments on commit d4f5ee7

Please sign in to comment.