Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable traced model for text-generation task #22072

Closed
wants to merge 0 commits into from

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Mar 10, 2023

@sywangyi
Enable traced model for text-generation task.
I changed beam_search and greedy_search of generation for traced model. If a traced model has been set on the attribute of "trace_graph", then we will use the model.trace_grapg to forward. I also changed the text-generation example and found that model optimized by jit trace performs better on text-generation task. The data running on a A100 is as below:
model: gptj-6b
beam search: input_tokens=32, output_tokens=32, num_beam=4
data type: bf16
original model's latency: 0.96s
jit trace model's latency: 0.72s

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@jiqing-feng jiqing-feng force-pushed the main branch 6 times, most recently from 8bf921c to 080fd1e Compare March 13, 2023 08:26
@sywangyi
Copy link
Contributor

@sgugger please help review

@sgugger
Copy link
Collaborator

sgugger commented Mar 13, 2023

@gante Could you have a first look at the changes in generate?

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey everyone 👋 Here's my two cents.

In terms of technical correctness, no issues. Regarding long-term maintenance, I'm not supportive of the PR as it is. We are adding trace_graph references to the main body of generate, but there is no definition or explanation regarding trace_graph outside the example -- how can a user know what it means?. It also requires dedicated tensor manipulation for the specific use case, which is undesirable (source of maintenance problems and reduced code readability).

If the purpose of this PR is to create an example, I'd suggest to move the new generation logic to the example itself. For instance, in the example, you could redefine model.prepare_inputs_for_generation() to also include your input tuple preparation and model.traced_graph() to wrap the output in our model output classes. These two changes would allow running .generate() as is while illustrating the speedups of your changes :)

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Mar 14, 2023

@gante Hi, Gante. Thanks for your delicate comment, it's reasonable and I agree with it.
Here I have two solutions:

  1. For trace_graph in the main body of generate, we can add a doc to explain trace_graph with details, including what it is and how to implement it, and how it helps accelerate inference; For tensor manipulation, the method of preparing input tensors for trace_graph is general for text-generation task across all kinds of models. It can also adapt to any task easily with a few changes(it is in progress) instead of a specific use case. We can put this method on utils in general.
  2. As you said, we can redefine prepare_inputs_for_generation for both inputs and model.trace_graph outputs. However, redefining model.prepare_inputs_for_generation() is not a general way since different model classes have different functions of prepare_inputs_for_generation(), and it is not convenient to inherit different model classes every time we changed the type of model.

I strongly recommend the first way. There are many ways to optimize model.forward, if we can support the attribute trace_graph in the main body of generate, it will be convenient for users to pass their custom models.

BTW, you set return_dict=True in the main body of generate, so it would not work if I set return_dict=False in the .from_pretrain. Could I remove this so the users can decide whether or not to return the dictionary by themselves?

Thanks!

@gante
Copy link
Member

gante commented Mar 14, 2023

@jiqing-feng Thank you for your comment.

To clarify my position further, in an attempt to find a solution that pleases us all: from the transformers perspective, our current priority is the ease of use and experimentation. We also welcome performance-enhancing solutions like the one in the PR, but they must fulfill one of three requirements: (i) they are commonly requested by the community; (ii) they require minimal changes to existing functionality; (iii) the benefits of the new technique are very big, like int8 quantization. If we don't adhere to these principles, the codebase will quickly be unusable and hard to maintain, as there are many possible strategies to improve the code.

From my perspective, I haven't seen any request for torch.jit support in .generate(), and I get tagged in pretty much everything .generate()-related. This PR also includes a diff of 50 lines to existing functions in utils.py and the benefit is up to 20% speedup. This means that, according to the principles stated above, I'm afraid can't support the changes as they are 🤗

This doesn't mean that my perspective is static on the subject! I've suggested above what can be done to showcase torch.jit in the example. That is a way to increase the visibility of the technique, which may increase the community demand for it -- and, if this demand does materialize, I'd be more than happy to include the additional logic in utils.py.

I apologize if this is not the answer you'd like to read, but we do have to be picky with the changes we introduce in actively maintained cross-model functionality. I'm also working towards increasing the modularity of .generate(), so that use cases like yours can be more easily added!

@bratao
Copy link

bratao commented Mar 20, 2023

Just my +1 , generation speed improvement, especially with torch 2.0 is something very nice for make the model production ready

@yao-matrix
Copy link

Yes, echo. W/ PyTorch 2.0 introduced, suppose we will see more and more performance benefit out of jit for deployment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants