Skip to content

Commit

Permalink
[tnx] adding additional neuron config options (#1777) (#1783)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Apr 16, 2024
1 parent 80387ee commit d2842ad
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
8 changes: 8 additions & 0 deletions engines/python/setup/djl_python/neuron_utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ def set_neuron_config(self) -> None:
if self.config.collectives_layout:
neuron_config[
"collectives_layout"] = self.config.collectives_layout
if self.config.attention_layout:
neuron_config["attention_layout"] = self.config.attention_layout
if self.config.cache_layout:
neuron_config["cache_layout"] = self.config.cache_layout
if self.config.all_reduce_dtype:
neuron_config["all_reduce_dtype"] = self.config.all_reduce_dtype
if self.config.cast_logits_dtype:
neuron_config["cast_logits_dtype"] = self.config.cast_logits_dtype

self.neuron_config = NeuronConfig(**neuron_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class Dtype(str, Enum):
bf16 = 'bf16'


class TnXDtypeName(str, Enum):
float32 = 'float32'
float16 = 'float16'
bfloat16 = 'bfloat16'


class TnXQuantizeMethods(str, Enum):
static_int8 = 'static_int8'

Expand Down Expand Up @@ -93,8 +99,12 @@ class TransformerNeuronXProperties(Properties):
rolling_batch_strategy: Optional[TnXGenerationStrategy] = None
fuse_qkv: Optional[bool] = False
on_device_embedding: Optional[bool] = False
attention_layout: Optional[TnXMemoryLayout] = None
collectives_layout: Optional[TnXMemoryLayout] = None
cache_layout: Optional[TnXMemoryLayout] = None
partition_schema: Optional[TnXModelSchema] = None
all_reduce_dtype: Optional[TnXDtypeName] = None
cast_logits_dtype: Optional[TnXDtypeName] = None

@validator('neuron_optimize_level')
def set_neuron_optimal_env(cls, level):
Expand Down
15 changes: 13 additions & 2 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import unittest
from djl_python.properties_manager.properties import Properties
from djl_python.properties_manager.tnx_properties import TransformerNeuronXProperties, TnXGenerationStrategy, TnXModelSchema, TnXMemoryLayout
from djl_python.properties_manager.tnx_properties import (
TransformerNeuronXProperties, TnXGenerationStrategy, TnXModelSchema,
TnXMemoryLayout, TnXDtypeName)
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.properties_manager.ds_properties import DeepSpeedProperties, DsQuantizeMethods
from djl_python.properties_manager.hf_properties import HuggingFaceProperties, HFQuantizeMethods
Expand Down Expand Up @@ -132,7 +134,11 @@ def test_tnx_all_configs(self):
"rolling_batch_strategy": "continuous_batching",
"collectives_layout": "HSB",
"on_device_embedding": "true",
"partition_schema": "legacy"
"partition_schema": "legacy",
"attention_layout": "HSB",
"cache_layout": "SBH",
"all_reduce_dtype": "float32",
"cast_logits_dtype": "float32",
}
tnx_configs = TransformerNeuronXProperties(**common_properties,
**properties)
Expand Down Expand Up @@ -163,6 +169,11 @@ def test_tnx_all_configs(self):
TnXMemoryLayout.LAYOUT_HSB)
self.assertTrue(tnx_configs.on_device_embedding)
self.assertEqual(tnx_configs.partition_schema, TnXModelSchema.legacy)
self.assertEqual(tnx_configs.attention_layout,
TnXMemoryLayout.LAYOUT_HSB)
self.assertEqual(tnx_configs.cache_layout, TnXMemoryLayout.LAYOUT_SBH)
self.assertEqual(tnx_configs.all_reduce_dtype, TnXDtypeName.float32)
self.assertEqual(tnx_configs.cast_logits_dtype, TnXDtypeName.float32)

# tests context length estimate as integer
def test_tnx_cle_int(context_length_estimate):
Expand Down

0 comments on commit d2842ad

Please sign in to comment.