Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unexpect behavior GPU runtime with ORTModelForSeq2SeqLM #362

Closed
2 of 4 tasks
tranmanhdat opened this issue Aug 26, 2022 · 3 comments
Closed
2 of 4 tasks

unexpect behavior GPU runtime with ORTModelForSeq2SeqLM #362

tranmanhdat opened this issue Aug 26, 2022 · 3 comments
Assignees
Labels
bug Something isn't working inference Related to Inference onnxruntime Related to ONNX Runtime

Comments

@tranmanhdat
Copy link

System Info

OS: Ubuntu 20.04.4 LTS
CARD: RTX 3080

Libs:
python 3.10.4
onnx==1.12.0
onnxruntime-gpu==1.12.1
torch==1.12.1
transformers==4.21.2

Who can help?

@lewtun @michaelbenayoun @JingyaHuang @echarlaix

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproceduce the behavior:

  1. Convert a public translation from here: vinai-translate-en2vi
from optimum.onnxruntime import ORTModelForSeq2SeqLM
save_directory = "models/en2vi_onnx"
# Load a model from transformers and export it through the ONNX format
model = ORTModelForSeq2SeqLM.from_pretrained('vinai/vinai-translate-en2vi', from_transformers=True)
# Save the onnx model and tokenizer
model.save_pretrained(save_directory)
  1. Load model with modified from example of origin creater model
from transformers import AutoTokenizer, pipeline
from optimum.onnxruntime import ORTModelForSeq2SeqLM
import torch
import time
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer_en2vi = AutoTokenizer.from_pretrained("vinai/vinai-translate-en2vi", src_lang="en_XX")
model_en2vi = ORTModelForSeq2SeqLM.from_pretrained("models/en2vi_onnx")
model_en2vi.to(device)

# onnx_en2vi = pipeline("translation_en_to_vi", model=model_en2vi, tokenizer=tokenizer_en2vi, device=0)
# en_text = '''It's very cold to go out.'''
# start = time.time()
# outpt = onnx_en2vi(en_text)
# end = time.time()
# print(outpt)
# print("time: ", end - start)

def translate_en2vi(en_text: str) -> str:
    start = time.time()
    input_ids = tokenizer_en2vi(en_text, return_tensors="pt").input_ids.to(device)
    end = time.time()
    print("Tokenize time: {:.2f}s".format(end - start))
    # print(input_ids.shape)
    # print(input_ids)
    start = time.time()
    output_ids = model_en2vi.generate(
        input_ids,
        do_sample=True,
        top_k=100,
        top_p=0.8,
        decoder_start_token_id=tokenizer_en2vi.lang_code_to_id["vi_VN"],
        num_return_sequences=1,
    )
    end = time.time()
    print("Generate time: {:.2f}s".format(end - start))
    vi_text = tokenizer_en2vi.batch_decode(output_ids, skip_special_tokens=True)
    vi_text = " ".join(vi_text)
    return vi_text

en_text = '''It's very cold to go out.''' # long paragraph 

start = time.time()
result = translate_en2vi(en_text)
print(result)
end = time.time()
print('{:.2f} seconds'.format((end - start)))

I change line 167 in optimum/onnxruntime/utils.py to return "CUDAExecutionProvider" to run with GPU instead of an error.
3. run example of origin creater model with gpu and compare runtimes

Expected behavior

The onnx model was expected run faster the result is unexpected:

  • Runtime origin model with gpu is 3-5s while take about 3.5GB GPU
    Screenshot from 2022-08-26 09-08-19
  • Runtime onnx converted model with gpu is 70-80s while take about 7.7GB GPU
    Screenshot from 2022-08-26 09-07-45
@JingyaHuang
Copy link
Contributor

Hi @tranmanhdat,

Previously running inference on devices has a significant overhead on data copying as ONNX Runtime will by default copy inputs from CPU to device and copy outputs from device to CPU. Now we just merged the IOBinding support of ONNX Runtime, which will allocate memory for inputs and pre-allocate memory for outputs on devices to reduce the time spending on copies.

Can you try with optimum built from source to see if it improves the performance in your case? Thank you!

There will be no code change as the ORTModels will use IOBinding by default when using CUDA execution provider. You can install optimum from the main branch with the following command:

python -m pip install git+https://github.com/huggingface/optimum.git#egg=optimum[onnxruntime-gpu]

@tranmanhdat
Copy link
Author

I will try it out and feed back as soon as posible, thank you.

@JingyaHuang
Copy link
Contributor

This issue should have been fixed with ORT IOBinding support, I will close it. Feel free to re-open it if you have any further question, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference Related to Inference onnxruntime Related to ONNX Runtime
Projects
None yet
Development

No branches or pull requests

3 participants