Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving external data for large ONNX models #255

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, get_provider_for_device, parse_device
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
_get_external_data_paths,
get_provider_for_device,
parse_device,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -475,12 +481,16 @@ def _save_pretrained(
"""
src_paths = [self.decoder_model_path]
dst_file_names = [decoder_file_name]

if self.use_cache:
src_paths.append(self.decoder_with_past_model_path)
dst_file_names.append(decoder_with_past_file_name)

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory).joinpath(dst_file_name)
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@classmethod
Expand Down
13 changes: 10 additions & 3 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .io_binding import IOBindingHelper, TypeHelper
from .utils import (
ONNX_WEIGHTS_NAME,
_get_external_data_paths,
get_device_for_provider,
get_provider_for_device,
parse_device,
Expand Down Expand Up @@ -301,9 +302,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ON
file_name (`str`, *optional*, defaults to the value of `optimum.onnxruntime.utils.ONNX_WEIGHTS_NAME`):
The filename to use when saving the model.
"""
# TODO: support models with external data
dst_path = Path(save_directory).joinpath(file_name)
shutil.copyfile(self.model_path, dst_path)
src_paths = [self.model_path]
dst_file_names = [file_name]

# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

@staticmethod
def _generate_regular_names_for_filename(filename: str):
Expand Down
10 changes: 7 additions & 3 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
_get_external_data_paths,
get_provider_for_device,
parse_device,
validate_provider_availability,
Expand Down Expand Up @@ -900,13 +901,16 @@ def _save_pretrained(
The decoder with past key values model file name overwriting the default file name, allowing to save
the decoder model with a different name.
"""
src_file_names = [self.encoder_model_path, self.decoder_model_path]
src_paths = [self.encoder_model_path, self.decoder_model_path]
dst_file_names = [encoder_file_name, decoder_file_name]
if self.use_cache:
src_file_names.append(self.decoder_with_past_model_path)
src_paths.append(self.decoder_with_past_model_path)
dst_file_names.append(decoder_with_past_file_name)

for src_path, dst_file_name in zip(src_file_names, dst_file_names):
# add external data paths in case of large models
src_paths, dst_file_names = _get_external_data_paths(src_paths, dst_file_names)

for src_path, dst_file_name in zip(src_paths, dst_file_names):
dst_path = Path(save_directory) / dst_file_name
shutil.copyfile(src_path, dst_path)

Expand Down
23 changes: 22 additions & 1 deletion optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import importlib.util
import os
from enum import Enum
from typing import Dict, Tuple, Union
from pathlib import Path
from typing import Dict, List, Tuple, Union

import torch
from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
Expand All @@ -25,6 +26,7 @@
import onnx
import onnxruntime as ort
import pkg_resources
from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors

from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss

Expand Down Expand Up @@ -270,3 +272,22 @@ class ORTQuantizableOperator(Enum):
Resize = "Resize"
AveragePool = "AveragePool"
Concat = "Concat"


def _get_external_data_paths(src_paths: List[Path], dst_file_names: List[str]) -> Tuple[List[Path], List[str]]:
"""
Get external data paths from the model and add them to the list of files to copy.
"""
model_paths = src_paths.copy()
for model_path in model_paths:
model = onnx.load(str(model_path), load_external_data=False)
# filter out tensors that are not external data
model_tensors = _get_initializer_tensors(model)
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
src_paths.extend([model_path.parent / tensor_name for tensor_name in model_tensors_ext])
dst_file_names.extend(model_tensors_ext)
return src_paths, dst_file_names