Skip to content

Commit

Permalink
Move sm_fml import in AOT script
Browse files Browse the repository at this point in the history
  • Loading branch information
ethnzhng committed Jan 31, 2025
1 parent b888a0d commit 7ebe7dd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion serving/docker/partition/sm_neo_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from utils import (update_kwargs_with_env_vars, load_properties)

import torch
import sagemaker_fast_model_loader_rust as sm_fml
from mpi4py import MPI

from lmi_dist.init_engine import engine_from_args
Expand All @@ -50,6 +49,7 @@ def __init__(self):

def save_configs(self, pp_degree: int, tp_degree: int, input_dir: str,
output_dir: str, configs: list) -> None:
import sagemaker_fast_model_loader_rust as sm_fml
py_version = "{}.{}.{}".format(*sys.version_info[:3])
conf = sm_fml.ModelConfig(
pipeline_parallel_size=pp_degree,
Expand Down Expand Up @@ -160,6 +160,8 @@ def shard_lmi_dist_model(self, input_dir: str, output_dir: str,
engine_configs = engine_args.create_engine_configs()
engine_worker = load_model_for_sharding(engine_configs)

# Lazy import to avoid MPI not-inited errors
import sagemaker_fast_model_loader_rust as sm_fml
model_dir = os.path.join(output_dir, sm_fml.MODEL_DIR_NAME)
os.makedirs(model_dir, exist_ok=True)

Expand Down

0 comments on commit 7ebe7dd

Please sign in to comment.