-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable traced model for text-generation task #22072
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
8bf921c
to
080fd1e
Compare
@sgugger please help review |
@gante Could you have a first look at the changes in generate? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey everyone 👋 Here's my two cents.
In terms of technical correctness, no issues. Regarding long-term maintenance, I'm not supportive of the PR as it is. We are adding trace_graph
references to the main body of generate, but there is no definition or explanation regarding trace_graph
outside the example -- how can a user know what it means?. It also requires dedicated tensor manipulation for the specific use case, which is undesirable (source of maintenance problems and reduced code readability).
If the purpose of this PR is to create an example, I'd suggest to move the new generation logic to the example itself. For instance, in the example, you could redefine model.prepare_inputs_for_generation()
to also include your input tuple preparation and model.traced_graph()
to wrap the output in our model output classes. These two changes would allow running .generate()
as is while illustrating the speedups of your changes :)
@gante Hi, Gante. Thanks for your delicate comment, it's reasonable and I agree with it.
I strongly recommend the first way. There are many ways to optimize BTW, you set Thanks! |
@jiqing-feng Thank you for your comment. To clarify my position further, in an attempt to find a solution that pleases us all: from the From my perspective, I haven't seen any request for This doesn't mean that my perspective is static on the subject! I've suggested above what can be done to showcase I apologize if this is not the answer you'd like to read, but we do have to be picky with the changes we introduce in actively maintained cross-model functionality. I'm also working towards increasing the modularity of |
Just my +1 , generation speed improvement, especially with torch 2.0 is something very nice for make the model production ready |
Yes, echo. W/ PyTorch 2.0 introduced, suppose we will see more and more performance benefit out of jit for deployment. |
@sywangyi
Enable traced model for text-generation task.
I changed beam_search and greedy_search of generation for traced model. If a traced model has been set on the attribute of "trace_graph", then we will use the model.trace_grapg to forward. I also changed the text-generation example and found that model optimized by jit trace performs better on text-generation task. The data running on a A100 is as below:
model: gptj-6b
beam search: input_tokens=32, output_tokens=32, num_beam=4
data type: bf16
original model's latency: 0.96s
jit trace model's latency: 0.72s