diff --git a/psyneulink/core/components/functions/nonstateful/transferfunctions.py b/psyneulink/core/components/functions/nonstateful/transferfunctions.py index f3936f74be..26c96cd3a8 100644 --- a/psyneulink/core/components/functions/nonstateful/transferfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transferfunctions.py @@ -3591,7 +3591,7 @@ def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor: _mask = (torch.abs(_input) > mask_threshold) # Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0])) - if not any(masked_exp): + if (masked_exp == 0).all(): return masked_exp return masked_exp / torch.sum(masked_exp, -1, keepdim=True) # Return the function diff --git a/psyneulink/core/components/functions/nonstateful/transformfunctions.py b/psyneulink/core/components/functions/nonstateful/transformfunctions.py index 99733aad2b..e3258ffc0f 100644 --- a/psyneulink/core/components/functions/nonstateful/transformfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transformfunctions.py @@ -1593,14 +1593,14 @@ def _gen_pytorch_fct(self, device, context=None): weights = torch.tensor(weights, device=device).double() if self.operation == SUM: if weights is not None: - return lambda x: torch.sum(torch.stack(x) * weights, 0) + return lambda x: torch.sum(x * weights, 0) else: - return lambda x: torch.sum(torch.stack(x), 0) + return lambda x: torch.sum(x, 0) elif self.operation == PRODUCT: if weights is not None: - return lambda x: torch.prod(torch.stack(x) * weights, 0) + return lambda x: torch.prod(x * weights, 0) else: - return lambda x: torch.prod(torch.stack(x), 0) + return lambda x: torch.prod(x, 0) else: from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError raise AutodiffCompositionError(f"The 'operation' parameter of {function.componentName} is not supported " diff --git a/psyneulink/library/compositions/autodiffcomposition.py b/psyneulink/library/compositions/autodiffcomposition.py index 003b43db76..f9223aa761 100644 --- a/psyneulink/library/compositions/autodiffcomposition.py +++ b/psyneulink/library/compositions/autodiffcomposition.py @@ -1124,8 +1124,13 @@ def autodiff_forward(self, inputs, targets, for component in curr_tensors_for_trained_outputs.keys(): trial_loss = 0 for i in range(len(curr_tensors_for_trained_outputs[component])): - trial_loss += self.loss_function(curr_tensors_for_trained_outputs[component][i], - curr_target_tensors_for_trained_outputs[component][i]) + # loss only accepts 0 or 1d target. reshape assuming pytorch_rep.minibatch_loss dim is correct + comp_loss = self.loss_function( + curr_tensors_for_trained_outputs[component][i], + torch.atleast_1d(curr_target_tensors_for_trained_outputs[component][i].squeeze()) + ) + comp_loss = comp_loss.reshape_as(pytorch_rep.minibatch_loss) + trial_loss += comp_loss pytorch_rep.minibatch_loss += trial_loss pytorch_rep.minibatch_loss_count += 1 diff --git a/psyneulink/library/compositions/pytorchwrappers.py b/psyneulink/library/compositions/pytorchwrappers.py index 5439300570..ab8ce37e7c 100644 --- a/psyneulink/library/compositions/pytorchwrappers.py +++ b/psyneulink/library/compositions/pytorchwrappers.py @@ -17,7 +17,6 @@ from enum import Enum, auto -from psyneulink.core.components.functions.nonstateful.transformfunctions import LinearCombination, PRODUCT, SUM from psyneulink.core.components.functions.stateful.integratorfunctions import IntegratorFunction from psyneulink.core.components.functions.stateful import StatefulFunction from psyneulink.core.components.mechanisms.processing.transfermechanism import TransferMechanism @@ -30,7 +29,7 @@ NODE, NODE_VALUES, NODE_VARIABLES, OUTPUTS, RESULTS, RUN, TARGETS, TARGET_MECHANISM, ) from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context -from psyneulink.core.globals.utilities import convert_to_np_array, get_deepcopy_with_shared, convert_to_list +from psyneulink.core.globals.utilities import convert_to_list, convert_to_np_array, get_deepcopy_with_shared from psyneulink.core.globals.log import LogCondition from psyneulink.core import llvm as pnlvm @@ -632,7 +631,7 @@ def forward(self, inputs, optimization_rep, context=None)->dict: variable.append(input[i]) elif input_port.default_input == DEFAULT_VARIABLE: # input_port uses a bias, so get that - variable.append(input_port.defaults.variable) + variable.append(torch.from_numpy(input_port.defaults.variable)) # Input for the Mechanism is *not* explicitly specified, but its input_port(s) may have been else: @@ -644,15 +643,14 @@ def forward(self, inputs, optimization_rep, context=None)->dict: variable.append(inputs[input_port]) elif input_port.default_input == DEFAULT_VARIABLE: # input_port uses a bias, so get that - variable.append(input_port.defaults.variable) + variable.append(torch.from_numpy(input_port.defaults.variable)) elif not input_port.internal_only: # otherwise, use the node's input_port's afferents - variable.append(node.aggregate_afferents(i).squeeze(0)) - if len(variable) == 1: - variable = variable[0] + variable.append(node.aggregate_afferents(i)) else: # Node is not INPUT to Composition or BIAS, so get all input from its afferents variable = node.aggregate_afferents() + variable = node.execute_input_ports(variable) if node.exclude_from_gradient_calc: if node.exclude_from_gradient_calc == AFTER: @@ -926,6 +924,11 @@ def __init__(self, self.integrator_function = PytorchFunctionWrapper(mechanism.integrator_function, device, context) self.integrator_previous_value = mechanism.integrator_function._get_pytorch_fct_param_value('initializer', device, context) + self.input_ports = [ + PytorchFunctionWrapper(ip.function, device, context) + for ip in mechanism.input_ports + ] + def add_efferent(self, efferent): """Add ProjectionWrapper for efferent from MechanismWrapper. Implemented for completeness; not currently used @@ -955,54 +958,83 @@ def aggregate_afferents(self, port=None): proj_wrapper._curr_sender_value = proj_wrapper.sender.output[proj_wrapper._value_idx] else: proj_wrapper._curr_sender_value = torch.tensor(proj_wrapper.default_value) + proj_wrapper._curr_sender_value = torch.atleast_1d(proj_wrapper._curr_sender_value) # Specific port is specified # FIX: USING _port_idx TO INDEX INTO sender.value GETS IT WRONG IF THE MECHANISM HAS AN OUTPUT PORT # USED BY A PROJECTION NOT IN THE CURRENT COMPOSITION if port is not None: - return sum(proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0) - for proj_wrapper in self.afferents - if proj_wrapper._pnl_proj - in self._mechanism.input_ports[port].path_afferents) - # Has only one input_port - elif len(self._mechanism.input_ports) == 1: - # Get value corresponding to port from which each afferent projects - return sum((proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0) - for proj_wrapper in self.afferents)) - # Has multiple input_ports + res = [ + proj_wrapper.execute(proj_wrapper._curr_sender_value) + for proj_wrapper in self.afferents + if proj_wrapper._pnl_proj in self._mechanism.input_ports[port].path_afferents + ] else: - return [sum(proj_wrapper.execute(proj_wrapper._curr_sender_value).unsqueeze(0) - for proj_wrapper in self.afferents - if proj_wrapper._pnl_proj in input_port.path_afferents) - for input_port in self._mechanism.input_ports] + res = [] + for input_port in self._mechanism.input_ports: + ip_res = [] + for proj_wrapper in self.afferents: + if proj_wrapper._pnl_proj in input_port.path_afferents: + ip_res.append(proj_wrapper.execute(proj_wrapper._curr_sender_value)) + res.append(torch.stack(ip_res)) + try: + res = torch.stack(res) + except (RuntimeError, TypeError): + # is ragged, will handle ports individually during execute + pass + return res + + def execute_input_ports(self, variable): + from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction + + if not isinstance(variable, torch.Tensor): + try: + variable = torch.stack(variable) + except (RuntimeError, TypeError): + # ragged + pass + + # must iterate over at least 1d input per port + variable = torch.atleast_2d(variable) + + res = [] + for i in range(len(self.input_ports)): + v = variable[i] + if isinstance(self.input_ports[i]._pnl_function, TransformFunction): + # atleast_2d to account for input port dimension reduction + v = torch.atleast_2d(v) + + res.append(self.input_ports[i].function(v)) + + try: + res = torch.stack(res) + except (RuntimeError, TypeError): + # ragged + pass + return res def execute(self, variable, context): """Execute Mechanism's _gen_pytorch version of function on variable. Enforce result to be 2d, and assign to self.output """ - def execute_function(function, variable, fct_has_mult_args=False, is_combination_fct=False): + def execute_function(function, variable, fct_has_mult_args=False): """Execute _gen_pytorch_fct on variable, enforce result to be 2d, and return it If fct_has_mult_args is True, treat each item in variable as an arg to the function If False, compute function for each item in variable and return results in a list """ - if ((isinstance(variable, list) and len(variable) == 1) - or (isinstance(variable, torch.Tensor) and len(variable.squeeze(0).shape) == 1) - or isinstance(self._mechanism.function, LinearCombination)): - # Enforce 2d on value of MechanismWrapper (using unsqueeze) for single InputPort - # or if TransformFunction (which reduces output to single item from multi-item input) - if isinstance(variable, torch.Tensor): - variable = variable.squeeze(0) - return function(variable).unsqueeze(0) - elif is_combination_fct: - # Function combines the elements - return function(variable) - elif fct_has_mult_args: - # Assign each element of variable as an arg to the function - return function(*variable) + from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction + if fct_has_mult_args: + res = function(*variable) + # variable is ragged + elif isinstance(variable, list): + res = [function(variable[i]) for i in range(len(variable))] else: - # Treat each item in variable as a separate input to the function and get result for each in a list: - # make return value 2d by creating list of the results of function returned for each item in variable - return [function(variable[i].squeeze(0)) for i in range(len(variable))] + res = function(variable) + # TransformFunction can reduce output to single item from + # multi-item input + if isinstance(function._pnl_function, TransformFunction): + res = res.unsqueeze(0) + return res # If mechanism has an integrator_function and integrator_mode is True, # execute it first and use result as input to the main function; @@ -1017,9 +1049,7 @@ def execute_function(function, variable, fct_has_mult_args=False, is_combination self.input = variable # Compute main function of mechanism and return result - from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction - self.output = execute_function(self.function, variable, - is_combination_fct=isinstance(self._mechanism.function, TransformFunction)) + self.output = execute_function(self.function, variable) return self.output def _gen_llvm_execute(self, ctx, builder, state, params, mech_input, data):