Skip to content

Commit

Permalink
[CodeGen] ONNX model loading to support >2Gb models / two engines (#991)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored and markurtz committed Jun 8, 2023
1 parent 7f1651d commit b85746d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 94 deletions.
160 changes: 74 additions & 86 deletions examples/codegen/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import os
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, Tuple, Type, Union

import numpy
import onnx
from onnx import ValueInfoProto
from pydantic import BaseModel, Field
from transformers import AutoConfig, AutoTokenizer

from deepsparse import Context, MultiModelEngine, Pipeline
from deepsparse.pipeline import (
Expand All @@ -28,60 +28,62 @@
Engine,
ORTEngine,
)
from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs
from deepsparse.transformers.pipelines import TransformersPipeline
from scipy.special import softmax


_MODEL_DIR_ONNX_MULTI_TOKEN_NAME = "decoder_model.onnx"
_MODEL_DIR_ONNX_NAME = "model.onnx"

__all__ = ["TextGenerationPipeline"]

def overwrite_multi_token_onnx_model_inputs(
external_inputs: List[ValueInfoProto], batch_size: int, max_length: int
) -> List[str]:
"""
Overwrite the input shape of the onnx model for multi token generation.
:param external_inputs: the external inputs of the onnx model
:param batch_size: the batch size of the input
:param max_length: the max length of the input
:return: the input names of the onnx model
"""
input_names = []
for external_input in external_inputs:
for single_input in external_input.type.tensor_type.shape.dim:
if single_input.dim_param == "batch_size":
single_input.dim_value = batch_size
elif single_input.dim_param == "sequence_length":
single_input.dim_value = max_length
input_names.append(external_input.name)
return input_names


def overwrite_transformer_onnx_model_inputs(
path: str,
batch_size: int = 1,
max_length: int = 128,
output_path: Optional[str] = None,
) -> Tuple[Optional[str], List[str], Optional[NamedTemporaryFile]]:
def overwrite_single_token_onnx_model_inputs(
external_inputs: List[ValueInfoProto], batch_size: int, max_length: int
) -> List[str]:
"""
Overrides an ONNX model's inputs to have the given batch size and sequence lengths.
Assumes that these are the first and second shape indices of the given model inputs
respectively
:param path: path to the ONNX model to override
:param batch_size: batch size to set
:param max_length: max sequence length to set
:param output_path: if provided, the model will be saved to the given path,
otherwise, the model will be saved to a named temporary file that will
be deleted after the program exits
:return: if no output path, a tuple of the saved path to the model, list of
model input names, and reference to the tempfile object will be returned
otherwise, only the model input names will be returned
Overwrite the input shapes of the onnx model of the single token model.
:param external_inputs: the external inputs of the onnx model
:param batch_size: the batch size to overwrite the input shapes with
:param max_length: the max length to overwrite the input shapes with
:return: the input names of the onnx model
"""
# overwrite input shapes
model = onnx.load(path)
initializer_input_names = set([node.name for node in model.graph.initializer])
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
input_names = []
for external_input in external_inputs:
# this is removed for now (will need to be accounted for when we start
# supporting deepsparse engine
# external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
# external_input.type.tensor_type.shape.dim[1].dim_value = max_length
for single_input in external_input.type.tensor_type.shape.dim:
if single_input.dim_param == "batch_size":
single_input.dim_value = batch_size
elif single_input.dim_param == "past_sequence_length + sequence_length":
single_input.dim_value = max_length
elif single_input.dim_param == "past_sequence_length + 1":
single_input.dim_value = max_length
input_names.append(external_input.name)
return input_names

# Save modified model
if output_path is None:
tmp_file = NamedTemporaryFile() # file will be deleted after program exit
onnx.save(model, tmp_file.name)

return tmp_file.name, input_names, tmp_file
else:
onnx.save(model, output_path)
return input_names
__all__ = ["TextGenerationPipeline"]


class TextGenerationInput(BaseModel):
Expand Down Expand Up @@ -144,11 +146,6 @@ def __init__(
# initialize the auxiliary multitoken engine
self.multitoken_engine = self._initialize_multitoken_engine()

# re-initialize the target model
# this will be removed once codegen is productionized
self.onnx_path = self._setup_onnx_file_path()
self.engine = self._reinitialize_engine()

if self._batch_size != 1:
raise ValueError(
"For the sake of simplicity, only dynamic"
Expand Down Expand Up @@ -412,6 +409,34 @@ def sample_new_token(
probs = softmax(logits)
return numpy.random.choice(len(probs), p=probs)

def setup_onnx_file_path(self) -> str:
onnx_path = os.path.join(self.model_path, _MODEL_DIR_ONNX_NAME)

config_path = self.model_path
tokenizer_path = self.model_path

self.config = AutoConfig.from_pretrained(
config_path, finetuning_task=self.task if hasattr(self, "task") else None
)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
model_max_length=self.sequence_length,
)
self.config_path = os.path.join(config_path, "config.json")
self.tokenizer_config_path = os.path.join(tokenizer_path, "tokenizer.json")

(
onnx_path,
self.onnx_input_names,
self._temp_model_directory,
) = overwrite_transformer_onnx_model_inputs(
onnx_path,
max_length=self.sequence_length,
custom_input_overwrite_func=overwrite_single_token_onnx_model_inputs,
)

return onnx_path

def _setup_multitoken_onnx_file_path(self) -> str:
# `setup_onnx_file_path` function rewritten
# to setup the multitoken_onnx_file_path
Expand All @@ -424,7 +449,10 @@ def _setup_multitoken_onnx_file_path(self) -> str:
self.multitoken_onnx_input_names,
self._temp_model_directory,
) = overwrite_transformer_onnx_model_inputs(
multitoken_onnx_path, max_length=self.sequence_length
multitoken_onnx_path,
max_length=self.sequence_length,
load_external_data=False,
custom_input_overwrite_func=overwrite_multi_token_onnx_model_inputs,
)

return multitoken_onnx_path
Expand Down Expand Up @@ -452,43 +480,3 @@ def _initialize_multitoken_engine(self) -> Union[Engine, ORTEngine]:
f"Unknown engine_type {self.engine_type}. Supported values include: "
f"{SUPPORTED_PIPELINE_ENGINES}"
)

def _setup_onnx_file_path(self) -> str:
# `setup_onnx_file_path` function rewritten

onnx_path = os.path.join(self.model_path, _MODEL_DIR_ONNX_NAME)
(
onnx_path,
self.onnx_input_names,
self._temp_model_directory,
) = overwrite_transformer_onnx_model_inputs(
onnx_path, max_length=self.sequence_length
)

return onnx_path

def _initialize_engine(self):
return None

def _reinitialize_engine(self) -> Union[Engine, ORTEngine]:
# `_initialize_engine` function rewritten

engine_type = self.engine_type.lower()

if engine_type == DEEPSPARSE_ENGINE:
if self.context is not None and isinstance(self.context, Context):
self._engine_args.pop("num_cores", None)
self._engine_args.pop("scheduler", None)
self._engine_args["context"] = self.context
return MultiModelEngine(
model=self.onnx_path,
**self._engine_args,
)
return Engine(self.onnx_path, **self._engine_args)
elif engine_type == ORT_ENGINE:
return ORTEngine(self.onnx_path, **self._engine_args)
else:
raise ValueError(
f"Unknown engine_type {self.engine_type}. Supported values include: "
f"{SUPPORTED_PIPELINE_ENGINES}"
)
28 changes: 20 additions & 8 deletions src/deepsparse/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import re
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -136,6 +136,8 @@ def overwrite_transformer_onnx_model_inputs(
batch_size: int = 1,
max_length: int = 128,
output_path: Optional[str] = None,
load_external_data: bool = True,
custom_input_overwrite_func: Optional[Callable] = None,
) -> Tuple[Optional[str], List[str], Optional[NamedTemporaryFile]]:
"""
Overrides an ONNX model's inputs to have the given batch size and sequence lengths.
Expand All @@ -148,21 +150,32 @@ def overwrite_transformer_onnx_model_inputs(
:param output_path: if provided, the model will be saved to the given path,
otherwise, the model will be saved to a named temporary file that will
be deleted after the program exits
:param load_external_data: if True, external data will be loaded into the model
graph. If False, external data will not be loaded and the model will be
saved without external data
:custom_input_overwrite_func: if provided, this function will be called instead
of the default input overwrite function. This function should take in a list
of external inputs and return a list of the overwritten input names
:return: if no output path, a tuple of the saved path to the model, list of
model input names, and reference to the tempfile object will be returned
otherwise, only the model input names will be returned
"""
# overwrite input shapes
model = onnx.load(path)
model = onnx.load_model(path, load_external_data=load_external_data)
initializer_input_names = set([node.name for node in model.graph.initializer])
external_inputs = [
inp for inp in model.graph.input if inp.name not in initializer_input_names
]
input_names = []
for external_input in external_inputs:
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
external_input.type.tensor_type.shape.dim[1].dim_value = max_length
input_names.append(external_input.name)
if custom_input_overwrite_func is not None:
input_names = custom_input_overwrite_func(
external_inputs, batch_size, max_length
)
else:
input_names = []
for external_input in external_inputs:
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
external_input.type.tensor_type.shape.dim[1].dim_value = max_length
input_names.append(external_input.name)

# Save modified model
if output_path is None:
Expand All @@ -171,7 +184,6 @@ def overwrite_transformer_onnx_model_inputs(
return tmp_file.name, input_names, tmp_file
else:
save_onnx(model, output_path)

return input_names


Expand Down

0 comments on commit b85746d

Please sign in to comment.