From 17a40ba8f89d1149146c531744e4168d52132d2a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 23 Aug 2021 10:03:03 +0300 Subject: [PATCH] fixes after review. GRU test was implemented for pytorch frontend --- python/tvm/relay/frontend/common.py | 8 +- .../pytorch/{test_lstms.py => test_rnns.py} | 307 +++++++++++++++--- 2 files changed, 257 insertions(+), 58 deletions(-) rename tests/python/frontend/pytorch/{test_lstms.py => test_rnns.py} (53%) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index e0dce1e212c25..ce048105ae8b2 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -672,7 +672,7 @@ def gru_cell( ): """ Common implementation of GRU cell for all frontends of TVM - TODO(vvchernov): currently it is used by pytorch. Extend for other frontends + TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends Parameters ---------- @@ -709,8 +709,7 @@ def gru_cell( xwt = _op.nn.dense(x_t, w_inp) if linear_before_reset: hwt = _op.nn.dense(hidden_state, w_hid) - # TODO(vvchernov): It is assumed that both bias are or not - if b_inp is not None: + if b_inp is not None and b_hid is not None: xwt += b_inp hwt += b_hid i_r, i_z, i_n = _op.split(xwt, 3, axis=-1) @@ -723,8 +722,7 @@ def gru_cell( w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0) r_gate = i_r + _op.nn.dense(hidden_state, w_hr) z_gate = i_z + _op.nn.dense(hidden_state, w_hz) - # TODO(vvchernov): It is assumed that both bias are or not - if b_inp is not None: + if b_inp is not None and b_hid is not None: b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1) b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1) r_gate += b_ir + b_hr diff --git a/tests/python/frontend/pytorch/test_lstms.py b/tests/python/frontend/pytorch/test_rnns.py similarity index 53% rename from tests/python/frontend/pytorch/test_lstms.py rename to tests/python/frontend/pytorch/test_rnns.py index 967245e1ef9d0..4c8c9c81c3501 100644 --- a/tests/python/frontend/pytorch/test_lstms.py +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -22,22 +22,101 @@ import onnx import io import sys -import pytest from tvm import relay from tvm.contrib import graph_executor from torch import nn -## Model parameters -model_feature_size = 16 -model_hidden_size = 32 -model_num_layers = 2 -seqs_length = 2 +## LSTM parameters +lstm_feature_size = 16 +lstm_hidden_size = 32 +lstm_num_layers = 2 projection_size = 20 + +## GRU parameters +gru_feature_size = 8 +gru_hidden_size = 16 +gru_num_layers = 2 + +seqs_length = 2 batch_size = 2 +class GRU_Model(nn.Module): + def __init__( + self, + device, + seq_len = seqs_length, + batch_size = batch_size, + feature_size = gru_feature_size, + hidden_size = gru_hidden_size, + batch_first=False, + layer_num=1, + bidirectional=False, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + self.batch_first = batch_first + self.seqs_length = seq_len + self.batch_size = batch_size + self.feature_size = feature_size + + self.gru = nn.GRU( + input_size=self.feature_size, + hidden_size=hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ).to(device) + + if rnd_weights_init: + self.gen_rnd_weights() + + def forward(self, input, hidden_init=None): + """ + Computes the output tensor after input inference along GRU layer. + + :param input: batch of data as a tensor of shape (seqs_length, batch_size, feature_size) or (batch_size, seqs_length, feature_size) if self.batch_first = True + :param hidden_init: initial hidden state of the GRU as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the output tensor of shape (batch_size, hidden_size) + """ + out, hidden = self.gru(input, hidden_init) + + return out + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + For first uni- and bidirectional weights group: + Wi (3*hidden_size, feature_size) + Wh (3*hidden_size, hidden_size) + Bi (3*hidden_size) + Bh (3*hidden_size) + For other weights group: + Wi (3*hidden_size, hidden_size) + Wh (3*hidden_size, hidden_size) + Bi (3*hidden_size) + Bh (3*hidden_size) + For generation of random weigths for the model without biases the Bi and Bh weights are skipped + """ + with torch.no_grad(): + for weight_group in self.gru.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) + + def get_dummy_input(self): + shape = [self.seqs_length, self.batch_size, self.feature_size] + if self.batch_first: + shape = [self.batch_size, self.seqs_length, self.feature_size] + res = torch.rand(shape) + + return res, shape + + def check_torch_version_for_proj_in_lstm(): """ proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 torch version @@ -75,8 +154,8 @@ def __init__( if check_torch_version_for_proj_in_lstm(): self.lstm = nn.LSTM( - input_size=model_feature_size, - hidden_size=model_hidden_size, + input_size=lstm_feature_size, + hidden_size=lstm_hidden_size, num_layers=layer_num, bidirectional=bidirectional, proj_size=proj_size, @@ -91,8 +170,8 @@ def __init__( ) # sys.exit() self.lstm = nn.LSTM( - input_size=model_feature_size, - hidden_size=model_hidden_size, + input_size=lstm_feature_size, + hidden_size=lstm_hidden_size, num_layers=layer_num, bidirectional=bidirectional, batch_first=batch_first, @@ -106,9 +185,9 @@ def forward(self, input, hidden_init=None): """ Computes the output tensor after input inference along LSTM layer. - :param input: batch of data as a tensor of shape (seqs_length, batch_size, model_feature_size) or (batch_size, seqs_length, model_feature_size) if self.batch_first = True + :param input: batch of data as a tensor of shape (seqs_length, batch_size, lstm_feature_size) or (batch_size, seqs_length, lstm_feature_size) if self.batch_first = True :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. - :return: the output tensor of shape (batch_size, model_hidden_size) + :return: the output tensor of shape (batch_size, lstm_hidden_size) """ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state # and the final cell state. @@ -121,49 +200,50 @@ def gen_rnd_weights(self): Generate random weigths for the model with biases Without projection: For first weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, model_hidden_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) For first bidirectional weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, model_hidden_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) For other weights group: - Wi (4*model_hidden_size, model_hidden_size) - Wh (4*model_hidden_size, model_hidden_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) + Wi (4*lstm_hidden_size, lstm_hidden_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) With projection: For first weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, proj_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - P (proj_size, model_hidden_size) + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) For first bidirectional weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, proj_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - P (proj_size, model_hidden_size) + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) For other weights group: - Wi (4*model_hidden_size, proj_size * num_directions) - Wh (4*model_hidden_size, proj_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - P (proj_size, model_hidden_size) + Wi (4*lstm_hidden_size, proj_size * num_directions) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) For generation of random weigths for the model without biases Bi and Bh are skipped """ - for weight_group in self.lstm.all_weights: - for weight in weight_group: - weight.data = torch.rand(weight.shape) + with torch.no_grad(): + for weight_group in self.lstm.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) def get_dummy_input(self): - shape = [seqs_length, batch_size, model_feature_size] + shape = [seqs_length, batch_size, lstm_feature_size] if self.batch_first: - shape = [batch_size, seqs_length, model_feature_size] + shape = [batch_size, seqs_length, lstm_feature_size] res = torch.rand(shape) return res, shape @@ -173,6 +253,117 @@ def compare(input, gold_data, rtol=1e-5, atol=1e-5): tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) +def check_gru_with_type( + gru_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0) +): + device = torch.device("cpu") + hidden_layers_num = 1 + model = None + for batch_first in (True, False): + for use_bias in (True, False): + for rnd_weights in [True]: # (True, False): + if gru_type == "uni": + model = GRU_Model( + device, + batch_first=batch_first, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif gru_type == "b": + model = GRU_Model( + device, + batch_first=batch_first, + bidirectional=True, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif gru_type == "s": + model = GRU_Model( + device, + batch_first=batch_first, + layer_num=gru_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = gru_num_layers + elif gru_type == "sb": + model = GRU_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=gru_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * gru_num_layers + else: + print("WARNING: GRU type {} is not supported here!".format(gru_type)) + return + + model.eval() + + # Get golden output from original model + input_hidden_shape = (hidden_layers_num, batch_size, gru_hidden_size) + dummy_input, input_shape = model.get_dummy_input() + golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() + + dtype = "float32" + h_zeros = np.zeros(input_hidden_shape, dtype=dtype) + + tvm_output = None + for format in ["ts"]: # ["ts", "onnx"]: + if format == "ts": + # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. + traced_script_module = torch.jit.trace(model, dummy_input).eval() + + # Import model to Relay + shape_list = [("input", input_shape)] + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + elif format == "onnx": + onnx_io = io.BytesIO() + with torch.no_grad(): + h0 = torch.rand(input_hidden_shape) + input_names = ["input", "h0"] + + # default export (without dynamic input) + torch.onnx.export( + model, (dummy_input, h0), onnx_io, input_names=input_names + ) + onnx_io.seek(0, 0) + onnx_model = onnx.load_model(onnx_io) + + # Import model to Relay + shape_dict = { + "input": input_shape, + "h0": input_hidden_shape, + } + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=1): + lib = relay.build(mod, target=target, params=params) + + # Inference of the model with given input data + m = graph_executor.GraphModule(lib["default"](dev)) + + # Set inputs + m.set_input( + input=tvm.nd.array(dummy_input.numpy().astype(dtype)), + h0=tvm.nd.array(h_zeros), + ) + # Execute + m.run() + # Get outputs (converted to numpy array) + tvm_output = m.get_output(0).numpy() + + compare(tvm_output, golden_output_batch) + + def check_lstm_with_type( lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0) ): @@ -212,31 +403,31 @@ def check_lstm_with_type( model = LSTM_Model( device, batch_first=batch_first, - layer_num=model_num_layers, + layer_num=lstm_num_layers, rnd_weights_init=rnd_weights, use_bias=use_bias, ) - hidden_layers_num = model_num_layers + hidden_layers_num = lstm_num_layers elif lstm_type == "sb": model = LSTM_Model( device, batch_first=batch_first, bidirectional=True, - layer_num=model_num_layers, + layer_num=lstm_num_layers, rnd_weights_init=rnd_weights, use_bias=use_bias, ) - hidden_layers_num = 2 * model_num_layers + hidden_layers_num = 2 * lstm_num_layers elif lstm_type == "sp": model = LSTM_Model( device, batch_first=batch_first, - layer_num=model_num_layers, + layer_num=lstm_num_layers, proj_size=projection_size, rnd_weights_init=rnd_weights, use_bias=use_bias, ) - hidden_layers_num = model_num_layers + hidden_layers_num = lstm_num_layers elif lstm_type == "bp": model = LSTM_Model( device, @@ -252,12 +443,12 @@ def check_lstm_with_type( device, batch_first=batch_first, bidirectional=True, - layer_num=model_num_layers, + layer_num=lstm_num_layers, proj_size=projection_size, rnd_weights_init=rnd_weights, use_bias=use_bias, ) - hidden_layers_num = 2 * model_num_layers + hidden_layers_num = 2 * lstm_num_layers else: print("WARNING: LSTM type {} is not supported here!".format(lstm_type)) return @@ -265,7 +456,7 @@ def check_lstm_with_type( model.eval() # Get golden output from original model - input_hidden_shape = (hidden_layers_num, batch_size, model_hidden_size) + input_hidden_shape = (hidden_layers_num, batch_size, lstm_hidden_size) input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) dummy_input, input_shape = model.get_dummy_input() golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() @@ -346,6 +537,15 @@ def check_lstm_with_type( compare(tvm_output, golden_output_batch) +@tvm.testing.uses_gpu +def test_grus(): + for target, dev in tvm.testing.enabled_targets(): + check_gru_with_type("uni", target, dev) + check_gru_with_type("s", target, dev) + check_gru_with_type("b", target, dev) + check_gru_with_type("sb", target, dev) + + @tvm.testing.uses_gpu def test_lstms(): for target, dev in tvm.testing.enabled_targets(): @@ -361,3 +561,4 @@ def test_lstms(): if __name__ == "__main__": test_lstms() + test_grus()