Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch ORTTrainer inference with ONNX Runtime backend #737

Merged
merged 9 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,19 @@ def forward(
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> CausalLMOutputWithCrossAttentions:
known_output_shapes = {}
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = [past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer]

if self.device.type == "cuda" and self.parent_model.use_io_binding:
past_key_values_shapes = self.compute_past_key_values_output_shapes(
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
past_key_values=past_key_values,
)

past_key_values_inputs = past_key_values if past_key_values is not None else [None]

model_inputs = [input_ids]

if "attention_mask" in self.input_names:
Expand All @@ -216,10 +216,14 @@ def forward(
if past_key_values is not None:
model_inputs += past_key_values

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
known_output_shapes=past_key_values_shapes,
known_output_shapes=known_output_shapes,
)

io_binding.synchronize_inputs()
Expand All @@ -236,6 +240,10 @@ def forward(
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))

logits = output_buffers["logits"].view(output_shapes["logits"])

loss = None
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
Expand All @@ -247,6 +255,9 @@ def forward(
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
onnx_inputs["labels"] = labels.cpu().detach().numpy()

# Run inference
outputs = self.session.run(None, onnx_inputs)

Expand All @@ -262,7 +273,11 @@ def forward(
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))
logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)

return CausalLMOutputWithCrossAttentions(logits=logits, past_key_values=past_key_values)
loss = None
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)

return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits, past_key_values=past_key_values)


class ORTDecoderForSeq2Seq(ORTDecoder):
Expand Down Expand Up @@ -299,14 +314,15 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> Seq2SeqLMOutput:
known_output_shapes = {}
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
past_key_values_shapes = self.compute_past_key_values_output_shapes(
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
encoder_hidden_states,
past_key_values=past_key_values,
Expand All @@ -330,11 +346,12 @@ def filter_out_output(output_name):

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
known_output_shapes=past_key_values_shapes,
known_output_shapes=known_output_shapes,
forward_function=self.forward,
outputs_to_not_bind=outputs_to_not_bind,
)
Expand Down
12 changes: 10 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,19 +543,27 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithCrossAttentions:

if past_key_values is None or self.decoder_with_past is None:
outputs = self.decoder(input_ids=input_ids, attention_mask=attention_mask)
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
else:
outputs = self.decoder_with_past(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
labels=labels,
)

return CausalLMOutputWithCrossAttentions(logits=outputs.logits, past_key_values=outputs.past_key_values)
return CausalLMOutputWithCrossAttentions(
loss=outputs.get("loss", None), logits=outputs.logits, past_key_values=outputs.past_key_values
)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,11 @@ def _prepare_output_buffer(self, model: ort.InferenceSession, output_shape: Tupl
"""Prepares the buffer of output_name with a 1D tensor."""
ort_type = TypeHelper.get_output_type(model, output_name)
torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self.device).contiguous()
if len(output_shape) > 0:
output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self.device).contiguous()
else:
# Case when the output is a scalar
output_buffer = torch.tensor(0, dtype=torch_type, device=self.device).contiguous()
return output_buffer

def _output_shape_inference(self, axis_name: Union[str, int], dimensions: Dict[str, int]) -> Union[str, int]:
Expand Down Expand Up @@ -727,7 +731,6 @@ def _prepare_io_binding(
tuple(tensor.shape),
tensor.data_ptr(),
)

dimensions = {}
for input_ in model.get_inputs():
shape = input_.shape
Expand Down
3 changes: 3 additions & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,9 @@ def forward(
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

# Decode
if decoder_input_ids is None:
raise ValueError("You have to specify either decoder_input_ids.")

if past_key_values is None or self.decoder_with_past is None:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
Expand Down
78 changes: 50 additions & 28 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from transformers.utils import logging

from ..exporters import TasksManager
from ..exporters.onnx import OnnxConfigWithPast, export
from ..exporters.onnx import OnnxConfigWithPast, export, export_models, get_decoder_models_for_export
from .modeling_decoder import ORTModelForCausalLM
from .modeling_ort import (
ORTModel,
Expand Down Expand Up @@ -163,6 +163,16 @@ def get_model_class_for_feature(feature: str) -> Type:

return ORTFeaturesManager._TASKS_TO_ORTMODELS[feature]

@staticmethod
def do_use_cache(feature: str) -> bool:
"""
Gets the value of `use_cache` for the feature.
"""
if "-with-past" in feature:
return True
else:
return False


class ORTTrainer(Trainer):
"""
Expand Down Expand Up @@ -967,29 +977,26 @@ def evaluation_loop_ort(

# With `label_smoother` the loss will be computed outside modeling
with_loss = has_labels and not self.label_smoother
self._export(onnx_model_path, with_loss=with_loss, device=export_device)
use_cache = ORTFeaturesManager.do_use_cache(self.feature)
self._export(onnx_model_path, with_loss=with_loss, device=export_device, use_cache=use_cache)

self.exported_with_loss = with_loss
self.onnx_model_path = onnx_model_path.as_posix()
logger.info("[INFO] ONNX model is stored in:\n", self.onnx_model_path)

# Load ORT model
if not self.exported_with_loss and self.feature in ORTFeaturesManager.SUPPORTED_FEATURES:
if self.feature in ORTFeaturesManager.SUPPORTED_FEATURES:
# Exported with standard outputs, use specific ORTModels
ort_model_cls = ORTFeaturesManager.get_model_class_for_feature(self.feature)
else:
ort_model_cls = ORTModelForCustomTasks

model_id = self.onnx_model_path
args = self.args
# Temporary fix for decoder, now `use_cache` set to False which
# TODO: Use cache once `ORTModelForCausalLM` supports `loss` as output
if ort_model_cls is ORTModelForCausalLM:
ort_model = ort_model_cls.from_pretrained(
model_id=model_id, use_cache=False, provider="CUDAExecutionProvider"
)
ort_model = ort_model_cls.from_pretrained(model_id=model_id, use_cache=use_cache).to(args.device)
else:
ort_model = ort_model_cls.from_pretrained(model_id=model_id, provider="CUDAExecutionProvider")
ort_model = ort_model_cls.from_pretrained(model_id=model_id).to(args.device)

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

Expand Down Expand Up @@ -1450,6 +1457,7 @@ def _export(
opset: Optional[int] = None,
device: str = "cpu",
with_loss: bool = True,
use_cache: bool = False,
) -> None:
"""
Load and export a model to an ONNX format.
Expand All @@ -1472,14 +1480,6 @@ def _export(
self.model.to("cpu")
model = unwrap_model(self.model)

# TODO: Remove once `ORTModelForCausalLM` supports `loss` as an output
if "-with-past" in self.feature:
correct_feature = self.feature.replace("-with-past", "")
raise NotImplementedError(
"`use_cache` is not yet supported for ONNX Runtime inference in `ORTTrainer`, please replace "
f"{self.feature} task by {correct_feature}."
)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=self.feature
)
Expand All @@ -1488,18 +1488,40 @@ def _export(

is_decoder = isinstance(onnx_config, OnnxConfigWithPast)

if with_loss:
onnx_config = wrap_onnx_config_for_loss(onnx_config)
opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12
if is_decoder:
output_names = [ONNX_DECODER_NAME]
if use_cache is True:
output_names.append(ONNX_DECODER_WITH_PAST_NAME)

models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
if with_loss is True:
opset = max(opset, 12)
models_and_onnx_configs_with_loss = {}
for decoder_name, (decoder, decoder_config) in models_and_onnx_configs.items():
models_and_onnx_configs_with_loss[decoder_name] = (
decoder,
wrap_onnx_config_for_loss(decoder_config),
)

output_path = model_path / ONNX_DECODER_NAME if is_decoder else model_path / ONNX_WEIGHTS_NAME
_ = export(
model=model,
config=onnx_config,
opset=opset,
output=output_path,
device=device,
)
export_models(
models_and_onnx_configs=models_and_onnx_configs_with_loss if with_loss else models_and_onnx_configs,
opset=opset,
output_dir=model_path,
output_names=output_names,
)
else:
if with_loss is True:
onnx_config = wrap_onnx_config_for_loss(onnx_config)
opset = max(opset, 12) # Operators like `nll_loss`are added for opset>=12

output_path = model_path / ONNX_WEIGHTS_NAME
_ = export(
model=model,
config=onnx_config,
opset=opset,
output=output_path,
device=device,
)

model.config.save_pretrained(model_path)

Expand Down
13 changes: 7 additions & 6 deletions optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def evaluation_loop_ort(

args = self.args
# Load ORT model
self.ort_model = ORTModelForSeq2SeqLM.from_pretrained(
model_id=self.onnx_model_path, provider="CUDAExecutionProvider"
)
self.ort_model = ORTModelForSeq2SeqLM.from_pretrained(model_id=self.onnx_model_path).to(args.device)

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

Expand Down Expand Up @@ -586,7 +584,10 @@ def prediction_step_ort(
else:
generation_inputs = inputs[self.model.main_input_name]

generated_tokens = model.generate(
if torch.cuda.is_available():
self.model.to("cuda")

generated_tokens = self.model.generate(
generation_inputs,
**gen_kwargs,
)
Expand Down Expand Up @@ -776,8 +777,6 @@ def _export(
self.model.to("cpu")
model = unwrap_model(self.model)

use_cache = kwargs.get("use_cache", True)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=self.feature
)
Expand Down Expand Up @@ -814,7 +813,9 @@ def _export(
output=Path(save_dir).joinpath(ONNX_DECODER_NAME),
device=device,
)

# Export the decoder with the past key values
use_cache = kwargs.get("use_cache", True)
if use_cache:
export(
model=model,
Expand Down