-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace stack/mask/reduce by indexing in _hsv2rgb #7754
Conversation
Not sure what is going on with the out-of-memory errors in CI. The PR actually reduces peak memory usage according to: rgb = torch.rand((16, 3, 704, 1024), device='cuda')
hsv = _rgb2hsv(rgb)
torch.cuda.reset_peak_memory_stats()
_hsv2rgb(hsv)
print(torch.cuda.max_memory_allocated() // (1024 ** 2)) I cannot reproduce locally. |
Thanks for the PR @nlgranger
Can you please provide some simple benchmarks illustrating the gains in memory and perf? |
Using the following code: import time
import torch
from torchvision.transforms._functional_tensor import _rgb2hsv, _hsv2rgb
device = "cuda"
shapes = [
(3, 320, 320),
(8, 3, 320, 320),
(32, 3, 320, 320),
(3, 640, 768),
(8, 3, 640, 768),
(32, 3, 640, 768),
]
for s in shapes:
rgb = torch.rand(s, device=device)
hsv = _rgb2hsv(rgb)
durations = []
peak_mem = []
for _ in range(10):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
t0 = time.monotonic()
_hsv2rgb(hsv)
torch.cuda.synchronize()
t1 = time.monotonic()
for _ in range(100):
torch.cuda.synchronize()
t0 = time.monotonic()
_hsv2rgb(hsv)
torch.cuda.synchronize()
t1 = time.monotonic()
durations.append(t1 - t0)
if device == "cuda":
peak_mem.append(torch.cuda.max_memory_allocated())
if device == "cuda":
print(f"{str(s):20s} : {sum(durations) * 10:7.2f}ms {sum(peak_mem) / 100 / 1024 ** 2:7.2f}MB")
else:
print(f"{str(s):20s} : {sum(durations) / len(durations) * 1000:7.2f}ms") On GPU (Quadro T1000 mobile):
On CPU (i7-9850H):
|
I run my benchmark script and your script with more iterations and see the following:
I propose to use this improved function (mix of v2 function + indexing and gather): def fn_new2(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h.mul(6)
i = torch.floor(h6)
f = h6.sub_(i)
i = i.to(dtype=torch.int32)
sxf = s * f
one_minus_s = 1.0 - s
q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0)
t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0)
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)
vpqt = torch.stack((v, p, q, t), dim=-3)
# vpqt -> rgb mapping based on i
select = torch.tensor(
[[0, 2, 1, 1, 3, 0], [3, 0, 0, 2, 1, 1], [1, 1, 3, 0, 0, 2]],
dtype=torch.long, device=img.device
)
select = select[:, i]
if select.ndim > 3:
select = select.transpose(0, 1)
return vpqt.gather(-3, select) @nlgranger what do you think ? |
27bb519
to
1fd4730
Compare
@vfdev-5 I have included your in-place optimizations as well thank you. |
88bb7d7
to
2e07135
Compare
@nlgranger by the way, we have to update only v2 implementation, let's keep v1 implementation as it is. |
b0a89c3
to
bc9142f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this PR !
However, once again, please revert all changes in _functional_tensor.py
(https://github.com/pytorch/vision/pull/7754/files#r1276480612) as transforms v2 will replace v1 soon.
I left few other comments to address
These out of memory issues keep showing up during the tests. I don't see how the modifications I made could cause them. |
Probably, it was a flaky CI. We can update the branch once again to see if we can reproduce the OOM. Concerning gather needs contiguous indices? I'm not sure about that. Let's see in terms of perfs and it could be also possible that pytorch itself does |
@nlgranger I rerun my benchmark on your latest commit vs implementation without contiguous call and using non-blocking (
How about using |
Sure, but the tests won't pass anyway and I have no clue why. |
Which tests specifically you are talking about ? Currently, we have a lot of flaky failing tests... |
The slower speed is probably due to the btw, I found where the |
@nlgranger can you please update the code to rerun the CI and see if there are any OOMs. If you are busy, would you mind me pushing to the branch to move forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @nlgranger !
Let's see if there are any related OOM in the CI -> No OOMs seen on CI. Merging
Hey @vfdev-5! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Summary: Co-authored-by: vfdev <vfdev.5@gmail.com> Reviewed By: matteobettini Differential Revision: D48642248 fbshipit-source-id: 24f789cb0ddfb5810c423e4f3ef9e3d28cc2a8a6
Fixes #7753
cc @vfdev-5