Skip to content

Commit

Permalink
Merge pull request #43 from NVIDIA-AI-IOT/batchnorm1d_fix
Browse files Browse the repository at this point in the history
added support for NC input to batchnorm1d
  • Loading branch information
jaybdub authored Aug 16, 2019
2 parents b9a2c3f + 7c3dd5a commit 3a02f56
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torch2trt/converters/BatchNorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@ def convert_BatchNorm2d(ctx):

# reshape to 2D
layer = ctx.network.add_shuffle(input._trt)
layer.reshape_dims = (-1, input.shape[-1], 1)

if len(input.shape) == 2:
layer.reshape_dims = (input.shape[1], 1, 1)
else:
layer.reshape_dims = (input.shape[1], input.shape[2], 1)

layer = ctx.network.add_scale(layer.get_output(0), trt.ScaleMode.CHANNEL, bias, scale, power)

# reshape back to 2D
# reshape back to 1D
layer = ctx.network.add_shuffle(layer.get_output(0))
layer.reshape_dims = (-1, output.shape[-1])
layer.reshape_dims = tuple(output.shape[1:])

output._trt = layer.get_output(0)




@add_module_test(torch.float32, torch.device('cuda'), [(1, 10)])
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)])
def test_BatchNorm1d_basic():
return torch.nn.BatchNorm1d(10)

0 comments on commit 3a02f56

Please sign in to comment.