Skip to content

Commit

Permalink
[tnx] improve model partitioning time (#1652)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Mar 21, 2024
1 parent 672070b commit 21a6d57
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
10 changes: 7 additions & 3 deletions engines/python/setup/djl_python/neuron_utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def load_hf_model(self) -> "PreTrainedModel":
self.config.model_id_or_path,
trust_remote_code=self.config.trust_remote_code,
revision=self.config.revision,
low_cpu_mem_usage=True)
low_cpu_mem_usage=self.config.low_cpu_mem_usage)

def load_inf2_model_from_disk(self) -> "PreTrainedModel":
if not self.config.load_split_model:
Expand All @@ -217,6 +217,10 @@ def load_inf2_model_from_memory(self) -> "PreTrainedModel":
model.load_state_dict_low_memory(self.model.state_dict())
return model

def save_split_model(self):
logging.info(f"Saving INF2 model to {self.split_model_path} ...")
save_pretrained_split(self.model, self.split_model_path)

def set_load_path(self) -> None:
"""
Sets the path to which to load artifacts - based on specified format
Expand Down Expand Up @@ -302,8 +306,8 @@ def partition(self, save_path: str, **kwargs):
self.model_config.save_pretrained(save_path)
self.model = self.load_hf_model()
self.load_path = self.get_load_path()
self.model = self.load_inf2_model_from_disk()
shutil.copytree(self.load_path, self.split_model_path)
self.save_split_model()
self.model = self.load_inf2_model_from_memory()

# Neuron compiler serialization workaround
path = os.getcwd()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class TnXGQAMethods(str, Enum):
all_gather_heads = 'all-gather-heads'


class TnXModelLoaders(str, Enum):
tnx = "tnx"
optimum = "optimum"


TNX_SUPPORTED_ROLLING_BATCH_TYPES = ['auto']


Expand All @@ -65,6 +70,7 @@ class TransformerNeuronXProperties(Properties):
task: Optional[str] = None
save_mp_checkpoint_path: Optional[str] = None
group_query_attention: Optional[str] = None
model_loader: Optional[TnXModelLoaders] = None

@validator('neuron_optimize_level')
def set_neuron_optimal_env(cls, level):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_set_model_loader(self):
"rolling_batch": "auto",
"task": "text-generation",
"max_rolling_batch_size": 4,
"load_split_model": True
"model_loader": "tnx"
}])
def test_initialize(self, params):
# Setup
Expand Down
7 changes: 7 additions & 0 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def __init__(self) -> None:
self._model_loader_class = OptimumModelLoader

def set_model_loader_class(self):
if self.config.model_loader == "optimum":
return

if self.config.model_loader == "tnx":
self._model_loader_class = TNXModelLoader
return

use_tnx = False
if self.model_config.architectures is not None and any(
"CausalLM" in arch
Expand Down

0 comments on commit 21a6d57

Please sign in to comment.