Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add uint8 bicubic support to ResizeV2 #7668

Merged
merged 3 commits into from
Jun 14, 2023

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Jun 13, 2023

Same as #7557 but for bicubic mode.

See pytorch/pytorch#103252 (comment) for AVX benchmarks. TL;DR: it's ~8X faster for tensors and ~4X for PIL.

Unlike bilinear mode, the bicubic uint8 path seems to be faster than the float path even on non-AVX archs:

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
Timestamp: 20230613-025557
Torch version: 2.1.0a0+git39bf86a
Torch config: PyTorch built with:
  - GCC 11.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=0, USE_CUDNN=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=0, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, 

Num threads: 1

PIL version:  9.5.0
/home/nicolashug/dev/pth_interpolate_vec_uint8/run_bench_interp.py:92: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /home/nicolashug/dev/pytorch/torch/csrc/utils/tensor_numpy.cpp:206.)
  expected_pil = torch.from_numpy(np.asarray(output_pil_img)).clone().permute(2, 0, 1).contiguous()
[---------------------------------------------------------------------------------------------- Resize ----------------------------------------------------------------------------------------------]
                                                                                |      Pillow (9.5.0)     |  torch (2.1.0a0+git39bf86a)   |    torchvision resize    |  Native uint8 vs Resize (float)
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bicubic (256, 256) -> (32, 32) aa=True        |    579.343 (+-17.959)   |       815.110 (+-17.003)      |   1018.556 (+-32.457)    |         1.250 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (32, 32) aa=False       |                         |       228.279 (+-4.612)       |    206.602 (+-2.803)     |         0.905 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (224, 224) aa=True      |   1330.929 (+-14.770)   |      2690.719 (+-28.169)      |   2481.077 (+-23.070)    |         0.922 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (224, 224) aa=False     |                         |      2677.604 (+-55.063)      |   4630.876 (+-117.176)   |         1.729 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (320, 320) aa=True      |   2143.452 (+-62.691)   |      4396.780 (+-44.819)      |   3983.286 (+-48.056)    |         0.906 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (320, 320) aa=False     |                         |      4571.305 (+-341.976)     |   9380.763 (+-262.068)   |         2.052 (+-0.000)        
      3 torch.uint8 channels_last bicubic (520, 520) -> (32, 32) aa=True        |   2124.227 (+-22.774)   |      2890.192 (+-55.722)      |   4015.071 (+-288.245)   |         1.389 (+-0.000)        
      3 torch.uint8 channels_last bicubic (520, 520) -> (32, 32) aa=False       |                         |       444.557 (+-10.736)      |    392.703 (+-6.175)     |         0.883 (+-0.000)        
      3 torch.uint8 channels_last bicubic (520, 520) -> (224, 224) aa=True      |   3666.882 (+-95.871)   |      6426.497 (+-84.957)      |   5433.909 (+-430.537)   |         0.846 (+-0.000)        
      3 torch.uint8 channels_last bicubic (520, 520) -> (224, 224) aa=False     |                         |      6447.592 (+-733.454)     |   6830.963 (+-107.080)   |         1.059 (+-0.000)        
      3 torch.uint8 channels_last bicubic (712, 712) -> (32, 32) aa=True        |   3985.989 (+-100.438)  |      5624.623 (+-45.380)      |   7778.349 (+-873.429)   |         1.383 (+-0.000)        
      3 torch.uint8 channels_last bicubic (712, 712) -> (32, 32) aa=False       |                         |       560.230 (+-5.999)       |    548.018 (+-13.943)    |         0.978 (+-0.000)        
      3 torch.uint8 channels_last bicubic (712, 712) -> (224, 224) aa=True      |   5688.837 (+-65.645)   |      9580.410 (+-186.579)     |   8915.008 (+-662.006)   |         0.931 (+-0.000)        
      3 torch.uint8 channels_last bicubic (712, 712) -> (224, 224) aa=False     |                         |      4594.071 (+-38.363)      |  6618.010 (+-1099.593)   |         1.441 (+-0.000)        
      3 torch.uint8 channels_last bicubic (64, 64) -> (224, 224) aa=True        |    741.229 (+-11.357)   |      1699.632 (+-45.722)      |   1756.545 (+-213.558)   |         1.033 (+-0.000)        
      3 torch.uint8 channels_last bicubic (224, 224) -> (270, 268) aa=True      |   1626.944 (+-23.810)   |       3435.597 (+-9.135)      |   3034.679 (+-67.289)    |         0.883 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (1024, 1024) aa=True    |   14337.421 (+-81.924)  |     34827.674 (+-434.990)     |  47232.543 (+-7380.735)  |         1.356 (+-0.000)        
      3 torch.uint8 channels_last bicubic (224, 224) -> (64, 64) aa=True        |    568.176 (+-11.795)   |       955.263 (+-21.939)      |    842.830 (+-63.085)    |         0.882 (+-0.000)        
      3 torch.uint8 channels_last bicubic (270, 268) -> (224, 224) aa=True      |   1425.972 (+-16.670)   |      2970.503 (+-48.945)      |   2650.246 (+-37.559)    |         0.892 (+-0.000)        
      3 torch.uint8 channels_last bicubic (1024, 1024) -> (256, 256) aa=True    |  11087.384 (+-138.637)  |     17701.366 (+-643.421)     |  17120.345 (+-1775.004)  |         0.967 (+-0.000)        
      3 torch.uint8 channels_last bicubic (64, 64) -> (224, 224) aa=False       |                         |      1723.015 (+-37.447)      |   4906.660 (+-158.680)   |         2.848 (+-0.000)        
      3 torch.uint8 channels_last bicubic (224, 224) -> (270, 268) aa=False     |                         |      3176.077 (+-81.433)      |   6856.496 (+-106.117)   |         2.159 (+-0.000)        
      3 torch.uint8 channels_last bicubic (256, 256) -> (1024, 1024) aa=False   |                         |     35810.471 (+-457.810)     |   100067.908 (+-0.000)   |         2.794 (+-0.000)        
      3 torch.uint8 channels_last bicubic (224, 224) -> (64, 64) aa=False       |                         |       464.529 (+-21.325)      |    476.989 (+-21.502)    |         1.027 (+-0.000)        
      3 torch.uint8 channels_last bicubic (270, 268) -> (224, 224) aa=False     |                         |      2614.490 (+-30.337)      |   4722.996 (+-123.664)   |         1.806 (+-0.000)        
      3 torch.uint8 channels_last bicubic (1024, 1024) -> (256, 256) aa=False   |                         |      7256.342 (+-173.671)     |   7441.999 (+-188.655)   |         1.026 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (32, 32) aa=True        |                         |      1110.441 (+-31.092)      |   1296.382 (+-40.163)    |         1.167 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (32, 32) aa=False       |                         |       289.785 (+-3.488)       |    244.317 (+-3.386)     |         0.843 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (224, 224) aa=True      |                         |      3460.551 (+-86.592)      |   3177.568 (+-112.095)   |         0.918 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (224, 224) aa=False     |                         |      3307.787 (+-93.377)      |   5868.070 (+-158.006)   |         1.774 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (320, 320) aa=True      |                         |      5552.072 (+-101.754)     |   5047.441 (+-123.868)   |         0.909 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (320, 320) aa=False     |                         |      5752.248 (+-196.970)     |   11344.533 (+-79.634)   |         1.972 (+-0.000)        
      4 torch.uint8 channels_last bicubic (520, 520) -> (32, 32) aa=True        |                         |      3678.043 (+-134.576)     |   5426.964 (+-16.216)    |         1.476 (+-0.000)        
      4 torch.uint8 channels_last bicubic (520, 520) -> (32, 32) aa=False       |                         |       590.729 (+-10.176)      |    450.537 (+-5.030)     |         0.763 (+-0.000)        
      4 torch.uint8 channels_last bicubic (520, 520) -> (224, 224) aa=True      |                         |      8482.118 (+-104.975)     |   7280.506 (+-104.668)   |         0.858 (+-0.000)        
      4 torch.uint8 channels_last bicubic (520, 520) -> (224, 224) aa=False     |                         |      4832.408 (+-118.830)     |   5889.998 (+-116.669)   |         1.219 (+-0.000)        
      4 torch.uint8 channels_last bicubic (712, 712) -> (32, 32) aa=True        |                         |      7816.213 (+-257.136)     |  10947.831 (+-310.221)   |         1.401 (+-0.000)        
      4 torch.uint8 channels_last bicubic (712, 712) -> (32, 32) aa=False       |                         |       768.410 (+-12.089)      |    675.967 (+-31.651)    |         0.880 (+-0.000)        
      4 torch.uint8 channels_last bicubic (712, 712) -> (224, 224) aa=True      |                         |     12770.273 (+-208.609)     |  11811.433 (+-180.334)   |         0.925 (+-0.000)        
      4 torch.uint8 channels_last bicubic (712, 712) -> (224, 224) aa=False     |                         |      5930.326 (+-75.287)      |   6268.420 (+-194.721)   |         1.057 (+-0.000)        
      4 torch.uint8 channels_last bicubic (64, 64) -> (224, 224) aa=True        |                         |      2019.372 (+-30.103)      |   1980.057 (+-52.107)    |         0.981 (+-0.000)        
      4 torch.uint8 channels_last bicubic (224, 224) -> (270, 268) aa=True      |                         |      4136.823 (+-105.343)     |   3630.952 (+-126.245)   |         0.878 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (1024, 1024) aa=True    |                         |     43421.493 (+-219.089)     |  61742.028 (+-5478.200)  |         1.422 (+-0.000)        
      4 torch.uint8 channels_last bicubic (224, 224) -> (64, 64) aa=True        |                         |      1195.047 (+-22.516)      |   1106.663 (+-26.732)    |         0.926 (+-0.000)        
      4 torch.uint8 channels_last bicubic (270, 268) -> (224, 224) aa=True      |                         |      3635.189 (+-72.015)      |   3205.130 (+-51.798)    |         0.882 (+-0.000)        
      4 torch.uint8 channels_last bicubic (1024, 1024) -> (256, 256) aa=True    |                         |     23423.459 (+-316.660)     |  23598.802 (+-3645.947)  |         1.007 (+-0.000)        
      4 torch.uint8 channels_last bicubic (64, 64) -> (224, 224) aa=False       |                         |      2020.718 (+-37.982)      |   5570.996 (+-39.390)    |         2.757 (+-0.000)        
      4 torch.uint8 channels_last bicubic (224, 224) -> (270, 268) aa=False     |                         |      4101.923 (+-122.654)     |   8595.246 (+-68.286)    |         2.095 (+-0.000)        
      4 torch.uint8 channels_last bicubic (256, 256) -> (1024, 1024) aa=False   |                         |     45452.268 (+-1086.815)    |   137220.414 (+-0.000)   |         3.019 (+-0.000)        
      4 torch.uint8 channels_last bicubic (224, 224) -> (64, 64) aa=False       |                         |       604.601 (+-1.874)       |    628.872 (+-15.449)    |         1.040 (+-0.000)        
      4 torch.uint8 channels_last bicubic (270, 268) -> (224, 224) aa=False     |                         |      3303.836 (+-67.614)      |   5786.897 (+-79.340)    |         1.752 (+-0.000)        
      4 torch.uint8 channels_last bicubic (1024, 1024) -> (256, 256) aa=False   |                         |      9860.581 (+-170.194)     |   8720.393 (+-150.097)   |         0.884 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (32, 32) aa=True       |    556.622 (+-5.813)    |       539.978 (+-5.632)       |    988.038 (+-37.112)    |         1.830 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (32, 32) aa=False      |                         |       230.523 (+-4.308)       |    177.409 (+-4.907)     |         0.770 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (224, 224) aa=True     |   1340.723 (+-33.003)   |      2272.881 (+-37.227)      |   1933.442 (+-34.089)    |         0.851 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (224, 224) aa=False    |                         |      2071.035 (+-49.652)      |   3129.371 (+-63.411)    |         1.511 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (320, 320) aa=True     |   2098.626 (+-41.586)   |      3638.615 (+-61.503)      |   3036.742 (+-81.017)    |         0.835 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (320, 320) aa=False    |                         |      3400.730 (+-52.220)      |   6437.490 (+-75.020)    |         1.893 (+-0.000)        
      3 torch.uint8 channels_first bicubic (520, 520) -> (32, 32) aa=True       |   2097.965 (+-72.688)   |       900.457 (+-17.195)      |   3785.110 (+-94.435)    |         4.204 (+-0.000)        
      3 torch.uint8 channels_first bicubic (520, 520) -> (32, 32) aa=False      |                         |       432.282 (+-10.741)      |    355.114 (+-6.356)     |         0.821 (+-0.000)        
      3 torch.uint8 channels_first bicubic (520, 520) -> (224, 224) aa=True     |   3775.745 (+-74.791)   |      5864.176 (+-16.667)      |   5103.820 (+-29.708)    |         0.870 (+-0.000)        
      3 torch.uint8 channels_first bicubic (520, 520) -> (224, 224) aa=False    |                         |      3570.994 (+-19.865)      |   3560.065 (+-19.907)    |         0.997 (+-0.000)        
      3 torch.uint8 channels_first bicubic (712, 712) -> (32, 32) aa=True       |   4143.920 (+-44.357)   |      1747.473 (+-18.674)      |   8293.558 (+-110.170)   |         4.746 (+-0.000)        
      3 torch.uint8 channels_first bicubic (712, 712) -> (32, 32) aa=False      |                         |       605.460 (+-3.790)       |    548.150 (+-17.708)    |         0.905 (+-0.000)        
      3 torch.uint8 channels_first bicubic (712, 712) -> (224, 224) aa=True     |   6168.689 (+-20.160)   |      8906.277 (+-131.506)     |   7681.346 (+-118.190)   |         0.862 (+-0.000)        
      3 torch.uint8 channels_first bicubic (712, 712) -> (224, 224) aa=False    |                         |      4221.651 (+-76.518)      |   3692.215 (+-99.804)    |         0.875 (+-0.000)        
      3 torch.uint8 channels_first bicubic (64, 64) -> (224, 224) aa=True       |    748.223 (+-7.079)    |      1258.357 (+-19.792)      |   1131.738 (+-45.086)    |         0.899 (+-0.000)        
      3 torch.uint8 channels_first bicubic (224, 224) -> (270, 268) aa=True     |   1516.437 (+-32.132)   |      2485.187 (+-77.856)      |   2211.527 (+-63.854)    |         0.890 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (1024, 1024) aa=True   |   14350.821 (+-56.939)  |     24134.156 (+-876.331)     |  22369.975 (+-548.179)   |         0.927 (+-0.000)        
      3 torch.uint8 channels_first bicubic (224, 224) -> (64, 64) aa=True       |    559.199 (+-11.889)   |       797.574 (+-14.688)      |    778.741 (+-5.432)     |         0.976 (+-0.000)        
      3 torch.uint8 channels_first bicubic (270, 268) -> (224, 224) aa=True     |   1475.965 (+-38.749)   |      2356.898 (+-47.205)      |   2042.348 (+-86.461)    |         0.867 (+-0.000)        
      3 torch.uint8 channels_first bicubic (1024, 1024) -> (256, 256) aa=True   |   11392.752 (+-92.637)  |     14491.920 (+-413.260)     |  14823.522 (+-956.728)   |         1.023 (+-0.000)        
      3 torch.uint8 channels_first bicubic (64, 64) -> (224, 224) aa=False      |                         |      1161.936 (+-14.510)      |   3080.719 (+-34.868)    |         2.651 (+-0.000)        
      3 torch.uint8 channels_first bicubic (224, 224) -> (270, 268) aa=False    |                         |      2493.449 (+-61.744)      |   4502.654 (+-171.818)   |         1.806 (+-0.000)        
      3 torch.uint8 channels_first bicubic (256, 256) -> (1024, 1024) aa=False  |                         |     22806.061 (+-127.089)     |  83302.979 (+-1760.627)  |         3.653 (+-0.000)        
      3 torch.uint8 channels_first bicubic (224, 224) -> (64, 64) aa=False      |                         |       419.204 (+-6.733)       |    371.334 (+-4.909)     |         0.886 (+-0.000)        
      3 torch.uint8 channels_first bicubic (270, 268) -> (224, 224) aa=False    |                         |      2238.418 (+-70.300)      |   3131.492 (+-46.530)    |         1.399 (+-0.000)        
      3 torch.uint8 channels_first bicubic (1024, 1024) -> (256, 256) aa=False  |                         |      6650.916 (+-93.940)      |   5028.768 (+-113.766)   |         0.756 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (32, 32) aa=True       |                         |       721.082 (+-12.926)      |   1278.588 (+-21.755)    |         1.773 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (32, 32) aa=False      |                         |       312.155 (+-5.765)       |    235.242 (+-6.643)     |         0.754 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (224, 224) aa=True     |                         |      3139.546 (+-33.043)      |   2556.500 (+-80.318)    |         0.814 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (224, 224) aa=False    |                         |      2765.007 (+-60.966)      |   4142.630 (+-65.127)    |         1.498 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (320, 320) aa=True     |                         |      4805.107 (+-142.337)     |   4158.664 (+-49.055)    |         0.865 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (320, 320) aa=False    |                         |      4543.143 (+-103.957)     |   8291.538 (+-178.459)   |         1.825 (+-0.000)        
      4 torch.uint8 channels_first bicubic (520, 520) -> (32, 32) aa=True       |                         |      1167.090 (+-26.580)      |   5067.106 (+-152.671)   |         4.342 (+-0.000)        
      4 torch.uint8 channels_first bicubic (520, 520) -> (32, 32) aa=False      |                         |       568.228 (+-11.592)      |    437.691 (+-7.412)     |         0.770 (+-0.000)        
      4 torch.uint8 channels_first bicubic (520, 520) -> (224, 224) aa=True     |                         |      7484.268 (+-139.652)     |   6569.758 (+-60.580)    |         0.878 (+-0.000)        
      4 torch.uint8 channels_first bicubic (520, 520) -> (224, 224) aa=False    |                         |      4366.352 (+-45.847)      |   4613.428 (+-92.989)    |         1.057 (+-0.000)        
      4 torch.uint8 channels_first bicubic (712, 712) -> (32, 32) aa=True       |                         |      2268.420 (+-45.239)      |  10989.015 (+-199.606)   |         4.844 (+-0.000)        
      4 torch.uint8 channels_first bicubic (712, 712) -> (32, 32) aa=False      |                         |       738.970 (+-8.432)       |    665.469 (+-34.264)    |         0.901 (+-0.000)        
      4 torch.uint8 channels_first bicubic (712, 712) -> (224, 224) aa=True     |                         |     11157.667 (+-238.283)     |  10261.128 (+-121.852)   |         0.920 (+-0.000)        
      4 torch.uint8 channels_first bicubic (712, 712) -> (224, 224) aa=False    |                         |      6039.225 (+-13.195)      |   4870.404 (+-142.421)   |         0.806 (+-0.000)        
      4 torch.uint8 channels_first bicubic (64, 64) -> (224, 224) aa=True       |                         |      1551.588 (+-54.129)      |   1440.497 (+-27.615)    |         0.928 (+-0.000)        
      4 torch.uint8 channels_first bicubic (224, 224) -> (270, 268) aa=True     |                         |      3309.442 (+-50.674)      |   2869.192 (+-80.589)    |         0.867 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (1024, 1024) aa=True   |                         |     30152.962 (+-334.145)     |  44874.145 (+-1193.926)  |         1.488 (+-0.000)        
      4 torch.uint8 channels_first bicubic (224, 224) -> (64, 64) aa=True       |                         |      1052.361 (+-24.604)      |   1018.679 (+-29.605)    |         0.968 (+-0.000)        
      4 torch.uint8 channels_first bicubic (270, 268) -> (224, 224) aa=True     |                         |      3115.086 (+-50.990)      |   2758.710 (+-68.910)    |         0.886 (+-0.000)        
      4 torch.uint8 channels_first bicubic (1024, 1024) -> (256, 256) aa=True   |                         |      18939.976 (+-62.367)     |  19728.630 (+-748.339)   |         1.042 (+-0.000)        
      4 torch.uint8 channels_first bicubic (64, 64) -> (224, 224) aa=False      |                         |      1556.132 (+-60.834)      |   4097.203 (+-117.064)   |         2.633 (+-0.000)        
      4 torch.uint8 channels_first bicubic (224, 224) -> (270, 268) aa=False    |                         |      3286.220 (+-66.960)      |   5864.638 (+-68.877)    |         1.785 (+-0.000)        
      4 torch.uint8 channels_first bicubic (256, 256) -> (1024, 1024) aa=False  |                         |     30305.451 (+-111.447)     |  91521.671 (+-3108.023)  |         3.020 (+-0.000)        
      4 torch.uint8 channels_first bicubic (224, 224) -> (64, 64) aa=False      |                         |       539.526 (+-4.280)       |    479.700 (+-6.648)     |         0.889 (+-0.000)        
      4 torch.uint8 channels_first bicubic (270, 268) -> (224, 224) aa=False    |                         |      2935.037 (+-64.681)      |   4219.701 (+-261.222)   |         1.438 (+-0.000)        
      4 torch.uint8 channels_first bicubic (1024, 1024) -> (256, 256) aa=False  |                         |      8810.591 (+-86.628)      |  7468.853 (+-2843.493)   |         0.848 (+-0.000)        

Times are in microseconds (us).


cc @vfdev-5

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7668

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8e6351b:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to pull these tests out so in order not to affect the atol=1 for the other tests.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @NicolasHug !

@NicolasHug
Copy link
Member Author

Thanks for the review and for the help with the tests @vfdev-5 .

As a quick sanity check I evaluated the perf of some of our models that were trained with bicubic. Evaluation is done on PIL, float tensors and uint8 tensors (this PR). They're all extremely close.

efficientnet_b0  pil
Test:  Acc@1 77.688 Acc@5 93.536
efficientnet_b0  tensor float
Test:  Acc@1 77.696 Acc@5 93.540
efficientnet_b0  tensor uint8
Test:  Acc@1 77.694 Acc@5 93.538


efficientnet_b1  pil
Test:  Acc@1 78.638 Acc@5 94.186
efficientnet_b1  tensor float
Test:  Acc@1 78.642 Acc@5 94.198
efficientnet_b1  tensor uint8
Test:  Acc@1 78.642 Acc@5 94.190

efficientnet_b7  pil
Test:  Acc@1 84.120 Acc@5 96.908
efficientnet_b7  tensor float
Test:  Acc@1 84.126 Acc@5 96.910
efficientnet_b7  tensor uint8
Test:  Acc@1 84.120 Acc@5 96.904

swin_t  pil
Test:  Acc@1 81.480 Acc@5 95.776
swin_t  tensor float
Test:  Acc@1 81.408 Acc@5 95.762
swin_t  tensor uint8
Test:  Acc@1 81.458 Acc@5 95.772

swin_v2_t  pil
Test:  Acc@1 82.056 Acc@5 96.134
swin_v2_t  tensor float
Test:  Acc@1 82.028 Acc@5 96.114
swin_v2_t  tensor uint8
Test:  Acc@1 82.066 Acc@5 96.138

maxvit_t  pil
Test:  Acc@1 83.696 Acc@5 96.724
maxvit_t  tensor float
Test:  Acc@1 83.688 Acc@5 96.730
maxvit_t  tensor uint8
Test:  Acc@1 83.696 Acc@5 96.718

facebook-github-bot pushed a commit that referenced this pull request Jun 14, 2023
Summary: Co-authored-by: vfdev-5 <vfdev.5@gmail.com>

Reviewed By: vmoens

Differential Revision: D46724117

fbshipit-source-id: e47e4fbb2be67830fe56e22f1a806b8742d7652c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants