Skip to content

Commit

Permalink
[Neo] Refactor Neo TRT-LLM partition script (#2166)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethnzhng authored Jul 17, 2024
1 parent 6ee4496 commit 9b72600
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 98 deletions.
12 changes: 3 additions & 9 deletions serving/docker/partition/sm_neo_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from sm_neo_utils import (CompilationFatalError, write_error_to_file,
get_neo_env_vars)
from utils import extract_python_jar, load_properties
from utils import (extract_python_jar, load_properties,
update_dataset_cache_location)
from properties_manager import PropertiesManager
from partition import PartitionService

Expand All @@ -40,13 +41,6 @@ def __init__(self):
self.COMPILATION_ERROR_FILE: Final[str] = env[3]
self.HF_CACHE_LOCATION: Final[str] = env[5]

def update_dataset_cache_location(self):
logging.info(
f"Updating HuggingFace Datasets cache directory to: {self.HF_CACHE_LOCATION}"
)
os.environ['HF_DATASETS_CACHE'] = self.HF_CACHE_LOCATION
#os.environ['HF_DATASETS_OFFLINE'] = "1"

def initialize_partition_args_namespace(self):
"""
Initialize args, a SimpleNamespace object that is used to instantiate a
Expand Down Expand Up @@ -122,7 +116,7 @@ def write_properties(self):
self.properties_manager.generate_properties_file()

def neo_quantize(self):
self.update_dataset_cache_location()
update_dataset_cache_location(self.HF_CACHE_LOCATION)
self.initialize_partition_args_namespace()
self.construct_properties_manager()
self.run_quantization()
Expand Down
136 changes: 54 additions & 82 deletions serving/docker/partition/sm_neo_trt_llm_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,104 +12,76 @@
# the specific language governing permissions and limitations under the License.

import logging
import os
import sys
from typing import Final

from utils import load_properties
from sm_neo_utils import InputConfiguration, CompilationFatalError, write_error_to_file, get_neo_env_vars, get_neo_compiler_flags
from utils import (update_kwargs_with_env_vars, load_properties,
update_dataset_cache_location,
remove_option_from_properties)
from sm_neo_utils import (CompilationFatalError, write_error_to_file,
get_neo_env_vars)
from tensorrt_llm_toolkit import create_model_repo

# TODO: Merge the functionality of this file into trt_llm_partition.py
# so all TRT-LLM partitioning is unified

DJL_SERVING_OPTION_PREFIX = "option."
class NeoTRTLLMPartitionService():

def __init__(self):
self.properties: dict = {}

def verify_neo_compiler_flags(compiler_flags):
"""
Verify that provided compiler_flags is a valid configuration
"""
convert_checkpoint_flags = compiler_flags.get("convert_checkpoint_flags")
quantize_flags = compiler_flags.get("quantize_flags")
trtllm_build_flags = compiler_flags.get("trtllm_build_flags")
env = get_neo_env_vars()
self.INPUT_MODEL_DIRECTORY: Final[str] = env[1]
self.OUTPUT_MODEL_DIRECTORY: Final[str] = env[2]
self.COMPILATION_ERROR_FILE: Final[str] = env[3]
self.COMPILER_CACHE_LOCATION: Final[str] = env[4]
self.HF_CACHE_LOCATION: Final[str] = env[5]

if trtllm_build_flags is None:
raise InputConfiguration(
"`compiler_flags` were found, but required sub-field `trtllm_build_flags` was not defined."
" See SageMaker Neo documentation for more info:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
)
if convert_checkpoint_flags is None and quantize_flags is None:
raise InputConfiguration(
"`compiler_flags` were found, but neither sub-fields `convert_checkpoint_flags` "
" or `quantize_flags` were defined, at least one of which must be provided."
" See SageMaker Neo documentation for more info:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
)
if convert_checkpoint_flags is not None and quantize_flags is not None:
logging.warning(
"Both `convert_checkpoint_flags` and `quantize_flags` were provided –"
" `convert_checkpoint_flags` will be used.")
def run_partition(self):
kwargs = remove_option_from_properties(self.properties)
kwargs["trt_llm_model_repo"] = self.OUTPUT_MODEL_DIRECTORY
kwargs["neo_cache_dir"] = self.COMPILER_CACHE_LOCATION
os.environ['TRTLLM_TOOLKIT_SKIP_DOWNLOAD_DIR_CLEANUP'] = 'true'
os.environ['TRTLLM_TOOLKIT_SKIP_CHECKPOINT_DIR_CLEANUP'] = 'true'
try:
create_model_repo(self.INPUT_MODEL_DIRECTORY, **kwargs)
except Exception as exc:
raise CompilationFatalError(
f"Encountered an error during TRT-LLM compilation: {exc}")

def get_properties(self):
"""Get properties from serving.properties and/or environment variables."""
self.properties = update_kwargs_with_env_vars({})
self.properties.update(load_properties(self.INPUT_MODEL_DIRECTORY))

def generate_properties_file(self):
"""Generate serving.properties file in output repo, so compiled artifacts can be deployed."""
with open(
os.path.join(self.OUTPUT_MODEL_DIRECTORY,
"serving.properties"), "w") as f:
f.write("engine=MPI\n")
for key, value in self.properties.items():
if key != "option.model_id" and key != "option.model_dir":
f.write(f"{key}={value}\n")

def neo_partition(self):
update_dataset_cache_location(self.HF_CACHE_LOCATION)
self.get_properties()
self.run_partition()
self.generate_properties_file()


def main():
"""
Convert from SageMaker Neo interface to DJL-Serving format for TRT-LLM compilation
"""
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
level=logging.INFO)
level=logging.INFO,
force=True)

compilation_error_file = None
try:
(compiler_options, input_model_directory, compiled_model_directory,
compilation_error_file, neo_cache_dir,
neo_hf_cache_dir) = get_neo_env_vars()

# Neo requires that serving.properties is in same dir as model files
serving_properties = load_properties(input_model_directory)
compiler_flags = get_neo_compiler_flags(compiler_options)
kwargs = {}

if compiler_flags is not None:
# If set, prefer Neo CompilerOptions flags over LMI serving.properties
logging.info(
f"Using CompilerOptions from SageMaker Neo. If a `serving.properties`"
" file was provided, it will be ignored for compilation.")
verify_neo_compiler_flags(compiler_flags)
kwargs = compiler_flags
elif len(serving_properties) > 0:
# Else, if present, use LMI serving.properties options
logging.info(
f"Using compiler options from serving.properties file")

for key, value in serving_properties.items():
if key.startswith(DJL_SERVING_OPTION_PREFIX):
kwargs[key[len(DJL_SERVING_OPTION_PREFIX):]] = value
else:
kwargs[key] = value
else:
raise InputConfiguration(
"Neither compiler_flags nor serving.properties found. Please either:"
" \nA) specify `compiler_flags` in the CompilerOptions field of SageMaker Neo or CreateCompilationJob API, or"
" \nB) include a `serving.properties` file along with your model files."
" \nFor info on valid `compiler_flags` fields and values for TensorRT-LLM, see SageMaker Neo documentation:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
" \nFor `serving.properties` configuration, see"
" https://docs.djl.ai/docs/serving/serving/docs/lmi/user_guides/trt_llm_user_guide.html"
" Note that SageMaker Neo requires that the `serving.properties` file is placed in the"
" same directory as the model files, i.e. on the same level as `config.json` and checkpoints."
)

try:
kwargs["trt_llm_model_repo"] = compiled_model_directory
kwargs["neo_cache_dir"] = neo_cache_dir
create_model_repo(input_model_directory, **kwargs)
except Exception as exc:
raise CompilationFatalError(
f"Encountered an error during TRT-LLM compilation: {exc}")

neo_trtllm_partition_service = NeoTRTLLMPartitionService()
neo_trtllm_partition_service.neo_partition()
except Exception as exc:
write_error_to_file(exc, compilation_error_file)
write_error_to_file(
exc, neo_trtllm_partition_service.COMPILATION_ERROR_FILE)
raise exc


Expand Down
9 changes: 2 additions & 7 deletions serving/docker/partition/trt_llm_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@
import sys
from tensorrt_llm_toolkit import create_model_repo

from utils import update_kwargs_with_env_vars, load_properties
from utils import update_kwargs_with_env_vars, load_properties, remove_option_from_properties


def create_trt_llm_repo(properties, args):
kwargs = {}
for key, value in properties.items():
if key.startswith("option."):
kwargs[key[7:]] = value
else:
kwargs[key] = value
kwargs = remove_option_from_properties(properties)
kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo
kwargs["tensor_parallel_degree"] = args.tensor_parallel_degree
model_id_or_path = args.model_path or kwargs['model_id']
Expand Down
8 changes: 8 additions & 0 deletions serving/docker/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,11 @@ def init_hf_tokenizer(model_id_or_path: str, hf_configs):
revision=hf_configs.revision,
)
return tokenizer


def update_dataset_cache_location(hf_cache_location):
logging.info(
f"Updating HuggingFace Datasets cache directory to: {hf_cache_location}"
)
os.environ['HF_DATASETS_CACHE'] = hf_cache_location
#os.environ['HF_DATASETS_OFFLINE'] = "1"

0 comments on commit 9b72600

Please sign in to comment.