Skip to content

Commit

Permalink
Enable Llama 3.1 405B in FP8 (#124) (#1745)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaygala223 authored Feb 14, 2025
1 parent 2704931 commit 4b4c8a8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "unit_scale",
"dump_stats_path": "./hqt_output/measure"
"whitelist": {"types": [], "names": []},
"blacklist": {"types": [], "names": []},
"quantize_weight": false,
"dump_stats_path": "./results/hk",
"ignore_modules_wo_measures": "True",
"dump_stats_xlsx_path": "./run_outputs/fp8stats.xlsx"
}
8 changes: 7 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,12 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):
logger.info("DeepSpeed is enabled.")
deepspeed.init_distributed(dist_backend="hccl")
config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs)
load_to_meta = model_on_meta(config)

keep_module_on_host = False
if "Llama-3.1-405B" in args.model_name_or_path:
keep_module_on_host = True

load_to_meta = False if keep_module_on_host else model_on_meta(config)

if args.assistant_model is None:
assistant_model = None
Expand Down Expand Up @@ -494,6 +499,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

# Initialize the model
ds_inference_kwargs = {"dtype": model_dtype}
ds_inference_kwargs["keep_module_on_host"] = keep_module_on_host
ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size}
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(config)
Expand Down

0 comments on commit 4b4c8a8

Please sign in to comment.