Skip to content

Commit

Permalink
Avoid duplicated calls of postprocess in training frontend (#4579)
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao authored Aug 8, 2020
1 parent 77c69a0 commit abbb7f6
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions orttraining/orttraining/python/ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def forward(self, *inputs):

return model

def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION, _enable_internal_postprocess=True):
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION):
# example: {input0:{0:'batch'}, input1:{0:'batch'}}
dynamic_axes = {}
for input in model_desc.inputs_:
Expand Down Expand Up @@ -367,9 +367,6 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
"Initializer names do not match between PyTorch model and ONNX model, " \
"please report a bug to ONNX Runtime."

if _enable_internal_postprocess:
onnx_model = postprocess.run_postprocess(onnx_model)

return onnx_model

def create_ort_training_session_with_optimizer(model, device, training_optimizer_name, lr_params_feed_name,
Expand Down Expand Up @@ -609,6 +606,9 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti

self.torch_model_ = None
self.onnx_model_ = None
self._enable_internal_postprocess = _enable_internal_postprocess
self._extra_postprocess = _extra_postprocess

if isinstance(model, torch.nn.Module):
self.torch_model_ = model
self.loss_fn_ = loss_fn
Expand All @@ -619,6 +619,12 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
self.loss_fn_ = None

if self._enable_internal_postprocess:
postprocess.run_postprocess(self.onnx_model_)

if self._extra_postprocess:
self._extra_postprocess(self.onnx_model_)

self.model_desc_ = model_desc
self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description]

Expand All @@ -640,7 +646,6 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti
# we use self.global_step_ to count optimizations being performed.
# it is used to calculate learning rate if self.get_lr_this_step_ is provided.
self.global_step_ = global_step
self._extra_postprocess = _extra_postprocess
self.get_lr_this_step_ = get_lr_this_step
self.loss_scaler_ = loss_scaler

Expand All @@ -655,7 +660,6 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti
self.frozen_weights_ = frozen_weights
self.opset_version_ = _opset_version
self.state_dict_ = None
self._enable_internal_postprocess = _enable_internal_postprocess
self._use_deterministic_compute = _use_deterministic_compute
self.use_invertible_layernorm_grad = use_invertible_layernorm_grad

Expand All @@ -669,12 +673,6 @@ def _init_session(self):
if self.onnx_model_ is None:
return

if self._enable_internal_postprocess:
self._onnx_model_ = postprocess.run_postprocess(self.onnx_model_)

if self._extra_postprocess:
self._extra_postprocess(self.onnx_model_)

self._verify_fully_optimized_model(self.onnx_model_)
self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
create_ort_training_session_with_optimizer(
Expand Down Expand Up @@ -732,7 +730,13 @@ def _init_onnx_model(self, inputs):
torch_buffers = list(dict(self.torch_model_.named_buffers()).keys())
self.frozen_weights_ = self.frozen_weights_ + torch_buffers
self.onnx_model_ = convert_model_loss_fn_to_onnx(
self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_, _enable_internal_postprocess=self._enable_internal_postprocess)
self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_)

if self._enable_internal_postprocess:
postprocess.run_postprocess(self.onnx_model_)

if self._extra_postprocess:
self._extra_postprocess(self.onnx_model_)

self._init_session()

Expand Down

0 comments on commit abbb7f6

Please sign in to comment.