From a18bb7ca35049a4a4f2b6bb16cfea482cb57284c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 11 Jul 2022 19:02:10 -0700 Subject: [PATCH 1/2] feat: Update Pytorch version to 1.12 Signed-off-by: Dheeraj Peri --- README.md | 2 +- WORKSPACE | 8 ++++---- py/requirements.txt | 2 +- tests/modules/requirements.txt | 1 - tests/py/requirements.txt | 2 +- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9bf645af0b..8770f29969 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass. - Bazel 5.1.1 -- Libtorch 1.11.0 (built with CUDA 11.3) +- Libtorch 1.12.0 (built with CUDA 11.3) - CUDA 11.3 - cuDNN 8.2.1 - TensorRT 8.2.4.2 diff --git a/WORKSPACE b/WORKSPACE index 2779e93cc7..9a6b62a7a3 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -56,17 +56,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "8d9e829ce9478db4f35bdb7943308cf02e8a2f58cf9bb10f742462c1d57bf287", + sha256 = "80f089939de20e68e3fcad4dfa72a26c8bf91b5e77b11042f671f39ebac35865", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.11.0%2Bcu113.zip"], + urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu113.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "90159ecce3ff451f3ef3f657493b6c7c96759c3b74bbd70c1695f2ea2f81e1ad", + sha256 = "8e35371403f7052d9e9b43bcff383980dbde4df028986dc1dab539953481d55f", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.11.0%2Bcu113.zip"], + urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.12.0%2Bcu113.zip"], ) # Download these tarballs manually from the NVIDIA website diff --git a/py/requirements.txt b/py/requirements.txt index 8d12c108aa..0fe116416f 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,5 +1,5 @@ -f https://download.pytorch.org/whl/torch_stable.html -f https://download.pytorch.org/whl/torch/ --extra-index-url https://download.pytorch.org/whl/cu113 -torch==1.11.0+cu113 +torch==1.12.0+cu113 pybind11==2.6.2 diff --git a/tests/modules/requirements.txt b/tests/modules/requirements.txt index b1a922e034..3f52484ca8 100644 --- a/tests/modules/requirements.txt +++ b/tests/modules/requirements.txt @@ -1,4 +1,3 @@ -f https://download.pytorch.org/whl/torch_stable.html -#torch==1.11.0+cu113 timm==v0.4.12 transformers==4.17.0 diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 91e97eed3e..0ea1c76a29 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -1,2 +1,2 @@ -torchvision==0.12.0+cu113 +torchvision==0.13.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html From dfd4a8323b83325580b05aa9b197506edcddf64b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 22 Jul 2022 20:18:31 -0700 Subject: [PATCH 2/2] chore: Minor fix Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/ts/_compile_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 204f4cf91c..4c7b8b5b5d 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -350,7 +350,7 @@ def TensorRTCompileSpec(inputs=[], backend_spec._set_workspace_size(parsed_spec.workspace_size) backend_spec._set_dla_sram_size(parsed_spec.dla_sram_size) backend_spec._set_dla_local_dram_size(parsed_spec.dla_local_dram_size) - backend_spec._set_dla_global_dram_size(parsed_spec._set_dla_global_dram_size) + backend_spec._set_dla_global_dram_size(parsed_spec.dla_global_dram_size) backend_spec._set_truncate_long_and_double(parsed_spec.truncate_long_and_double) backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())