Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neo] Refactor Neo TRT-LLM partition script #2166

Merged
merged 7 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions serving/docker/partition/sm_neo_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from sm_neo_utils import (CompilationFatalError, write_error_to_file,
get_neo_env_vars)
from utils import extract_python_jar, load_properties
from utils import (extract_python_jar, load_properties,
update_dataset_cache_location)
from properties_manager import PropertiesManager
from partition import PartitionService

Expand All @@ -41,13 +42,6 @@ def __init__(self):
self.HF_CACHE_LOCATION: Final[str] = env[5]
self.TARGET_INSTANCE_TYPE: Final[str] = env[6]

def update_dataset_cache_location(self):
logging.info(
f"Updating HuggingFace Datasets cache directory to: {self.HF_CACHE_LOCATION}"
)
os.environ['HF_DATASETS_CACHE'] = self.HF_CACHE_LOCATION
#os.environ['HF_DATASETS_OFFLINE'] = "1"

def initialize_partition_args_namespace(self):
"""
Initialize args, a SimpleNamespace object that is used to instantiate a
Expand Down Expand Up @@ -123,7 +117,7 @@ def write_properties(self):
self.properties_manager.generate_properties_file()

def neo_quantize(self):
self.update_dataset_cache_location()
update_dataset_cache_location(self.HF_CACHE_LOCATION)
self.initialize_partition_args_namespace()
self.construct_properties_manager()
self.run_quantization()
Expand Down
136 changes: 54 additions & 82 deletions serving/docker/partition/sm_neo_trt_llm_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,104 +12,76 @@
# the specific language governing permissions and limitations under the License.

import logging
import os
import sys
from typing import Final

from utils import load_properties
from sm_neo_utils import InputConfiguration, CompilationFatalError, write_error_to_file, get_neo_env_vars, get_neo_compiler_flags
from utils import (update_kwargs_with_env_vars, load_properties,
update_dataset_cache_location,
remove_option_from_properties)
from sm_neo_utils import (CompilationFatalError, write_error_to_file,
get_neo_env_vars)
from tensorrt_llm_toolkit import create_model_repo

# TODO: Merge the functionality of this file into trt_llm_partition.py
# so all TRT-LLM partitioning is unified

DJL_SERVING_OPTION_PREFIX = "option."
class NeoTRTLLMPartitionService():

def __init__(self):
self.properties: dict = {}

def verify_neo_compiler_flags(compiler_flags):
"""
Verify that provided compiler_flags is a valid configuration
"""
convert_checkpoint_flags = compiler_flags.get("convert_checkpoint_flags")
quantize_flags = compiler_flags.get("quantize_flags")
trtllm_build_flags = compiler_flags.get("trtllm_build_flags")
env = get_neo_env_vars()
self.INPUT_MODEL_DIRECTORY: Final[str] = env[1]
self.OUTPUT_MODEL_DIRECTORY: Final[str] = env[2]
self.COMPILATION_ERROR_FILE: Final[str] = env[3]
self.COMPILER_CACHE_LOCATION: Final[str] = env[4]
self.HF_CACHE_LOCATION: Final[str] = env[5]

if trtllm_build_flags is None:
raise InputConfiguration(
"`compiler_flags` were found, but required sub-field `trtllm_build_flags` was not defined."
" See SageMaker Neo documentation for more info:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
)
if convert_checkpoint_flags is None and quantize_flags is None:
raise InputConfiguration(
"`compiler_flags` were found, but neither sub-fields `convert_checkpoint_flags` "
" or `quantize_flags` were defined, at least one of which must be provided."
" See SageMaker Neo documentation for more info:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
)
if convert_checkpoint_flags is not None and quantize_flags is not None:
logging.warning(
"Both `convert_checkpoint_flags` and `quantize_flags` were provided –"
" `convert_checkpoint_flags` will be used.")
def run_partition(self):
kwargs = remove_option_from_properties(self.properties)
kwargs["trt_llm_model_repo"] = self.OUTPUT_MODEL_DIRECTORY
kwargs["neo_cache_dir"] = self.COMPILER_CACHE_LOCATION
os.environ['TRTLLM_TOOLKIT_SKIP_DOWNLOAD_DIR_CLEANUP'] = 'true'
os.environ['TRTLLM_TOOLKIT_SKIP_CHECKPOINT_DIR_CLEANUP'] = 'true'
try:
create_model_repo(self.INPUT_MODEL_DIRECTORY, **kwargs)
except Exception as exc:
raise CompilationFatalError(
f"Encountered an error during TRT-LLM compilation: {exc}")

def get_properties(self):
"""Get properties from serving.properties and/or environment variables."""
self.properties = update_kwargs_with_env_vars({})
self.properties.update(load_properties(self.INPUT_MODEL_DIRECTORY))

def generate_properties_file(self):
"""Generate serving.properties file in output repo, so compiled artifacts can be deployed."""
with open(
os.path.join(self.OUTPUT_MODEL_DIRECTORY,
"serving.properties"), "w") as f:
f.write("engine=MPI\n")
for key, value in self.properties.items():
if key != "option.model_id" and key != "option.model_dir":
f.write(f"{key}={value}\n")

def neo_partition(self):
update_dataset_cache_location(self.HF_CACHE_LOCATION)
self.get_properties()
self.run_partition()
self.generate_properties_file()


def main():
"""
Convert from SageMaker Neo interface to DJL-Serving format for TRT-LLM compilation
"""
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
level=logging.INFO)
level=logging.INFO,
force=True)

compilation_error_file = None
try:
(compiler_options, input_model_directory, compiled_model_directory,
compilation_error_file, neo_cache_dir,
neo_hf_cache_dir) = get_neo_env_vars()

# Neo requires that serving.properties is in same dir as model files
serving_properties = load_properties(input_model_directory)
compiler_flags = get_neo_compiler_flags(compiler_options)
kwargs = {}

if compiler_flags is not None:
# If set, prefer Neo CompilerOptions flags over LMI serving.properties
logging.info(
f"Using CompilerOptions from SageMaker Neo. If a `serving.properties`"
" file was provided, it will be ignored for compilation.")
verify_neo_compiler_flags(compiler_flags)
kwargs = compiler_flags
elif len(serving_properties) > 0:
# Else, if present, use LMI serving.properties options
logging.info(
f"Using compiler options from serving.properties file")

for key, value in serving_properties.items():
if key.startswith(DJL_SERVING_OPTION_PREFIX):
kwargs[key[len(DJL_SERVING_OPTION_PREFIX):]] = value
else:
kwargs[key] = value
else:
raise InputConfiguration(
"Neither compiler_flags nor serving.properties found. Please either:"
" \nA) specify `compiler_flags` in the CompilerOptions field of SageMaker Neo or CreateCompilationJob API, or"
" \nB) include a `serving.properties` file along with your model files."
" \nFor info on valid `compiler_flags` fields and values for TensorRT-LLM, see SageMaker Neo documentation:"
" https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting.html"
" \nFor `serving.properties` configuration, see"
" https://docs.djl.ai/docs/serving/serving/docs/lmi/user_guides/trt_llm_user_guide.html"
" Note that SageMaker Neo requires that the `serving.properties` file is placed in the"
" same directory as the model files, i.e. on the same level as `config.json` and checkpoints."
)

try:
kwargs["trt_llm_model_repo"] = compiled_model_directory
kwargs["neo_cache_dir"] = neo_cache_dir
create_model_repo(input_model_directory, **kwargs)
except Exception as exc:
raise CompilationFatalError(
f"Encountered an error during TRT-LLM compilation: {exc}")

neo_trtllm_partition_service = NeoTRTLLMPartitionService()
neo_trtllm_partition_service.neo_partition()
except Exception as exc:
write_error_to_file(exc, compilation_error_file)
write_error_to_file(
exc, neo_trtllm_partition_service.COMPILATION_ERROR_FILE)
ethnzhng marked this conversation as resolved.
Show resolved Hide resolved
raise exc


Expand Down
9 changes: 2 additions & 7 deletions serving/docker/partition/trt_llm_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@
import sys
from tensorrt_llm_toolkit import create_model_repo

from utils import update_kwargs_with_env_vars, load_properties
from utils import update_kwargs_with_env_vars, load_properties, remove_option_from_properties


def create_trt_llm_repo(properties, args):
kwargs = {}
for key, value in properties.items():
if key.startswith("option."):
kwargs[key[7:]] = value
else:
kwargs[key] = value
kwargs = remove_option_from_properties(properties)
kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo
kwargs["tensor_parallel_degree"] = args.tensor_parallel_degree
model_id_or_path = args.model_path or kwargs['model_id']
Expand Down
8 changes: 8 additions & 0 deletions serving/docker/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,11 @@ def init_hf_tokenizer(model_id_or_path: str, hf_configs):
revision=hf_configs.revision,
)
return tokenizer


def update_dataset_cache_location(hf_cache_location):
logging.info(
f"Updating HuggingFace Datasets cache directory to: {hf_cache_location}"
)
os.environ['HF_DATASETS_CACHE'] = hf_cache_location
#os.environ['HF_DATASETS_OFFLINE'] = "1"
Loading