-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model.py support in trtllm flow #1041
Conversation
# That base class would look like: | ||
# class TrussExtension(ABC): | ||
# @abstracemethod | ||
# def model_override(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the ABC, I feel like model override could be optional. I.e. you could have an extension that passes some args to a model without supporting overriding.
|
||
This is used if model.py is omitted, which is allowed when using trt_llm. | ||
""" | ||
return self._engine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a base ABC for the truss Model class? Thinking about ways we could show that the Engine class is a Model on its own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have a base ABC for model class right now. It would be really useful to have one, once we create the smaller base truss library for use in runtime.
model_init_params["secrets"] = SecretsResolver.get_secrets(self._config) | ||
if _signature_accepts_keyword_arg(model_class_signature, "lazy_data_resolver"): | ||
model_init_params["lazy_data_resolver"] = LazyDataResolver(data_dir).fetch() | ||
secrets_resolver = SecretsResolver.get_secrets(self._config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
secrets_resolver = SecretsResolver.get_secrets(self._config) | |
secrets = SecretsResolver.get_secrets(self._config) |
if _signature_accepts_keyword_arg(model_class_signature, "lazy_data_resolver"): | ||
model_init_params["lazy_data_resolver"] = LazyDataResolver(data_dir).fetch() | ||
secrets_resolver = SecretsResolver.get_secrets(self._config) | ||
lazy_data_resolver = LazyDataResolver(data_dir).fetch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fetch() is what actually performs the resolution (and returns None). So the difference between this and the existing code is that this will always try to resolve bptrs, regardless of if the model class signature accepts it - which should be fine, since LazyDataResolver can handle there being no bptr manifest.
Looking at the code I also noticed we also have always been assigning model_init_params["lazy_data_resolver"] = None (when model class signature accepts the resolver). I assume we just use it for a check like "lazy_data_resolver" in {"lazy_data_resolver": None}
then.
if _signature_accepts_keyword_arg(signature, ext_name): | ||
model_init_params[ext_name] = ext.model_args() | ||
self._model = model_class(**model_init_params) | ||
elif "trt_llm" in extensions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add a constant for "trt_llm" extension name
🚀 What
Make model.py effective for the trt_llm flow.
💻 How
A concept of truss extension is created and trt_llm is modeled as an extension. Extensions will be bundled with the truss under server/extensions/. Things are pretty hard coded right now. An extension at runtime is modeled as a directory:
ModelWrapper loads all extensions first (right there is only one), and collects arguments to pass to model class' init method. model class' init method is passed an argument by the name of the extension. e.g. for trt_llm extension a parameter named
trt_llm
is passed as a dictionary { "engine": engine_object }. The model can thus make use of the engine to make predictions.If model class is missing in user provided code then the idea is to load an extension provided model object replacement. Right now this is hardcoded to check trt_llm extension only (we can change it to some other strategy if/when there are more than one extensions).
A check is also added where if
trt_llm
section is provided in the config then either the model class' init method should ask for trt_llm arg in the signature, or not have the model class at all. There are already a few cases where trt_llm config is used and where a default model.py may be there in these existing Trusses. After this change these Trusses will error out, prompting users to either add that parameter or remove the model.py. We expect most of those users to remove model.py to keep the previous functionality (where model.py was ignored). This is to avoid a situation where their previously defunct model.py suddenly starts being used after this change, and most likely fail.🔬 Testing
A few unit and integration tests have been added here. I've also done some local testing using GPU. I plan to test on the cluster next.