Skip to content

Commit

Permalink
SoftMax: use last dimension for pytorch softmax
Browse files Browse the repository at this point in the history
handles 1d or 2d input
  • Loading branch information
kmantel committed Dec 17, 2024
1 parent a6e322c commit b9c2b1a
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3583,22 +3583,22 @@ def _gen_pytorch_fct(self, device, context=None):
mask_threshold = self._get_pytorch_fct_param_value('mask_threshold', device, context)

if isinstance(gain, str) and gain == ADAPTIVE:
return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, 0))
return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, -1))

elif mask_threshold:
def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor:
# Mask elements of input below threshold
_mask = (torch.abs(_input) > mask_threshold)
# Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask
masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, 0, keepdim=True)[0]))
masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0]))
if not any(masked_exp):
return masked_exp
return masked_exp / torch.sum(masked_exp, 0, keepdim=True)
return masked_exp / torch.sum(masked_exp, -1, keepdim=True)
# Return the function
return pytorch_thresholded_softmax

else:
return lambda x: (torch.softmax(gain * x, 0))
return lambda x: (torch.softmax(gain * x, -1))

def _gen_pytorch_adapt_gain_fct(self, device, context=None):
scale = self._get_pytorch_fct_param_value('adapt_scale', device, context)
Expand Down

0 comments on commit b9c2b1a

Please sign in to comment.