Skip to content

Commit

Permalink
[TRTLLM] Add entrypoint for SM Neo AOT compilation (#1665)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethnzhng authored Mar 26, 2024
1 parent 5d78486 commit 20ebf60
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 0 deletions.
8 changes: 8 additions & 0 deletions serving/docker/dockerd-entrypoint-with-cuda-compat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ else
echo "Skip CUDA compat libs setup as package not found"
fi

if [ -n "$SM_NEO_EXECUTION_CONTEXT" ]; then
echo "SageMaker Neo execution context detected"
/usr/bin/python3 /opt/djl/partition/sm_neo_trt_llm_partition.py
exit_code=$?
echo "TensorRT-LLM compilation exited with code $exit_code"
exit $exit_code
fi

# Convert select SM/TGI Environment Variables to LMI Equivalents
translateTGIToLMI "HF_MODEL_QUANTIZE" "OPTION_QUANTIZE"
# We use HF_TRUST_REMOTE_CODE in our properties parsing
Expand Down
170 changes: 170 additions & 0 deletions serving/docker/partition/sm_neo_trt_llm_partition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import json
import logging
import os
import sys
import traceback

from utils import load_properties
from tensorrt_llm_toolkit import create_model_repo

# TODO: Merge the functionality of this file into trt_llm_partition.py
# so all TRT-LLM partitioning is unified

DJL_SERVING_OPTION_PREFIX = "option."


class InputConfiguration(Exception):
"""Raise when SageMaker Neo interface expectation is not met"""


class CompilationFatalError(Exception):
"""Raise for errors encountered during the TensorRT-LLM build process"""


def write_error_to_file(error_message, error_file):
"""
Write error messages to error file
"""
try:
with open(error_file, "w", encoding="utf-8") as f:
json.dump({"error": repr(error_message)}, f)
except:
tb_exc = traceback.format_exc()
logging.error(f"Failed to write error file: {tb_exc}")


def get_neo_env_vars():
"""
Get environment variables required by the SageMaker Neo interface
"""
try:
compiler_options = os.environ["COMPILER_OPTIONS"]
input_model_directory = os.environ["SM_NEO_INPUT_MODEL_DIR"]
compiled_model_directory = os.environ["SM_NEO_COMPILED_MODEL_DIR"]
compilation_error_file = os.environ["SM_NEO_COMPILATION_ERROR_FILE"]
return (compiler_options, input_model_directory,
compiled_model_directory, compilation_error_file)
except KeyError as exc:
raise InputConfiguration(
f"SageMaker Neo environment variable '{exc.args[0]}' expected but not found"
f" \nRequired env vars are: 'COMPILER_OPTIONS', 'SM_NEO_INPUT_MODEL_DIR',"
f" 'SM_NEO_COMPILED_MODEL_DIR', 'SM_NEO_COMPILATION_ERROR_FILE'")


def get_neo_compiler_flags(compiler_options):
"""
Get SageMaker Neo compiler_flags from the CompilerOptions field
"""
try:
# CompilerOptions JSON will always be present, but compiler_flags key is optional
compiler_options = json.loads(compiler_options)
if not isinstance(compiler_options, dict):
raise ValueError("Parsed JSON is not a dictionary")
return compiler_options.get("compiler_flags")
except Exception as exc:
raise InputConfiguration(
f"Failed to parse SageMaker Neo CompilerOptions: {exc}")


def verify_neo_compiler_flags(compiler_flags):
"""
Verify that provided compiler_flags is a valid configuration
"""
convert_checkpoint_flags = compiler_flags.get("convert_checkpoint_flags")
quantize_flags = compiler_flags.get("quantize_flags")
trtllm_build_flags = compiler_flags.get("trtllm_build_flags")

if trtllm_build_flags is None:
raise InputConfiguration(
"`compiler_flags` were found, but required sub-field `trtllm_build_flags` was not defined."
" See SageMaker Neo documentation for more info:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
)
if convert_checkpoint_flags is None and quantize_flags is None:
raise InputConfiguration(
"`compiler_flags` were found, but neither sub-fields `convert_checkpoint_flags` "
" or `quantize_flags` were defined, at least one of which must be provided."
" See SageMaker Neo documentation for more info:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
)
if convert_checkpoint_flags is not None and quantize_flags is not None:
logging.warning(
"Both `convert_checkpoint_flags` and `quantize_flags` were provided –"
" `convert_checkpoint_flags` will be used.")


def main():
"""
Convert from SageMaker Neo interface to DJL-Serving format for TRT-LLM compilation
"""
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
level=logging.INFO)

compilation_error_file = None
try:
(compiler_options, input_model_directory, compiled_model_directory,
compilation_error_file) = get_neo_env_vars()

# Neo requires that serving.properties is in same dir as model files
serving_properties = load_properties(input_model_directory)
compiler_flags = get_neo_compiler_flags(compiler_options)
kwargs = {}

if compiler_flags is not None:
# If set, prefer Neo CompilerOptions flags over LMI serving.properties
logging.info(
f"Using CompilerOptions from SageMaker Neo. If a `serving.properties`"
" file was provided, it will be ignored for compilation.")
verify_neo_compiler_flags(compiler_flags)
kwargs = compiler_flags
elif len(serving_properties) > 0:
# Else, if present, use LMI serving.properties options
logging.info(
f"Using compiler options from serving.properties file")

for key, value in serving_properties.items():
if key.startswith(DJL_SERVING_OPTION_PREFIX):
kwargs[key[len(DJL_SERVING_OPTION_PREFIX):]] = value
else:
kwargs[key] = value
else:
raise InputConfiguration(
"Neither compiler_flags nor serving.properties found. Please either:"
" \nA) specify `compiler_flags` in the CompilerOptions field of SageMaker Neo or CreateCompilationJob API, or"
" \nB) include a `serving.properties` file along with your model files."
" \nFor info on valid `compiler_flags` fields and values for TensorRT-LLM, see SageMaker Neo documentation:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
" \nFor `serving.properties` configuration, see"
" https://docs.djl.ai/docs/serving/serving/docs/lmi/user_guides/trt_llm_user_guide.html"
" Note that SageMaker Neo requires that the `serving.properties` file is placed in the"
" same directory as the model files, i.e. on the same level as `config.json` and checkpoints."
)

try:
kwargs["trt_llm_model_repo"] = compiled_model_directory
create_model_repo(input_model_directory, **kwargs)
except Exception as exc:
raise CompilationFatalError(
f"Encountered an error during TRT-LLM compilation: {exc}")

except Exception as exc:
write_error_to_file(exc, compilation_error_file)
raise exc


if __name__ == "__main__":
main()

0 comments on commit 20ebf60

Please sign in to comment.