Skip to content

Commit

Permalink
[python] add aot config for nxdi with vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Jan 28, 2025
1 parent a1d2ea3 commit 9d5b661
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 0 deletions.
6 changes: 6 additions & 0 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
OPTIMUM_CAUSALLM_MODEL_TYPES = {"gpt2", "opt", "bloom", "llama", "mistral"}
OPTIMUM_CAUSALLM_CONTINUOUS_BATCHING_MODELS = {"llama", "mistral"}
VLLM_CONTINUOUS_BATCHING_MODELS = {"llama"}
NXDI_COMPILED_MODEL_FILE_NAME = "model.pt"


class TransformersNeuronXService(object):
Expand Down Expand Up @@ -141,6 +142,11 @@ def set_model_loader_class(self) -> None:
if self.config.model_loader == "nxdi":
os.environ[
'VLLM_NEURON_FRAMEWORK'] = "neuronx-distributed-inference"
if self.config.save_mp_checkpoint_path:
os.environ["NEURON_COMPILED_ARTIFACTS"] = self.config.save_mp_checkpoint_path
nxdi_compiled_model_path = os.path.join(self.config.model_id_or_path, NXDI_COMPILED_MODEL_FILE_NAME)
if os.path.isfile(nxdi_compiled_model_path):
os.environ["NEURON_COMPILED_ARTIFACTS"] = self.config.model_id_or_path
return

if self.config.model_loader == "vllm":
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def get_model_name():
"llama-3-1-8b-instruct-vllm-nxdi": {
"batch_size": [1, 2],
"seq_length": [256],
},
"llama-3-2-1b-instruct-vllm-nxdi-aot": {
"batch_size": [1],
"seq_length": [128],
}
}

Expand Down
16 changes: 16 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,22 @@
"deterministic": False
}
}
},
"llama-3-2-1b-instruct-vllm-nxdi-aot": {
"option.model_id": "",
"option.tensor_parallel_degree": 2,
"option.rolling_batch": "vllm",
"option.model_loading_timeout": 1200,
"option.model_loader": "nxdi",
"option.override_neuron_config": {
"on_device_sampling_config": {
"global_topk": 64,
"dynamic": True,
"deterministic": False
}
},
"option.n_positions": 128,
"option.max_rolling_batch_size": 1,
}
}

Expand Down
15 changes: 15 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,21 @@ def test_llama_vllm_nxdi(self):
"transformers_neuronx_rolling_batch llama-3-1-8b-instruct-vllm-nxdi"
)

def test_llama_vllm_nxdi_aot(self):
with Runner('pytorch-inf2', 'llama-3-2-1b-instruct-vllm-nxdi-aot') as r:
prepare.build_transformers_neuronx_handler_model(
"llama-3-2-1b-instruct-vllm-nxdi-aot")
r.launch(
container="pytorch-inf2-1",
cmd=
"partition --model-dir /opt/ml/input/data/training --save-mp-checkpoint-path /opt/ml/input/data/training/aot --skip-copy"
)
r.launch(container="pytorch-inf2-1",
cmd="serve -m test=file:/opt/ml/model/test/aot")
client.run(
"transformers_neuronx_rolling_batch llama-3-2-1b-instruct-vllm-nxdi-aot"
)


@pytest.mark.correctness
@pytest.mark.trtllm
Expand Down

0 comments on commit 9d5b661

Please sign in to comment.