Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen model fails #589

Closed
4 tasks
PoodleWang opened this issue Dec 14, 2022 · 4 comments
Closed
4 tasks

Codegen model fails #589

PoodleWang opened this issue Dec 14, 2022 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@PoodleWang
Copy link

System Info

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer,pipeline

ort_model = ORTModelForCausalLM.from_pretrained(
    "Salesforce/codegen-6B-mono",
    from_transformers=True,
    provider="CUDAExecutionProvider",
    use_io_binding= False
)

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-6B-mono")


pipe = pipeline(task="text-generation", model=ort_model, tokenizer=tokenizer)
result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
print(result)

home/tiger/.local/lib/python3.7/site-packages/transformers/models/codegen/modeling_codegen.py:167: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at  /pytorch/aten/src/ATen/native/TensorCompare.cpp:255.)
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
Traceback (most recent call last):
  File "/opt/tiger/genius/tensorrt/load.py", line 8, in <module>
    use_io_binding= False
  File "/home/tiger/.local/lib/python3.7/site-packages/optimum/onnxruntime/modeling_ort.py", line 280, in from_pretrained
    **kwargs,
  File "/home/tiger/.local/lib/python3.7/site-packages/optimum/modeling_base.py", line 263, in from_pretrained
    **model_kwargs,
  File "/home/tiger/.local/lib/python3.7/site-packages/optimum/onnxruntime/modeling_ort.py", line 412, in _from_transformers
    output=save_dir.joinpath(ONNX_WEIGHTS_NAME),
  File "/home/tiger/.local/lib/python3.7/site-packages/transformers/onnx/convert.py", line 353, in export
    return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
  File "/home/tiger/.local/lib/python3.7/site-packages/transformers/onnx/convert.py", line 189, in export_pytorch
    opset_version=opset,
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/__init__.py", line 280, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/utils.py", line 94, in export
    use_external_data_format=use_external_data_format)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/utils.py", line 695, in _export
    dynamic_axes=dynamic_axes)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/utils.py", line 467, in _model_to_graph
    module=module)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/utils.py", line 200, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/__init__.py", line 313, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/utils.py", line 994, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/symbolic_helper.py", line 172, in wrapper
    return fn(g, *args, **kwargs)
  File "/home/tiger/.local/lib/python3.7/site-packages/torch/onnx/symbolic_opset13.py", line 73, in split
    if self.type().sizes()[dim] is not None:
TypeError: 'NoneType' object is not subscribable 
(Occurred when translating split).


#pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

python3.7
A100-80G

Expected behavior

model run correctly

@PoodleWang PoodleWang added the bug Something isn't working label Dec 14, 2022
@JingyaHuang JingyaHuang self-assigned this Dec 14, 2022
@JingyaHuang
Copy link
Contributor

Hi @PoodleWang,

With the snippet, I encounter an error but not exactly the same as yours:

Traceback (most recent call last):
  File "test_codegen.py", line 10, in <module>
    ort_model = ORTModelForCausalLM.from_pretrained(
  File "/workspace/optimum/onnxruntime/modeling_ort.py", line 523, in from_pretrained
    return super().from_pretrained(
  File "/workspace/optimum/modeling_base.py", line 325, in from_pretrained
    return from_pretrained_method(
  File "/workspace/optimum/onnxruntime/modeling_decoder.py", line 666, in _from_transformers
    return cls._from_pretrained(
  File "/workspace/optimum/onnxruntime/modeling_decoder.py", line 551, in _from_pretrained
    model = cls.load_model(
  File "/workspace/optimum/onnxruntime/modeling_decoder.py", line 437, in load_model
    decoder_session = onnxruntime.InferenceSession(
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 395, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Deserialize tensor onnx::MatMul_17298 failed.tensorprotoutils.cc:622 GetExtDataFromTensorProto External initializer: onnx::MatMul_17298 offset: 0 size to read: 268435456 given file_length: 201326592 are out of bounds or can not be read in full.

Which seems to be due to some large external tensors surpassing the 2GB limits. Did you succeed in exporting the ONNX model?

I will take a closer look tomorrow, and there are some ongoing PR to improve large ONNX export #255 #586 .

@PoodleWang
Copy link
Author

PoodleWang commented Dec 15, 2022

@JingyaHuang

I tried the PR you mentioned. However, I still get failed. This is my code:

I converted the 6b-mono to the onnx model and store locally
I also give a config under this folder
model = ORTModelForCausalLM.from_pretrained("/opt/tiger/genius/checkpoints/codegen-6B-mono-onnx",use_transformers = True, onnx_folder="onnx", use_io_binding= False)

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-6B-mono")
pie = pipeline('text-generation', model=model, tokenizer=tokenizer)
a = pie("Hello world")

** LOG
2022-12-15 09:39:53.798385: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
Setting pad_token_id to eos_token_id:50256 for open-end generation.
/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py:1364: UserWarning: Neither max_length nor max_new_tokens has been set, max_length will default to 50 (self.config.max_length). Controlling max_length via the config is deprecated and max_length will be removed from the config in v5 of Transformers -- we recommend using max_new_tokens to control the maximum length of the generation.
UserWarning,
Traceback (most recent call last):
File "/opt/tiger/genius/tensorrt/load.py", line 23, in
a = pie("Hello world")
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/text_generation.py", line 187, in call
return super().call(text_inputs, **kwargs)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1074, in call
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1081, in run_single
model_outputs = self.forward(model_inputs, **forward_params)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/base.py", line 990, in forward
model_outputs = self._forward(model_inputs, **forward_params)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/text_generation.py", line 229, in _forward
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
File "/home/tiger/.local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py", line 1553, in generate
**model_kwargs,
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py", line 2486, in sample
output_hidden_states=output_hidden_states,
File "/usr/local/lib/python3.7/dist-packages/optimum/modeling_base.py", line 60, in call
return self.forward(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/optimum/onnxruntime/modeling_ort.py", line 1455, in forward
logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device)
KeyError: 'logits'

** environment
optimum == 1.5.1

@PoodleWang
Copy link
Author

@JingyaHuang
experiment 3. I think your recent change for this file: onnxruntime_inference_collection.py is not correct. You can take a look at this file. There is a change for the device recently.
I add parameter: use_io_binding = True, then your model could be loaded.

code
model = ORTModelForCausalLM.from_pretrained("/opt/tiger/genius/checkpoints/codegen-6B-mono-onnx",use_transformers = True, use_auth_token=True, onnx_folder="onnx")

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-6B-mono")

pie = pipeline('text-generation', model=model, tokenizer=tokenizer)
a = pie("Hello world")

** Environment
A100-80G
optimum 1.5.1
pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

**LOG
2022-12-15 10:05:13.552564: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
Setting pad_token_id to eos_token_id:50256 for open-end generation.
/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py:1364: UserWarning: Neither max_length nor max_new_tokens has been set, max_length will default to 50 (self.config.max_length). Controlling max_length via the config is deprecated and max_length will be removed from the config in v5 of Transformers -- we recommend using max_new_tokens to control the maximum length of the generation.
UserWarning,
/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py:1449: UserWarning: You are calling .generate() with the input_ids being on a device type different than your model's device. input_ids is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put input_ids to the correct device by calling for example input_ids = input_ids.to('cuda') before running .generate().
UserWarning,
Traceback (most recent call last):
File "/opt/tiger/genius/tensorrt/load.py", line 24, in
a = pie("Hello world")
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/text_generation.py", line 187, in call
return super().call(text_inputs, **kwargs)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1074, in call
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1081, in run_single
model_outputs = self.forward(model_inputs, **forward_params)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/base.py", line 990, in forward
model_outputs = self._forward(model_inputs, **forward_params)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/pipelines/text_generation.py", line 229, in _forward
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
File "/home/tiger/.local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py", line 1553, in generate
**model_kwargs,
File "/home/tiger/.local/lib/python3.7/site-packages/transformers/generation_utils.py", line 2486, in sample
output_hidden_states=output_hidden_states,
File "/usr/local/lib/python3.7/dist-packages/optimum/modeling_base.py", line 60, in call
return self.forward(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/optimum/onnxruntime/modeling_ort.py", line 1437, in forward
io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_ids, attention_mask)
File "/usr/local/lib/python3.7/dist-packages/optimum/onnxruntime/modeling_ort.py", line 1392, in prepare_io_binding
input_ids.data_ptr(),
File "/home/tiger/.local/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 461, in bind_input
device_id,
TypeError: init(): incompatible constructor arguments. The following argument types are supported:
1. onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice(arg0: int, arg1: int, arg2: int)

Invoked with: 0, 0, None

@NouamaneTazi
Copy link
Member

@PoodleWang can you checkout to this PR #586 and see if it solves your problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants