diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index 870e6138d7bd..e25c28e5711a 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -227,13 +227,13 @@ def estimate_thrust_workspace_size(self, call: relax.Call) -> int: int32_byte_per_elem = DataType("int32").bits // 8 num_elem = reduce(mul, input_shape, 1) input_size = num_elem * input_byte_per_elem - # Most GPU algorithms take O(n) space or less, we choose 8N + 4MB as a safe estimation + # Most GPU algorithms take O(n) space or less, we choose 8N + 8MB as a safe estimation # for algorithm workspace. # The current thrust sort implementation may need extra int64 and int32 arrays # for temporary data, so we further add this part to the workspace. return ( 8 * input_size - + 4 * 1024 * 1024 + + 8 * 1024 * 1024 + num_elem * (int64_byte_per_elem + int32_byte_per_elem) )