From b9c2b1a60d50cfbb118316c6c825e42551e007bc Mon Sep 17 00:00:00 2001 From: Katherine Mantel Date: Wed, 11 Dec 2024 03:29:56 +0000 Subject: [PATCH] SoftMax: use last dimension for pytorch softmax handles 1d or 2d input --- .../components/functions/nonstateful/transferfunctions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index 5e2550b150..f3936f74be 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -3583,22 +3583,22 @@ def _gen_pytorch_fct(self, device, context=None): mask_threshold = self._get_pytorch_fct_param_value('mask_threshold', device, context) if isinstance(gain, str) and gain == ADAPTIVE: - return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, 0)) + return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, -1)) elif mask_threshold: def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor: # Mask elements of input below threshold _mask = (torch.abs(_input) > mask_threshold) # Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask - masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, 0, keepdim=True)[0])) + masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0])) if not any(masked_exp): return masked_exp - return masked_exp / torch.sum(masked_exp, 0, keepdim=True) + return masked_exp / torch.sum(masked_exp, -1, keepdim=True) # Return the function return pytorch_thresholded_softmax else: - return lambda x: (torch.softmax(gain * x, 0)) + return lambda x: (torch.softmax(gain * x, -1)) def _gen_pytorch_adapt_gain_fct(self, device, context=None): scale = self._get_pytorch_fct_param_value('adapt_scale', device, context)