Skip to content

Commit

Permalink
fix minor unit test bug
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Jan 2, 2024
1 parent f203986 commit 937d30e
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List

import numpy as np
import pytest
Expand Down Expand Up @@ -110,11 +110,11 @@ class TestNormalize:
@pytest.mark.parametrize(
"mean,std,to_rgb,expected",
[
(1.0, 1.0, True, np.array([[[1.0, 0.0, 0.0]]], dtype=np.float32)),
(1.0, 1.0, False, np.array([[[-1.0, 0.0, 0.0]]], dtype=np.float32)),
([1.0 for _ in range(3)], [1.0 for _ in range(3)], True, np.array([[[1.0, 0.0, -1.0]]], dtype=np.float32)),
([1.0 for _ in range(3)], [1.0 for _ in range(3)], False, np.array([[[-1.0, 0.0, 1.0]]], dtype=np.float32)),
],
)
def test_call(self, mean: float, std: float, to_rgb: bool, expected: np.array) -> None:
def test_call(self, mean: List[float], std: List[float], to_rgb: bool, expected: np.array) -> None:
"""Test __call__."""
normalize = Normalize(mean=mean, std=std, to_rgb=to_rgb)
inputs = dict(img=np.arange(3).reshape(1, 1, 3))
Expand Down

0 comments on commit 937d30e

Please sign in to comment.