Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 dynamic prefill weight only decode #1436

Merged
merged 63 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
f390fd9
Add sparsity flag to benchmark
jcaip Oct 18, 2024
67937a9
update
jcaip Oct 18, 2024
6b62266
update
jcaip Oct 18, 2024
aa4c9df
fp8 testing
jcaip Oct 18, 2024
6b1ede1
fp8 testing
jcaip Oct 18, 2024
3c07c40
wip
jcaip Oct 22, 2024
a6c7de9
update benchmark script
jcaip Oct 22, 2024
3660766
update
jcaip Oct 22, 2024
ddf2e10
wip
jcaip Oct 22, 2024
ad4d3b0
udpate
jcaip Oct 22, 2024
653587e
update
jcaip Oct 22, 2024
c757357
wip
jcaip Oct 22, 2024
f1b0841
wip
jcaip Oct 22, 2024
afeaff5
test
jcaip Oct 22, 2024
c294765
wip
jcaip Oct 22, 2024
803e9b3
update
jcaip Oct 22, 2024
eb18850
fix
jcaip Oct 22, 2024
2642212
wip
jcaip Oct 22, 2024
4eccdb9
move out of aqt
jcaip Oct 22, 2024
13e6fd6
wip
jcaip Oct 22, 2024
608d70c
moved float8+24 to it's own file
jcaip Oct 22, 2024
b1f1796
Merge branch 'main' into jcaip/sparse-benchmarking-updates
jcaip Oct 22, 2024
30a4fac
update
jcaip Oct 23, 2024
6091592
wip
jcaip Oct 23, 2024
17f9121
remove float8 for now
jcaip Oct 23, 2024
75d0a0b
wip
jcaip Oct 23, 2024
b2fba99
fix
jcaip Oct 28, 2024
ba5665d
fix
jcaip Oct 28, 2024
4fdfa7b
time prefill by default
jcaip Dec 2, 2024
111babc
update
jcaip Dec 3, 2024
35f1fc7
merge
jcaip Dec 3, 2024
23f981d
fix merge conflicts
jcaip Dec 3, 2024
74c52ff
update
jcaip Dec 3, 2024
eed072d
update benchmarks
jcaip Dec 3, 2024
67cbcbb
fix ruff check
jcaip Dec 3, 2024
0e579ae
fix ruff v2
jcaip Dec 3, 2024
443db19
undo change
jcaip Dec 3, 2024
054717e
add padding
jcaip Dec 3, 2024
2e5b72a
update import
jcaip Dec 3, 2024
2b81dd6
final commit
jcaip Dec 3, 2024
de2d447
fix script
jcaip Dec 3, 2024
c0fa0da
wip
jcaip Dec 6, 2024
584c013
update
jcaip Dec 6, 2024
38d60c7
update
jcaip Dec 25, 2024
97cca7a
update
jcaip Dec 25, 2024
525053b
merge main
jcaip Dec 25, 2024
4da1b31
fix merge confligt
jcaip Dec 25, 2024
2517406
demo
jcaip Dec 25, 2024
5b8a28c
update
jcaip Dec 30, 2024
e25b30c
update generate
jcaip Dec 30, 2024
a58e0fd
moved summarization to standalone script
jcaip Dec 30, 2024
ea5cb0c
update
jcaip Dec 30, 2024
17a191a
update weight only decode flag
jcaip Dec 30, 2024
8899435
remove prompt.txt
jcaip Dec 30, 2024
a3056ff
cleanup
jcaip Dec 30, 2024
67a1a35
remove moby.txt
jcaip Dec 30, 2024
1554a8c
update
jcaip Dec 30, 2024
5161364
update
jcaip Dec 30, 2024
562191f
update
jcaip Dec 30, 2024
bf18806
update benchmars
jcaip Dec 30, 2024
89f03d8
rename arg
jcaip Dec 30, 2024
ce58e1e
update demo script
jcaip Dec 30, 2024
b144a53
formatting
jcaip Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/prepare.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B-Instruct
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
# neuralmagic doesn't come with tokenizer, so we need to copy it over
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model
Expand Down
1 change: 1 addition & 0 deletions torchao/_models/llama/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
moby.txt
8 changes: 8 additions & 0 deletions torchao/_models/llama/demo_summarize.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# grab moby dick prompt
wget -nc -O moby.txt https://gist.githubusercontent.com/jcaip/f319146bb543e92e23b2c76815b0f29f/raw/31a9cd12b0b59f323eb197c9534953bdac352986/gistfile1.txt

export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B-Instruct

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq_prefill_wo_decode --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
11 changes: 10 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_evaluation(
device = "cuda",
precision = torch.bfloat16,
quantization: Optional[str] = None,
sparsity:Optional[str] = None,
compile=False,
max_length=None,
calibration_tasks: Optional[List[str]] = None,
Expand All @@ -44,7 +45,7 @@ def run_evaluation(
"""Runs the evaluation of a model using LM Eval."""
print(
f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, "
+f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
+f"quantization: {quantization}, sparsity: {sparsity}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n"
)
torchao.quantization.utils.recommended_inductor_config_setter()
Expand Down Expand Up @@ -236,6 +237,13 @@ def run_evaluation(
"float8wo, float8dq, float8saq"
),
)
parser.add_argument(
"--sparsity",
type=str,
help=(
"Which sparsity techniques to apply: semi-structured"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand All @@ -251,6 +259,7 @@ def run_evaluation(
args.device,
args.precision,
args.quantization,
args.sparstiy,
args.compile,
args.max_length,
args.calibration_tasks,
Expand Down
42 changes: 34 additions & 8 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

torch.backends.cuda.enable_cudnn_sdp(True)

class HostEvent:
def __init__(self):
Expand Down Expand Up @@ -256,6 +256,7 @@ def _load_model(checkpoint_path, device, precision):
def main(
prefill_size: Optional[int] = None,
prompt: str = "Hello, my name is",
demo_summarize_prompt: Optional[str] = None,
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
Expand Down Expand Up @@ -285,7 +286,11 @@ def main(

if prefill_size is not None and prefill_size > 0:
# create prompt of prefill size
prompt = "prompt " * (int(prefill_size) - 3)
if demo_summarize_prompt is None:
prompt = "prompt " * (int(prefill_size) - 2)
else:
with open(demo_summarize_prompt, "r") as f:
prompt = f.read()

torchao.quantization.utils.recommended_inductor_config_setter()

Expand All @@ -306,6 +311,12 @@ def main(
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)

if demo_summarize_prompt is not None:
end_tag = encode_tokens(tokenizer, "\n <END_TEXT>", bos=False, device=device)
encoded = encoded[:prefill_size-end_tag.size(0)]
encoded = torch.cat((encoded, end_tag), dim=0)

prompt_length = encoded.size(0)

torch.manual_seed(1234)
Expand Down Expand Up @@ -390,6 +401,8 @@ def ffn_or_attn_only(mod, fqn):
quantize_(
model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only
)
elif "int8dq_prefill_wo_decode" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight(weight_only_decode=True))
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
Expand Down Expand Up @@ -809,14 +822,23 @@ def callback(x):
nonlocal done_generating
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
if x.item() == tokenizer.eos_id():
done_generating = True
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)
# print(, end="", flush=True)

elif demo_summarize_prompt is not None and i >= 0:
buffer = []
period_id = tokenizer.encode(".")[0]

def callback(x):
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
if len(buffer) == 4:
print("".join(buffer), end="", flush=True)
buffer.clear()
else:
callback = lambda x: x
t0 = time.perf_counter()
Expand Down Expand Up @@ -851,15 +873,15 @@ def callback(x):
decode_start_event=decode_start_event,
decode_end_event=decode_end_event,
)
if i == -1:
if i < 0:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
prof.export_chrome_trace(f"{profile}.json")
device_sync(device=device) # MKG
t = time.perf_counter() - t0

if not interactive and prefill_size is None:
if not interactive and demo_summarize_prompt is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = (
Expand All @@ -869,7 +891,7 @@ def callback(x):
)
print(tokenizer.decode(tokens))
else:
print()
print("\n")
tokens_generated = y.size(-1) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
Expand Down Expand Up @@ -913,7 +935,7 @@ def callback(x):
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() / 1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() / 1e9
Expand Down Expand Up @@ -975,6 +997,9 @@ def callback(x):
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
)
parser.add_argument(
"--demo_summarize_prompt", type=str, help="Read prompt from text file"
)
parser.add_argument(
"--interactive",
action="store_true",
Expand Down Expand Up @@ -1073,6 +1098,7 @@ def callback(x):
main(
args.prefill_size,
args.prompt,
args.demo_summarize_prompt,
args.interactive,
args.num_samples,
args.max_new_tokens,
Expand Down
38 changes: 33 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,33 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
)


def _int8_symm_per_token_reduced_range_quant_noop_decode(
x: torch.Tensor,
) -> torch.Tensor:
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
if x.shape[1] == 1:
return x
else:
return to_affine_quantized_intx(
x,
mapping_type,
_get_per_token_block_size(x),
target_dtype,
eps=eps,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=torch.float32 if x.dtype == torch.float16 else None,
)


def int8_dynamic_activation_int8_weight(
layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC
layout=PlainLayout(),
act_mapping_type=MappingType.SYMMETRIC,
weight_only_decode=False,
):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
Expand All @@ -831,11 +856,14 @@ def get_weight_block_size(x):
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
if weight_only_decode:
input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
else:
input_quant_func = _int8_asymm_per_token_quant
# input settings
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
else:
input_quant_func = _int8_asymm_per_token_quant

block_size = get_weight_block_size(weight)
weight = to_affine_quantized_intx(
Expand Down
Loading