Skip to content

Commit

Permalink
minor fix: Small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Jun 6, 2023
1 parent 075a028 commit e92034d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
23 changes: 13 additions & 10 deletions py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import timm
import pytest
import unittest

import torch_tensorrt as torchtrt
import torchvision.models as models
Expand All @@ -12,6 +13,8 @@
cosine_similarity,
)

assertions = unittest.TestCase()


@pytest.mark.unit
def test_resnet18(ir):
Expand All @@ -32,9 +35,9 @@ def test_resnet18(ir):

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down Expand Up @@ -63,9 +66,9 @@ def test_mobilenet_v2(ir):

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down Expand Up @@ -94,9 +97,9 @@ def test_efficientnet_b0(ir):

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down Expand Up @@ -138,9 +141,9 @@ def test_bert_base_uncased(ir):
for key in model_outputs.keys():
out, trt_out = model_outputs[key], trt_model_outputs[key]
cos_sim = cosine_similarity(out, trt_out)
assert (
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down Expand Up @@ -169,9 +172,9 @@ def test_resnet18_half(ir):

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down
15 changes: 1 addition & 14 deletions py/torch_tensorrt/fx/converters/impl/convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
Expand Down Expand Up @@ -62,21 +62,8 @@ def convNd(
# Transform the bias constant into a Numpy array
bias = to_numpy(bias)

# Prepend new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
bias = np.expand_dims(bias, 0)

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")
# Prepend new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
kwargs = {
"input": bias,
"dim": 0,
}
bias = acc_ops_unsqueeze(
network, target, tuple(), kwargs, name + "_unsqueeze_bias"
)

elif bias is not None:
raise RuntimeError(
Expand Down

0 comments on commit e92034d

Please sign in to comment.