From 06a3aa3af27ac768adb757bf3a3b7a4b20c08763 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 29 Jun 2022 15:14:26 +0200 Subject: [PATCH 1/2] fix nan during sampling --- src/transformers/generation_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index a5a2c4abffd9..3ef9c9f1ec5a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1970,8 +1970,12 @@ def sample( else (outputs.hidden_states,) ) + # To avoid all `-inf` along the vocab dimension (dim -1), which gives `nan` after `softmax` and error + # in `torch.multinomial`. + _next_token_scores = torch.max(next_token_scores, torch.tensor(torch.finfo(next_token_scores.dtype).min, dtype=next_token_scores.dtype, device=next_token_scores.device)) + # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) + probs = nn.functional.softmax(_next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token From cd53a5102f9ff113928aff083036632e4f74643e Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 29 Jun 2022 15:14:35 +0200 Subject: [PATCH 2/2] fix style --- src/transformers/generation_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3ef9c9f1ec5a..cc07bfbcb6f6 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1972,7 +1972,14 @@ def sample( # To avoid all `-inf` along the vocab dimension (dim -1), which gives `nan` after `softmax` and error # in `torch.multinomial`. - _next_token_scores = torch.max(next_token_scores, torch.tensor(torch.finfo(next_token_scores.dtype).min, dtype=next_token_scores.dtype, device=next_token_scores.device)) + _next_token_scores = torch.max( + next_token_scores, + torch.tensor( + torch.finfo(next_token_scores.dtype).min, + dtype=next_token_scores.dtype, + device=next_token_scores.device, + ), + ) # sample probs = nn.functional.softmax(_next_token_scores, dim=-1)