Skip to content

Commit

Permalink
feat: support moe hf chkpt (#133)
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant authored Mar 7, 2025
1 parent de9a4f1 commit f7210f7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

# Local
from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors
from .checkpoint_utils import (
patch_huggingface_save_and_load_for_dtensors,
recover_safetensors_from_dcp,
)
from .scattermoe_prepare import prepare_scattermoe

# this is a special patch function to disable foreach for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,75 +457,38 @@ def save_sharded_safetensors(
# --------------------------- SCRIPT -------------------------


# have it serve as a conversion script
if __name__ == "__main__":
# Standard
import argparse

parser = argparse.ArgumentParser(
description=(
"Utility for converting ScatterMoE checkpoint back to the "
"orginal state dict format. "
"The ScatterMoE checkpoint was saved after the pretrained model "
"had been converted by a module swap, hence the state dict will "
"no longer resemble the original. This utility creaes"
)
)

parser.add_argument(
"checkpoint_dir",
help="Path to the checkpoint.",
)

parser.add_argument(
"output_dir", help="Path to the location to write the converted checkpoint."
)

parser.add_argument(
"pretrained_model_name_or_path",
help=(
"In order to reconstruct the state dict, we requre hints from "
"the original pretrained model checkpoint (from which this "
"checkpoint is obtained)."
),
default=None,
)

args = parser.parse_args()

# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
# start with FSDP_MODEL_NAME
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
checkpoint_dir = args.checkpoint_dir
def recover_safetensors_from_dcp(
checkpoint_dir, pretrained_model_name_or_path, output_dir
):
if checkpoint_dir.startswith(FSDP_MODEL_NAME):
loader = get_state_dict_from_dcp_checkpoint
else:
checkpoint_dir = [
fsdp_checkpoint_dirs = [
x
for x in os.listdir(args.checkpoint_dir)
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
for x in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, x))
and x.startswith(FSDP_MODEL_NAME)
]
if len(checkpoint_dir) == 1:
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
if len(fsdp_checkpoint_dirs) == 1:
checkpoint_dir = os.path.join(checkpoint_dir, fsdp_checkpoint_dirs[0])
loader = get_state_dict_from_dcp_checkpoint
elif len(checkpoint_dir) > 1:
elif len(fsdp_checkpoint_dirs) > 1:
raise ValueError(
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
f"Found > 1 dirs in dcp checkpoint dir {checkpoint_dir} "
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
)
else:
# then take it as a safetensors checkpoint
# - do not support .bin checkpoints
checkpoint_dir = args.checkpoint_dir
loader = get_state_dict_from_safe_checkpoint

# - pretrained model name
_name_or_path = args.pretrained_model_name_or_path
_name_or_path = pretrained_model_name_or_path

# assume output directory exists, we do not create it
# - copy the config file if exists
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
target_config_file = os.path.join(output_dir, CONFIG_NAME)
if os.path.exists(config_file):
shutil.copyfile(config_file, target_config_file)

Expand All @@ -544,6 +507,46 @@ def save_sharded_safetensors(
# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
args.output_dir,
output_dir,
metadata={"format": "pt"},
)


# have it serve as a conversion script
if __name__ == "__main__":
# Standard
import argparse

parser = argparse.ArgumentParser(
description=(
"Utility for converting ScatterMoE checkpoint back to the "
"orginal state dict format. "
"The ScatterMoE checkpoint was saved after the pretrained model "
"had been converted by a module swap, hence the state dict will "
"no longer resemble the original. This utility creaes"
)
)

parser.add_argument(
"checkpoint_dir",
help="Path to the checkpoint.",
)

parser.add_argument(
"output_dir", help="Path to the location to write the converted checkpoint."
)

parser.add_argument(
"pretrained_model_name_or_path",
help=(
"In order to reconstruct the state dict, we requre hints from "
"the original pretrained model checkpoint (from which this "
"checkpoint is obtained)."
),
default=None,
)

args = parser.parse_args()
recover_safetensors_from_dcp(
args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir
)

0 comments on commit f7210f7

Please sign in to comment.