From 5817ee6e52af6b66675bc7dd381e15489a5af37e Mon Sep 17 00:00:00 2001 From: shademe Date: Wed, 24 Aug 2022 14:12:32 +0200 Subject: [PATCH 1/3] `PyTorchGradScaler`: Cache `_found_inf` on the CPU This prevents unnecessary overhead from launching kernels on the GPU in hot backward passes. --- thinc/shims/pytorch_grad_scaler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/thinc/shims/pytorch_grad_scaler.py b/thinc/shims/pytorch_grad_scaler.py index 999ed0047..10b7ea986 100644 --- a/thinc/shims/pytorch_grad_scaler.py +++ b/thinc/shims/pytorch_grad_scaler.py @@ -52,6 +52,8 @@ def __init__( self._growth_interval = growth_interval self._found_inf = torch.full((1,), 0.0) + # Caches `self._found_inf` to minimize DTH transfers in hotloops. + self._found_inf_bool = False self._growth_tracker = torch.full((1,), 0, dtype=torch.int) self._scale = torch.full((1,), init_scale) @@ -132,7 +134,7 @@ def _tensors_per_device(self, tensors): @property def found_inf(self): - return bool(self._found_inf) != 0 + return self._found_inf_bool def unscale(self, tensors): """Unscale the given tensors. Returns True if any of the gradients were infinite.""" @@ -154,7 +156,8 @@ def unscale(self, tensors): self._found_inf += found_inf_device.to(self._found_inf.device) - return bool(self._found_inf != 0) + self._found_inf_bool = bool(self._found_inf != 0) + return self._found_inf_bool def update(self): """ @@ -176,3 +179,4 @@ def update(self): # Clear infinity found status self._found_inf = torch.zeros_like(self._found_inf) + self._found_inf_bool = False From f4af6298992c1de64afefe38177e7950b534ecc0 Mon Sep 17 00:00:00 2001 From: shademe Date: Thu, 25 Aug 2022 10:33:54 +0200 Subject: [PATCH 2/3] Only pin `_found_inf` to the CPU --- thinc/shims/pytorch_grad_scaler.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/thinc/shims/pytorch_grad_scaler.py b/thinc/shims/pytorch_grad_scaler.py index 10b7ea986..b16c4e076 100644 --- a/thinc/shims/pytorch_grad_scaler.py +++ b/thinc/shims/pytorch_grad_scaler.py @@ -51,14 +51,12 @@ def __init__( self._backoff_factor = backoff_factor self._growth_interval = growth_interval - self._found_inf = torch.full((1,), 0.0) - # Caches `self._found_inf` to minimize DTH transfers in hotloops. - self._found_inf_bool = False self._growth_tracker = torch.full((1,), 0, dtype=torch.int) self._scale = torch.full((1,), init_scale) + # Pin to the CPU to minimize DTH transfers in hot loops. + self._found_inf = torch.full((1,), 0.0, device="cpu") def to_(self, device): - self._found_inf = self._found_inf.to(device) self._growth_tracker = self._growth_tracker.to(device) self._scale = self._scale.to(device) @@ -134,7 +132,7 @@ def _tensors_per_device(self, tensors): @property def found_inf(self): - return self._found_inf_bool + return bool(self._found_inf != 0) def unscale(self, tensors): """Unscale the given tensors. Returns True if any of the gradients were infinite.""" @@ -156,8 +154,7 @@ def unscale(self, tensors): self._found_inf += found_inf_device.to(self._found_inf.device) - self._found_inf_bool = bool(self._found_inf != 0) - return self._found_inf_bool + return bool(self._found_inf != 0) def update(self): """ @@ -168,15 +165,15 @@ def update(self): if not self._enabled: return + found_inf_device = self._found_inf.to(self._scale.device) torch._amp_update_scale_( self._scale, self._growth_tracker, - self._found_inf, + found_inf_device, self._growth_factor, self._backoff_factor, self._growth_interval, ) # Clear infinity found status - self._found_inf = torch.zeros_like(self._found_inf) - self._found_inf_bool = False + self._found_inf = torch.zeros((1,), device="cpu") From 39287ede7959f66b2766cdd8981fd9c008c48ecc Mon Sep 17 00:00:00 2001 From: shademe Date: Thu, 25 Aug 2022 10:40:14 +0200 Subject: [PATCH 3/3] Always store `_found_inf` as a `bool` --- thinc/shims/pytorch_grad_scaler.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/thinc/shims/pytorch_grad_scaler.py b/thinc/shims/pytorch_grad_scaler.py index b16c4e076..8db11bcae 100644 --- a/thinc/shims/pytorch_grad_scaler.py +++ b/thinc/shims/pytorch_grad_scaler.py @@ -53,8 +53,7 @@ def __init__( self._growth_tracker = torch.full((1,), 0, dtype=torch.int) self._scale = torch.full((1,), init_scale) - # Pin to the CPU to minimize DTH transfers in hot loops. - self._found_inf = torch.full((1,), 0.0, device="cpu") + self._found_inf = False def to_(self, device): self._growth_tracker = self._growth_tracker.to(device) @@ -132,7 +131,7 @@ def _tensors_per_device(self, tensors): @property def found_inf(self): - return bool(self._found_inf != 0) + return self._found_inf def unscale(self, tensors): """Unscale the given tensors. Returns True if any of the gradients were infinite.""" @@ -152,9 +151,10 @@ def unscale(self, tensors): device_tensors, found_inf_device, inv_scale_device ) - self._found_inf += found_inf_device.to(self._found_inf.device) + if bool(found_inf_device != 0): + self._found_inf = True - return bool(self._found_inf != 0) + return self._found_inf def update(self): """ @@ -165,7 +165,9 @@ def update(self): if not self._enabled: return - found_inf_device = self._found_inf.to(self._scale.device) + found_inf_device = torch.full( + (1,), 1.0 if self._found_inf else 0.0, device=self._scale.device + ) torch._amp_update_scale_( self._scale, self._growth_tracker, @@ -176,4 +178,4 @@ def update(self): ) # Clear infinity found status - self._found_inf = torch.zeros((1,), device="cpu") + self._found_inf = False