Skip to content

Commit

Permalink
Patch ORTTrainer inference with ONNX Runtime backend (#737)
Browse files Browse the repository at this point in the history
* Fix seq2seq

* Adapt export for decoders

* debug

* decoder with cache enabled

* Fix for seq2seq trainer ort inf

* Remove comments

* Replace default decoder_input_ids by raise
  • Loading branch information
JingyaHuang authored Feb 2, 2023
1 parent 17b76db commit 334d3cf
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 45 deletions.
31 changes: 24 additions & 7 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,19 @@ def forward(
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> CausalLMOutputWithCrossAttentions:
known_output_shapes = {}
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = [past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer]

if self.device.type == "cuda" and self.parent_model.use_io_binding:
past_key_values_shapes = self.compute_past_key_values_output_shapes(
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
past_key_values=past_key_values,
)

past_key_values_inputs = past_key_values if past_key_values is not None else [None]

model_inputs = [input_ids]

if "attention_mask" in self.input_names:
Expand All @@ -216,10 +216,14 @@ def forward(
if past_key_values is not None:
model_inputs += past_key_values

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
known_output_shapes=past_key_values_shapes,
known_output_shapes=known_output_shapes,
)

io_binding.synchronize_inputs()
Expand All @@ -236,6 +240,10 @@ def forward(
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))

logits = output_buffers["logits"].view(output_shapes["logits"])

loss = None
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
Expand All @@ -247,6 +255,9 @@ def forward(
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
onnx_inputs["labels"] = labels.cpu().detach().numpy()

# Run inference
outputs = self.session.run(None, onnx_inputs)

Expand All @@ -262,7 +273,11 @@ def forward(
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))
logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)

return CausalLMOutputWithCrossAttentions(logits=logits, past_key_values=past_key_values)
loss = None
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)

return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits, past_key_values=past_key_values)


class ORTDecoderForSeq2Seq(ORTDecoder):
Expand Down Expand Up @@ -299,14 +314,15 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> Seq2SeqLMOutput:
known_output_shapes = {}
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
past_key_values_shapes = self.compute_past_key_values_output_shapes(
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
encoder_hidden_states,
past_key_values=past_key_values,
Expand All @@ -330,11 +346,12 @@ def filter_out_output(output_name):

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
known_output_shapes=past_key_values_shapes,
known_output_shapes=known_output_shapes,
forward_function=self.forward,
outputs_to_not_bind=outputs_to_not_bind,
)
Expand Down
12 changes: 10 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,19 +543,27 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithCrossAttentions:

if past_key_values is None or self.decoder_with_past is None:
outputs = self.decoder(input_ids=input_ids, attention_mask=attention_mask)
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
else:
outputs = self.decoder_with_past(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
labels=labels,
)

return CausalLMOutputWithCrossAttentions(logits=outputs.logits, past_key_values=outputs.past_key_values)
return CausalLMOutputWithCrossAttentions(
loss=outputs.get("loss", None), logits=outputs.logits, past_key_values=outputs.past_key_values
)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,11 @@ def _prepare_output_buffer(self, model: ort.InferenceSession, output_shape: Tupl
"""Prepares the buffer of output_name with a 1D tensor."""
ort_type = TypeHelper.get_output_type(model, output_name)
torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self.device).contiguous()
if len(output_shape) > 0:
output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self.device).contiguous()
else:
# Case when the output is a scalar
output_buffer = torch.tensor(0, dtype=torch_type, device=self.device).contiguous()
return output_buffer

def _output_shape_inference(self, axis_name: Union[str, int], dimensions: Dict[str, int]) -> Union[str, int]:
Expand Down Expand Up @@ -727,7 +731,6 @@ def _prepare_io_binding(
tuple(tensor.shape),
tensor.data_ptr(),
)

dimensions = {}
for input_ in model.get_inputs():
shape = input_.shape
Expand Down
3 changes: 3 additions & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,9 @@ def forward(
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

# Decode
if decoder_input_ids is None:
raise ValueError("You have to specify either decoder_input_ids.")

if past_key_values is None or self.decoder_with_past is None:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
Expand Down
78 changes: 50 additions & 28 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from transformers.utils import logging

from ..exporters import TasksManager
from ..exporters.onnx import OnnxConfigWithPast, export
from ..exporters.onnx import OnnxConfigWithPast, export, export_models, get_decoder_models_for_export
from .modeling_decoder import ORTModelForCausalLM
from .modeling_ort import (
ORTModel,
Expand Down Expand Up @@ -163,6 +163,16 @@ def get_model_class_for_feature(feature: str) -> Type:

return ORTFeaturesManager._TASKS_TO_ORTMODELS[feature]

@staticmethod
def do_use_cache(feature: str) -> bool:
"""
Gets the value of `use_cache` for the feature.
"""
if "-with-past" in feature:
return True
else:
return False


class ORTTrainer(Trainer):
"""
Expand Down Expand Up @@ -967,29 +977,26 @@ def evaluation_loop_ort(

# With `label_smoother` the loss will be computed outside modeling
with_loss = has_labels and not self.label_smoother
self._export(onnx_model_path, with_loss=with_loss, device=export_device)
use_cache = ORTFeaturesManager.do_use_cache(self.feature)
self._export(onnx_model_path, with_loss=with_loss, device=export_device, use_cache=use_cache)

self.exported_with_loss = with_loss
self.onnx_model_path = onnx_model_path.as_posix()
logger.info("[INFO] ONNX model is stored in:\n", self.onnx_model_path)

# Load ORT model
if not self.exported_with_loss and self.feature in ORTFeaturesManager.SUPPORTED_FEATURES:
if self.feature in ORTFeaturesManager.SUPPORTED_FEATURES:
# Exported with standard outputs, use specific ORTModels
ort_model_cls = ORTFeaturesManager.get_model_class_for_feature(self.feature)
else:
ort_model_cls = ORTModelForCustomTasks

model_id = self.onnx_model_path
args = self.args
# Temporary fix for decoder, now `use_cache` set to False which
# TODO: Use cache once `ORTModelForCausalLM` supports `loss` as output
if ort_model_cls is ORTModelForCausalLM:
ort_model = ort_model_cls.from_pretrained(
model_id=model_id, use_cache=False, provider="CUDAExecutionProvider"
)
ort_model = ort_model_cls.from_pretrained(model_id=model_id, use_cache=use_cache).to(args.device)
else:
ort_model = ort_model_cls.from_pretrained(model_id=model_id, provider="CUDAExecutionProvider")
ort_model = ort_model_cls.from_pretrained(model_id=model_id).to(args.device)

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

Expand Down Expand Up @@ -1450,6 +1457,7 @@ def _export(
opset: Optional[int] = None,
device: str = "cpu",
with_loss: bool = True,
use_cache: bool = False,
) -> None:
"""
Load and export a model to an ONNX format.
Expand All @@ -1472,14 +1480,6 @@ def _export(
self.model.to("cpu")
model = unwrap_model(self.model)

# TODO: Remove once `ORTModelForCausalLM` supports `loss` as an output
if "-with-past" in self.feature:
correct_feature = self.feature.replace("-with-past", "")
raise NotImplementedError(
"`use_cache` is not yet supported for ONNX Runtime inference in `ORTTrainer`, please replace "
f"{self.feature} task by {correct_feature}."
)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=self.feature
)
Expand All @@ -1488,18 +1488,40 @@ def _export(

is_decoder = isinstance(onnx_config, OnnxConfigWithPast)

if with_loss:
onnx_config = wrap_onnx_config_for_loss(onnx_config)
opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12
if is_decoder:
output_names = [ONNX_DECODER_NAME]
if use_cache is True:
output_names.append(ONNX_DECODER_WITH_PAST_NAME)

models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
if with_loss is True:
opset = max(opset, 12)
models_and_onnx_configs_with_loss = {}
for decoder_name, (decoder, decoder_config) in models_and_onnx_configs.items():
models_and_onnx_configs_with_loss[decoder_name] = (
decoder,
wrap_onnx_config_for_loss(decoder_config),
)

output_path = model_path / ONNX_DECODER_NAME if is_decoder else model_path / ONNX_WEIGHTS_NAME
_ = export(
model=model,
config=onnx_config,
opset=opset,
output=output_path,
device=device,
)
export_models(
models_and_onnx_configs=models_and_onnx_configs_with_loss if with_loss else models_and_onnx_configs,
opset=opset,
output_dir=model_path,
output_names=output_names,
)
else:
if with_loss is True:
onnx_config = wrap_onnx_config_for_loss(onnx_config)
opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12

output_path = model_path / ONNX_WEIGHTS_NAME
_ = export(
model=model,
config=onnx_config,
opset=opset,
output=output_path,
device=device,
)

model.config.save_pretrained(model_path)

Expand Down
13 changes: 7 additions & 6 deletions optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def evaluation_loop_ort(

args = self.args
# Load ORT model
self.ort_model = ORTModelForSeq2SeqLM.from_pretrained(
model_id=self.onnx_model_path, provider="CUDAExecutionProvider"
)
self.ort_model = ORTModelForSeq2SeqLM.from_pretrained(model_id=self.onnx_model_path).to(args.device)

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

Expand Down Expand Up @@ -586,7 +584,10 @@ def prediction_step_ort(
else:
generation_inputs = inputs[self.model.main_input_name]

generated_tokens = model.generate(
if torch.cuda.is_available():
self.model.to("cuda")

generated_tokens = self.model.generate(
generation_inputs,
**gen_kwargs,
)
Expand Down Expand Up @@ -776,8 +777,6 @@ def _export(
self.model.to("cpu")
model = unwrap_model(self.model)

use_cache = kwargs.get("use_cache", True)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=self.feature
)
Expand Down Expand Up @@ -814,7 +813,9 @@ def _export(
output=Path(save_dir).joinpath(ONNX_DECODER_NAME),
device=device,
)

# Export the decoder with the past key values
use_cache = kwargs.get("use_cache", True)
if use_cache:
export(
model=model,
Expand Down

0 comments on commit 334d3cf

Please sign in to comment.