Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PT2E quantization for CPU #2594

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 79 additions & 20 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
choices=["int8dynamic", "int8weightonly", "int4weightonly", "pt2e"],
help="Apply quantization to the model before running it",
)
parser.add_argument(
Expand Down Expand Up @@ -182,26 +182,34 @@ def apply_torchdynamo_args(
)

if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
)
if model.device == "cpu":
if args.quantization == "pt2e":
enable_inductor_quant(model)
else:
raise ValueError(
"The quantization mode is not enabled on CPU"
)
else:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
)

torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)

if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down Expand Up @@ -240,3 +248,54 @@ def apply_torchdynamo_args(
model.eval = optimize_ctx(model.eval)

torch._dynamo.reset()

def enable_inductor_quant(model: 'torchbenchmark.util.model.BenchmarkModel'):
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.export import Dim
module, example_inputs = model.get_module()

if isinstance(example_inputs, dict):
input_ids = torch.randn(2, 512).to(torch.long)
example_inputs = {
"input_ids": input_ids,
}
input_shapes = {k: list(v.shape) for (k, v) in example_inputs.items()}
dims = set()
for _, v in input_shapes.items():
dims.update(v)
dims=sorted(dims)
dim_str_map = {x: Dim("dim" + str(list(dims).index(x)), min=1, max=1024 * 1024) for x in dims}
dynamic_shapes = {k: {v.index(dim): dim_str_map[dim] for dim in v} for (k, v) in input_shapes.items()}
del dynamic_shapes["input_ids"][1]
# Create X86InductorQuantizer
quantizer = xiq.X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
# Generate the FX Module
if isinstance(example_inputs, dict):
input_ids = torch.ones(2, 512).to(torch.long)
example_inputs = {
"input_ids": input_ids,
}
exported_model = torch.export.export_for_training(
module,
(),
example_inputs,
dynamic_shapes=dynamic_shapes,
).module()
else:
exported_model = torch.export.export_for_training(
module,
example_inputs,
).module()
# PT2E Quantization flow
prepared_model = prepare_pt2e(exported_model, quantizer)
# Calibration
if isinstance(example_inputs, dict):
prepared_model(**example_inputs)
else:
prepared_model(*example_inputs)
with torch.no_grad():
converted_model = convert_pt2e(prepared_model)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
model.set_module(converted_model)