diff --git a/test/test_models.py b/test/test_models.py index 4633617d477..f6eeb7c28c8 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -29,8 +29,10 @@ def disable_tf32(): previous = torch.backends.cudnn.allow_tf32 torch.backends.cudnn.allow_tf32 = False - yield - torch.backends.cudnn.allow_tf32 = previous + try: + yield + finally: + torch.backends.cudnn.allow_tf32 = previous def list_model_fns(module):