diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Environment.py b/Scripts/Environment.py similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Environment.py rename to Scripts/Environment.py diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/DeclanParams.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/DeclanParams.py similarity index 84% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/DeclanParams.py rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/DeclanParams.py index 7209121c186..c2dbbbf8180 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/DeclanParams.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/DeclanParams.py @@ -39,10 +39,10 @@ def calc_prob(em_preds, test_ys): # Names: name = "EGO Model CSW", + em_name = "EM", state_input_layer_name = "STATE", previous_state_layer_name = "PREVIOUS STATE", context_layer_name = 'CONTEXT', - em_name = "EM", prediction_layer_name = "PREDICTION", # Structural @@ -50,10 +50,10 @@ def calc_prob(em_preds, test_ys): previous_state_d = 11, # length of state vector context_d = 11, # length of context vector memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims - # memory_init = (0,.0001), # Initialize memory with random values in interval - memory_init = None, # Initialize with zeros - concatenate_queries = False, - # concatenate_queries = True, + memory_init = (0,.0001), # Initialize memory with random values in interval + # memory_init = None, # Initialize with zeros + # concatenate_queries = False, + concatenate_queries = True, # environment # curriculum_type = 'Interleaved', @@ -63,20 +63,23 @@ def calc_prob(em_preds, test_ys): # Processing integration_rate = .69, # rate at which state is integrated into new context - # state_weight = 1, # weight of the state used during memory retrieval + # state_weight =normalize_field_weightsnormalize_field_weights 1, # weight of the state used during memory retrieval # context_weight = 1, # weight of the context used during memory retrieval - state_weight = .5, # weight of the state used during memory retrieval + previous_state_weight = .5, # weight of the state used during memory retrieval context_weight = .5, # weight of the context used during memory retrieval + state_weight = None, # weight of the state used during memory retrieval # normalize_field_weights = False, # whether to normalize the field weights during memory retrieval normalize_field_weights = True, # whether to normalize the field weights during memory retrieval + normalize_memories = False, # whether to normalize the memory during memory retrieval + # normalize_memories = True, # whether to normalize the memory during memory retrieval # softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_threshold = None, # threshold used to mask out small values in softmax softmax_threshold = .001, # threshold used to mask out small values in softmax - enable_learning=[False, False, True], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE - learn_field_weights = False, + # target_fields=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE + enable_learning = True, loss_spec = Loss.BINARY_CROSS_ENTROPY, # loss_spec = Loss.MSE, learning_rate = .5, diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/EGO CSW Model (using RNN).py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/EGO CSW Model (using RNN).py similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/EGO CSW Model (using RNN).py rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/EGO CSW Model (using RNN).py diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/EGO CSW Model.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/EGO CSW Model.py similarity index 91% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/EGO CSW Model.py rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/EGO CSW Model.py index 18d3ba419b0..52423c7dbbf 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/EGO CSW Model.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/EGO CSW Model.py @@ -174,7 +174,7 @@ 'CONTEXT', 'PREVIOUS STATE'], start=0) -state_retrieval_weight = 0 +state_retrieval_weight = None RANDOM_WEIGHTS_INITIALIZATION=RandomMatrix(center=0.0, range=0.1) # Matrix spec used to initialize all Projections if is_numeric_scalar(model_params['softmax_temperature']): # translate to gain of softmax retrieval function @@ -194,7 +194,7 @@ def construct_model(model_name:str=model_params['name'], state_size:int=model_params['state_d'], # Previous state - previous_state_input_name:str=model_params['previous_state_layer_name'], + previous_state_name:str=model_params['previous_state_layer_name'], # Context representation (learned): context_name:str=model_params['context_layer_name'], @@ -205,12 +205,15 @@ def construct_model(model_name:str=model_params['name'], em_name:str=model_params['em_name'], retrieval_softmax_gain=retrieval_softmax_gain, retrieval_softmax_threshold=model_params['softmax_threshold'], - state_retrieval_weight:Union[float,int]=state_retrieval_weight, - previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'], + # state_retrieval_weight:Union[float,int]=state_retrieval_weight, + # previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'], + state_retrieval_weight:Union[float,int]=model_params['state_weight'], + previous_state_retrieval_weight:Union[float,int]=model_params['previous_state_weight'], context_retrieval_weight:Union[float,int]=model_params['context_weight'], normalize_field_weights = model_params['normalize_field_weights'], + normalize_memories = model_params['normalize_memories'], concatenate_queries = model_params['concatenate_queries'], - learn_field_weights = model_params['learn_field_weights'], + enable_learning = model_params['enable_learning'], memory_capacity = memory_capacity, memory_init=model_params['memory_init'], @@ -219,7 +222,7 @@ def construct_model(model_name:str=model_params['name'], # Learning loss_spec=model_params['loss_spec'], - enable_learning=model_params['enable_learning'], + # target_fields=model_params['target_fields'], learning_rate = model_params['learning_rate'], device=model_params['device'] @@ -233,7 +236,7 @@ def construct_model(model_name:str=model_params['name'], # ---------------------------------------------------------------------------------------------------------------- state_input_layer = ProcessingMechanism(name=state_input_name, input_shapes=state_size) - previous_state_layer = ProcessingMechanism(name=previous_state_input_name, input_shapes=state_size) + previous_state_layer = ProcessingMechanism(name=previous_state_name, input_shapes=state_size) # context_layer = ProcessingMechanism(name=context_name, input_shapes=context_size) context_layer = TransferMechanism(name=context_name, input_shapes=context_size, @@ -241,6 +244,8 @@ def construct_model(model_name:str=model_params['name'], integrator_mode=True, integration_rate=integration_rate) + + em = EMComposition(name=em_name, memory_template=[[0] * state_size, # state [0] * state_size, # previous state @@ -250,6 +255,15 @@ def construct_model(model_name:str=model_params['name'], memory_decay_rate=0, softmax_gain=retrieval_softmax_gain, softmax_threshold=retrieval_softmax_threshold, + fields = {state_input_name: {FIELD_WEIGHT: state_retrieval_weight, + LEARN_FIELD_WEIGHT: False, + TARGET_FIELD: True}, + previous_state_name: {FIELD_WEIGHT:previous_state_retrieval_weight, + LEARN_FIELD_WEIGHT: False, + TARGET_FIELD: False}, + context_name: {FIELD_WEIGHT:context_retrieval_weight, + LEARN_FIELD_WEIGHT: False, + TARGET_FIELD: False}}, # Input Nodes: # field_names=[state_input_name, # previous_state_input_name, @@ -259,19 +273,20 @@ def construct_model(model_name:str=model_params['name'], # previous_state_retrieval_weight, # context_retrieval_weight # ), - field_names=[previous_state_input_name, - context_name, - state_input_name, - ], - field_weights=(previous_state_retrieval_weight, - context_retrieval_weight, - state_retrieval_weight, - ), + # field_names=[previous_state_input_name, + # context_name, + # state_input_name, + # ], + # field_weights=(previous_state_retrieval_weight, + # context_retrieval_weight, + # state_retrieval_weight, + # ), normalize_field_weights=normalize_field_weights, + normalize_memories=normalize_memories, concatenate_queries=concatenate_queries, - learn_field_weights=learn_field_weights, - learning_rate=learning_rate, enable_learning=enable_learning, + learning_rate=learning_rate, + # target_fields=target_fields, device=device ) @@ -311,7 +326,7 @@ def construct_model(model_name:str=model_params['name'], em] previous_state_to_em_pathway = [previous_state_layer, MappingProjection(sender=previous_state_layer, - receiver=em.nodes[previous_state_input_name+QUERY], + receiver=em.nodes[previous_state_name+QUERY], matrix=IDENTITY_MATRIX, learnable=False), em] diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Environment.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Environment.py new file mode 100644 index 00000000000..124de532c83 --- /dev/null +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Environment.py @@ -0,0 +1,54 @@ +import numpy as np +import torch +from torch.utils.data import dataset +from random import randint + +def one_hot_encode(labels, num_classes): + """ + One hot encode labels and convert to tensor. + """ + return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32) + +class DeterministicCSWDataset(dataset.Dataset): + def __init__(self, n_samples_per_context, contexts_to_load) -> None: + super().__init__() + raw_xs = np.array([ + [[9,1,3,5,7],[9,2,4,6,8]], + [[10,1,4,5,8],[10,2,3,6,7]] + ]) + + item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True) + task_names = [0,1] # Flexible so these can be renamed later + task_indices = [task_names.index(name) for name in contexts_to_load] + + context_indices = np.repeat(np.array(task_indices),n_samples_per_context) + self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11) + + self.xs = self.xs.reshape((-1,11)) + self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0) + context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context]) + self.contexts = one_hot_encode(context_indices, len(task_names)) + + # Remove the last transition since there's no next state available + self.xs = self.xs[:-1] + self.ys = self.ys[:-1] + self.contexts = self.contexts[:-1] + + def __len__(self): + return len(self.xs) + + def __getitem__(self, idx): + return self.xs[idx], self.contexts[idx], self.ys[idx] + +def generate_dataset(condition='Blocked'): + # Generate the dataset for either the blocked or interleaved condition + if condition=='Blocked': + contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)] + n_samples_per_context = [40,40,40,40] + [1]*40 + elif condition == 'Interleaved': + contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)] + n_samples_per_context = [1]*160 + [1]*40 + else: + raise ValueError(f'Unknown dataset condition: {condition}') + + return DeterministicCSWDataset(n_samples_per_context, contexts_to_load) diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (PyTorch).pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (PyTorch).pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (PyTorch).pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (PyTorch).pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (basic).pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (basic).pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (basic).pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (basic).pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (learning and store).pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (learning and store).pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (learning and store).pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (learning and store).pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (learning).pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (learning).pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model (learning).pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model (learning).pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model - EM (with PNL learning).pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model - EM (with PNL learning).pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model - EM (with PNL learning).pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model - EM (with PNL learning).pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model - EM.pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model - EM.pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO CSW Model - EM.pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO CSW Model - EM.pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO Paper Figure.jpg b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO Paper Figure.jpg similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EGO Paper Figure.jpg rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EGO Paper Figure.jpg diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EMComposition (example BIG).pdf b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EMComposition (example BIG).pdf similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/Figures/EMComposition (example BIG).pdf rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Figures/EMComposition (example BIG).pdf diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/ScriptControl.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/ScriptControl.py similarity index 93% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/ScriptControl.py rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/ScriptControl.py index 43016886d3a..f61ec5f75d4 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/ScriptControl.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/ScriptControl.py @@ -3,8 +3,8 @@ # Settings for running script: -MODEL_PARAMS = 'TestParams' -# MODEL_PARAMS = 'DeclanParams' +# MODEL_PARAMS = 'TestParams' +MODEL_PARAMS = 'DeclanParams' CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script DISPLAY_MODEL = ( # Only one of the following can be uncommented: @@ -13,7 +13,7 @@ # # 'show_pytorch': True, # show pytorch graph of model # 'show_learning': True, # # 'show_nested_args': {'show_learning': pnl.ALL}, - # 'show_projections_not_in_composition': True, + # # 'show_projections_not_in_composition': True, # # 'show_nested': {'show_node_structure': True}, # # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning # # 'show_node_structure': True # show detailed view of node structures and projections diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/TestParams.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/TestParams.py similarity index 86% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/TestParams.py rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/TestParams.py index e9893eff726..2ba7073f178 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/TestParams.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/TestParams.py @@ -1,14 +1,16 @@ from psyneulink.core.llvm import ExecutionMode from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL + + model_params = dict( # Names: name = "EGO Model CSW", + em_name = "EM", state_input_layer_name = "STATE", previous_state_layer_name = "PREVIOUS STATE", context_layer_name = 'CONTEXT', - em_name = "EM", prediction_layer_name = "PREDICTION", # Structural @@ -20,7 +22,6 @@ # memory_init = None, # Initialize with zeros concatenate_queries = False, # concatenate_queries = True, - # environment # curriculum_type = 'Interleaved', curriculum_type = 'Blocked', @@ -33,18 +34,19 @@ context_weight = 1, # weight of the context used during memory retrieval # normalize_field_weights = False, # whether to normalize the field weights during memory retrieval normalize_field_weights = True, # whether to normalize the field weights during memory retrieval + normalize_memories = False, # whether to normalize the memory during memory retrieval # softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like # softmax_threshold = None, # threshold used to mask out small values in softmax softmax_threshold = .001, # threshold used to mask out small values in softmax - enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE - # enable_learning=[True, True, True] - # enable_learning=True, - # enable_learning=False, - learn_field_weights = True, - # learn_field_weights = False, + # target_fields=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE + # target_fields=[True, True, True] + # target_fields=True, + # target_fields=False, + enable_learning = True, + # enable_learning = False, loss_spec = Loss.BINARY_CROSS_ENTROPY, # loss_spec = Loss.CROSS_ENTROPY, # loss_spec = Loss.MSE, @@ -53,8 +55,8 @@ synch_weights = RUN, synch_values = RUN, synch_results = RUN, - execution_mode = ExecutionMode.Python, - # execution_mode = ExecutionMode.PyTorch, + # execution_mode = ExecutionMode.Python, + execution_mode = ExecutionMode.PyTorch, device = CPU, # device = MPS, ) diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/__init__.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/__init__.py similarity index 100% rename from Scripts/Models (Under Development)/EGO/Using EMComposition/CSW/__init__.py rename to Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/__init__.py diff --git a/docs/source/_static/EMComposition_Example_fig.svg b/docs/source/_static/EMComposition_Example_fig.svg index f3a5662f21e..7456c5d2b38 100644 --- a/docs/source/_static/EMComposition_Example_fig.svg +++ b/docs/source/_static/EMComposition_Example_fig.svg @@ -1,56 +1,94 @@ - - - - -EM_Composition - -VALUE INPUT - -KEY INPUT - -MATCH KEY - - - - -RETRIEVAL - -VALUE RETRIEVED - - - - -KEY RETRIEVED - - - - -SOFTMAX - - - - - - - -SOFTMAX GAIN CONTROL - - - - - - - - + + + + + + + + + + + EM_Composition + + VALUE + [VALUE] + + STORE + + + + + KEY + [QUERY] + + KEY + [MATCH + to + KEYS] + + + + + + + + RETRIEVE + + + + + KEY + [RETRIEVED] + + + + + VALUE + [RETRIEVED] + + + + + \ No newline at end of file diff --git a/docs/source/_static/EMComposition_field_weights_different.pdf b/docs/source/_static/EMComposition_field_weights_different.pdf new file mode 100644 index 00000000000..97ebdb43148 Binary files /dev/null and b/docs/source/_static/EMComposition_field_weights_different.pdf differ diff --git a/docs/source/_static/EMComposition_field_weights_different.svg b/docs/source/_static/EMComposition_field_weights_different.svg index 94aab6b6a7c..eeb15badcc4 100644 --- a/docs/source/_static/EMComposition_field_weights_different.svg +++ b/docs/source/_static/EMComposition_field_weights_different.svg @@ -1,103 +1,209 @@ - - - - -EM_Composition - -VALUE INPUT - -KEY 0 INPUT - -MATCH KEY 0 - - - - -RETRIEVAL WEIGHTING 0 - - - - -KEY 1 INPUT - -MATCH KEY 1 - - - - -RETRIEVAL WEIGHTING 1 - - - - -SOFTMAX GAIN CONTROL 1 - - - - -SOFTMAX 1 - - - - - - - - -SOFTMAX 0 - -WEIGHT RETRIEVALS - - - - - - - -SOFTMAX GAIN CONTROL 0 - - - - - - - - - - - - -KEY 0 RETRIEVED - - - - -VALUE RETRIEVED - - - - -KEY 1 RETRIEVED - - - - - - - - - - - + + + + + + + + + + + EM_Composition + + 3 + [QUERY] + + 3 + [MATCH + to + KEYS] + + + + + STORE + + + + + 2 + [QUERY] + + 2 + [MATCH + to + KEYS] + + + + + + + + 1 + [QUERY] + + 1 + [MATCH + to + KEYS] + + + + + + + + 0 + [QUERY] + + 0 + [MATCH + to + KEYS] + + + + + + + + 0 + [WEIGHTED + MATCH] + + + + + COMBINE + MATCHES + + + + + 0 + [WEIGHT] + + + + + RETRIEVE + + + + + 1 + [WEIGHTED + MATCH] + + + + + 2 + [WEIGHTED + MATCH] + + + + + 3 + [WEIGHTED + MATCH] + + + + + + + + 1 + [WEIGHT] + + + + + + + + 2 + [WEIGHT] + + + + + + + + 3 + [WEIGHT] + + + + + 0 + [RETRIEVED] + + + + + 1 + [RETRIEVED] + + + + + 2 + [RETRIEVED] + + + + + 3 + [RETRIEVED] + + + + + \ No newline at end of file diff --git a/docs/source/_static/EMComposition_field_weights_equal_fig.svg b/docs/source/_static/EMComposition_field_weights_equal_fig.svg index dfa96297ffb..a093260a155 100644 --- a/docs/source/_static/EMComposition_field_weights_equal_fig.svg +++ b/docs/source/_static/EMComposition_field_weights_equal_fig.svg @@ -1,104 +1,209 @@ - - - - -EM_Composition - -VALUE 1 INPUT - -KEY 1 INPUT - -MATCH KEY 1 - - - - -KEY 0 INPUT - -MATCH KEY 0 - - - - -VALUE 0 INPUT - -RETRIEVAL - -VALUE 1 RETRIEVED - - - - -KEY 1 RETRIEVED - - - - -KEY 0 RETRIEVED - - - - -VALUE 0 RETRIEVED - - - - -SOFTMAX 0 - - - - -SOFTMAX 1 - - - - - - - -SOFTMAX GAIN CONTROL 1 - - - - - - - - - - - -SOFTMAX GAIN CONTROL 0 - - - - - - - - -RETRIEVAL WEIGHTING 0 - - - - - -RETRIEVAL WEIGHTING 1 - - - - - + + + + + + + + + + + EM_Composition + + 3 + [QUERY] + + 3 + [MATCH + to + KEYS] + + + + + STORE + + + + + 2 + [QUERY] + + 2 + [MATCH + to + KEYS] + + + + + + + + 1 + [QUERY] + + 1 + [MATCH + to + KEYS] + + + + + + + + 0 + [QUERY] + + 0 + [MATCH + to + KEYS] + + + + + + + + 0 + [WEIGHTED + MATCH] + + + + + COMBINE + MATCHES + + + + + 0 + [WEIGHT] + + + + + RETRIEVE + + + + + 1 + [WEIGHTED + MATCH] + + + + + 2 + [WEIGHTED + MATCH] + + + + + 3 + [WEIGHTED + MATCH] + + + + + + + + 1 + [WEIGHT] + + + + + + + + 2 + [WEIGHT] + + + + + + + + 3 + [WEIGHT] + + + + + 0 + [RETRIEVED] + + + + + 1 + [RETRIEVED] + + + + + 2 + [RETRIEVED] + + + + + 3 + [RETRIEVED] + + + + + \ No newline at end of file diff --git a/psyneulink/core/components/functions/nonstateful/transformfunctions.py b/psyneulink/core/components/functions/nonstateful/transformfunctions.py index 86c1db6b7b5..99733aad2bc 100644 --- a/psyneulink/core/components/functions/nonstateful/transformfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/transformfunctions.py @@ -2216,7 +2216,7 @@ def _function(self, elif operation == L0: if normalize: - normalization = np.sum(np.abs(vector - matrix)) + normalization = np.sum(np.abs(vector - matrix)) or 1 result = np.sum((1 - (np.abs(vector - matrix)) / normalization),axis=0) else: result = np.sum((np.abs(vector - matrix)),axis=0) diff --git a/psyneulink/core/components/ports/inputport.py b/psyneulink/core/components/ports/inputport.py index ce123b7118c..8fdcb4a2501 100644 --- a/psyneulink/core/components/ports/inputport.py +++ b/psyneulink/core/components/ports/inputport.py @@ -713,7 +713,8 @@ class InputPort(Port_Base): is executed and its variable is assigned None. If *default_input* is assigned *DEFAULT_VARIABLE*, then the `default value ` for the InputPort's `variable ` is used as its value. This is useful for assignment to a Mechanism that needs a constant (i.e., fixed value) as the input to its - `function `. + `function ` (such as a `bias unit ` in an + `AutodiffComposition`). .. note:: If `default_input ` is assigned *DEFAULT_VARIABLE*, then its `internal_only diff --git a/psyneulink/core/compositions/composition.py b/psyneulink/core/compositions/composition.py index 08f5e91e2bd..81b98a5aecd 100644 --- a/psyneulink/core/compositions/composition.py +++ b/psyneulink/core/compositions/composition.py @@ -3369,7 +3369,7 @@ class NodeRole(enum.Enum): BIAS A `Node ` for which one or more of its `InputPorts ` is assigned *DEFAULT_VARIABLE* as its `default_input ` (which provides it a prespecified - input that is constant across executions). Such a node can also be assigned as an `INPUT` and/or `ORIGIN`, + input that is constant across executions). Such a node can also be assigned as an `INPUT` and/or `ORIGIN`, if it receives input from outside the Composition and/or does not receive any `Projections ` from other Nodes within the Composition, respectively. This role cannot be modified programmatically. diff --git a/psyneulink/library/components/mechanisms/modulatory/learning/EMstoragemechanism.py b/psyneulink/library/components/mechanisms/modulatory/learning/EMstoragemechanism.py index f9d296eef87..fbd49f4d7a8 100644 --- a/psyneulink/library/components/mechanisms/modulatory/learning/EMstoragemechanism.py +++ b/psyneulink/library/components/mechanisms/modulatory/learning/EMstoragemechanism.py @@ -642,7 +642,7 @@ def _validate_params(self, request_set, target_set=None, context=None): f"a list or 2d np.array containing entries that have the same shape " f"({memory_matrix.shape}) as an entry (row) in 'memory_matrix' arg.") - # Ensure the number of fields is equal to the numbder of items in variable + # Ensure the number of fields is equal to the number of items in variable if FIELDS in request_set: fields = request_set[FIELDS] if len(fields) != len(self.variable): diff --git a/psyneulink/library/compositions/autodiffcomposition.py b/psyneulink/library/compositions/autodiffcomposition.py index 5ce5f1eb188..003b43db767 100644 --- a/psyneulink/library/compositions/autodiffcomposition.py +++ b/psyneulink/library/compositions/autodiffcomposition.py @@ -110,10 +110,17 @@ AutodiffComposition does not (currently) support the *automatic* construction of separate bias parameters. Thus, when constructing a model using an AutodiffComposition that corresponds to one in PyTorch, the `bias ` parameter of PyTorch modules should be set -to `False`. Trainable biases *can* be specified explicitly in an AutodiffComposition by including a -TransferMechanism that projects to the relevant Mechanism (i.e., implementing that layer of the network to -receive the biases) using a `MappingProjection` with a `matrix ` parameter that -implements a diagnoal matrix with values corresponding to the initial value of the biases. +to `False`. + + .. hint:: + Trainable biases *can* be specified explicitly in an AutodiffComposition by including a `ProcessingMechanism` + that projects to the relevant Mechanism (i.e., implementing that layer of the network to receive the biases) + using a `MappingProjection` with a `matrix ` parameter that implements a diagnoal + matrix with values corresponding to the initial value of the biases, and setting the `default_input + ` Parameter of one of the ProcessingMechanism's `input_ports + ` to *DEFAULT_VARIABLE*, and its `default_variable ` + equal to 1. ProcessingMechanisms configured in this way are assigned `NodeRole` `BIAS`, and the MappingProjection + is subject to learning. .. _AutodiffComposition_Nesting: @@ -951,8 +958,9 @@ def create_pathway(node)->list: return pathways - # Construct a pathway for each INPUT Node (except the TARGET Node) - pathways = [pathway for node in self.get_nodes_by_role(NodeRole.INPUT) + # Construct a pathway for each INPUT Node (including BIAS Nodes), except the TARGET Node) + pathways = [pathway + for node in (self.get_nodes_by_role(NodeRole.INPUT) + self.get_nodes_by_role(NodeRole.BIAS)) if node not in self.get_nodes_by_role(NodeRole.TARGET) for pathway in _get_pytorch_backprop_pathway(node)] @@ -1055,8 +1063,7 @@ def _get_loss(self, loss_spec): # and therefore requires a wrapper function to properly package inputs. return lambda x, y: nn.CrossEntropyLoss()(torch.atleast_2d(x), torch.atleast_2d(y.type(x.type()))) elif loss_spec == Loss.BINARY_CROSS_ENTROPY: - if version.parse(torch.version.__version__) >= version.parse('1.12.0'): - return nn.BCELoss() + return nn.BCELoss() elif loss_spec == Loss.L1: return nn.L1Loss(reduction='sum') elif loss_spec == Loss.NLL: @@ -1118,7 +1125,7 @@ def autodiff_forward(self, inputs, targets, trial_loss = 0 for i in range(len(curr_tensors_for_trained_outputs[component])): trial_loss += self.loss_function(curr_tensors_for_trained_outputs[component][i], - curr_target_tensors_for_trained_outputs[component][i]) + curr_target_tensors_for_trained_outputs[component][i]) pytorch_rep.minibatch_loss += trial_loss pytorch_rep.minibatch_loss_count += 1 diff --git a/psyneulink/library/compositions/emcomposition.py b/psyneulink/library/compositions/emcomposition.py index af84a5b7685..b8506217601 100644 --- a/psyneulink/library/compositions/emcomposition.py +++ b/psyneulink/library/compositions/emcomposition.py @@ -7,242 +7,8 @@ # ********************************************* EMComposition ************************************************* -# -# TODO: -# - QUESTION: -# - SHOULD differential of SoftmaxGainControl Node be included in learning? -# - SHOULD MEMORY DECAY OCCUR IF STORAGE DOES NOT? CURRENTLY IT DOES NOT (SEE EMStorage Function) - -# - FIX: Refactor field_weights to use None instead of 0 to specify value fields, and allow inputs to field_nodes -# - FIX: ALLOW SOFTMAX SPEC TO BE A DICT WITH PARAMETERS FOR _get_softmax_gain() FUNCTION -# - FIX: Concatenation: -# - LLVM for function and derivative -# - Add Concatenate to pytorchcreator_function -# - Deal with matrix assignment in LearningProjection LINE 643 -# - Reinstate test for execution of Concatenate with learning in test_emcomposition (currently commented out) -# - FIX: Softmax Gain Control: -# Test if it current works (they are added to Composition but not in BackProp processing pathway) -# Does backprop have to run through this if not learnable? -# If so, need to add PNL Function, with derivative and LLVM and Pytorch implementations -# - FIX: WRITE MORE TESTS FOR EXECUTION, WARNINGS, AND ERROR MESSAGES -# - learning (with and without learning field weights -# - 3d tuple with first entry != memory_capacity if specified -# - list with number of entries > memory_capacity if specified -# - input is added to the correct row of the matrix for each key and value for -# for non-contiguous keys (e.g, field_weights = [1,0,1])) -# - illegal field weight assignment -# - explicitly that storage occurs after retrieval -# - FIX: WARNING NOT OCCURRING FOR Normalize ON ZEROS WITH MULTIPLE ENTRIES (HAPPENS IF *ANY* KEY IS EVER ALL ZEROS) -# - FIX: IMPLEMENT LearningMechanism FOR RETRIEVAL WEIGHTS: -# - what is learning_update: AFTER doing? Use for scheduling execution of storage_node? -# ?? implement derivative for concatenate -# - FIX: implement add_storage_pathway to handle addition of storage_node as learning mechanism -# - in "_create_storage_learning_components()" assign "learning_update" arg -# as BEORE OR DURING instead of AFTER (assigned to learning_enabled arg of LearningMechanism) -# - FIX: Add StorageMechanism LearningProjections to Composition? -> CAUSES TEST FAILURES; NEEDS INVESTIGATION -# - FIX: Thresholded version of SoftMax gain (per Kamesh) -# - FIX: DEAL WITH INDEXING IN NAMES FOR NON-CONTIGUOUS KEYS AND VALUES (reorder to keep all keys together?) -# - FIX: _import_composition: -# - MOVE LearningProjections -# - MOVE Condition? (e.g., AllHaveRun) (OR PUT ON MECHANISM?) -# - FIX: IMPLEMENT _integrate_into_composition METHOD THAT CALLS _import_composition ON ANOTHER COMPOSITION -# - AND TRANSFERS RELEVANT ATTRIBUTES (SUCH AS MEMORY, query_input_nodeS, ETC., POSSIBLY APPENDING NAMES) -# - FIX: ADD Option to suppress field_weights when computing norm for weakest entry in EMStorageMechanism -# - FIX: GENERATE ANIMATION w/ STORAGE (uses Learning but not in usual way) -# - IMPLEMENT use OF multiple inheritance of EMComposition from AutoDiff and Composition - -# - FIX: DOCUMENTATION: -# - enable_learning vs. learning_field_weights -# - USE OF EMStore.storage_location (NONE => LOCAL, SPECIFIED => GLOBAL) -# - define "keys" and "values" explicitly -# - define "key weights" explicitly as field_weights for all non-zero values -# - make it clear that full size of memory is initialized (rather than "filling up" w/ use) -# - write examples for run() -# - FIX: ADD NOISE -# - FIX: ?ADD add_memory() METHOD FOR STORING W/O RETRIEVAL, OR JUST ADD retrieval_prob AS modulable Parameter -# - FIX: CONFIDENCE COMPUTATION (USING SIGMOID ON DOT PRODUCTS) AND REPORT THAT (EVEN ON FIRST CALL) -# MISC: -# - WRITE TESTS FOR INPUT_PORT and MATRIX SPECS CORRECT IN LATEST BRANCHs -# - ACCESSIBILITY OF DISTANCES (SEE BELOW): MAKE IT A LOGGABLE PARAMETER (I.E., WITH APPROPRIATE SETTER) -# ADD COMPILED VERSION OF NORMED LINEAR_COMBINATION FUNCTION TO LinearCombination FUNCTION: dot / (norm a * norm b) -# - DECAY WEIGHTS BY: -# ? 1-SOFTMAX / N (WHERE N = NUMBER OF ITEMS IN MEMORY) -# or -# 1/N (where N=number of items in memory, and thus gets smaller as N gets -# larger) on each storage (to some asymptotic minimum value), and store the new memory to the unit with the -# smallest weights (randomly selected among “ties" [i.e., within epsilon of each other]), I think we have a -# mechanism that can adaptively use its limited capacity as sensibly as possible, by re-cycling the units -# that have the least used memories. -# - MAKE "_store_memory" METHOD USE LEARNING INSTEAD OF ASSIGNMENT -# - make LearningMechanism that, instead of error, simply adds relevant input to weights (with all others = 0) -# - (relationship to Steven's Hebbian / DPP model?): - -# - ADD ADDITIONAL PARAMETERS FROM CONTENTADDRESSABLEMEMORY FUNCTION -# - ADAPTIVE TEMPERATURE: KAMESH FOR FORMULA -# - ADD MEMORY_DECAY TO ContentAddressableMemory FUNCTION (and compiled version by Samyak) -# - MAKE memory_template A CONSTRUCTOR ARGUMENT FOR default_variable - -# - FIX: PSYNEULINK: -# - TESTS: -# - WRITE TESTS FOR DriftOnASphere variable = scalar, 2d vector or 1d vector of correct and incorrect lengths -# - WRITE TESTS FOR LEARNING WITH LinearCombination of 1, 2 and 3 inputs -# -# - COMPILATION: -# - Remove CIM projections on import to another composition -# - Autodiff support for IdentityFunction -# - MatrixTransform to add normalization -# - _store() method to assign weights to memory -# - LLVM problem with ComparatorMechanism -# -# - pytorchcreator_function: -# SoftMax implementation: torch.nn.Softmax(dim=0) is not getting passed correctly -# Implement LinearCombination -# - MatrixTransform Function: -# -# - LEARNING - Backpropagation LearningFunction / LearningMechanism -# - DOCUMENTATION: -# - weight_change_matrix = gradient (result of delta rule) * learning_rate -# - ERROR_SIGNAL is OPTIONAL (only implemented when there is an error_source specified) -# - Backprop: (related to above?) handle call to constructor with default_variable = None -# - WRITE TESTS FOR USE OF COVARIATES AND RELATED VIOLATIONS: (see ScratchPad) -# - Use of LinearCombination with PRODUCT in output_source -# - Use of LinearCombination with PRODUCT in InputPort of output_source -# - Construction of LearningMechanism with Backprop: -# - MappingProjection / LearningMechanism: -# - Add learning_rate parameter to MappingProjection (if learnable is True) -# - Refactor LearningMechanism to use MappingProjection learning_rate specification if present -# - CHECK FOR EXISTING LM ASSERT IN pytests -# -# - AutodiffComposition: -# - replace handling / flattening of nested compositions with Pytorch.add_module (which adds "child" modules) -# - Check that error occurs for adding a controller to an AutodiffComposition -# - Check that if "epochs" is not in input_dict for Autodiff, then: -# - set to num_trials as default, -# - leave it to override num_trials if specified (add this to DOCUMENTATION) -# - Input construction has to be: -# - same for Autodiff in Python mode and PyTorch mode -# (NOTE: used to be that autodiff could get left in Python mode -# so only where tests for Autodiff happened did it branch) -# - AND different from Composition (in Python mode) -# - support use of pathway argument in Autodff -# - the following format doesn't work for LLVM (see test_identicalness_of_input_types: -# xor = pnl.AutodiffComposition(nodes=[input_layer,hidden_layer,output_layer]) -# xor.add_projections([input_to_hidden_wts, hidden_to_output_wts]) -# - DOCUMENTATION: execution_mode=ExecutionMode.Python allowed -# - Add warning of this on initial call to learn() -# -# - Composition: -# - Add default_execution_mode attribute to allow nested Compositions to be executed in -# different model than outer Composition -# - _validate_input_shapes_and_expand_for_all_trials: consolidate with get_input_format() -# - Generalize treatment of FEEDBACK specification: - # - FIX: ADD TESTS FOR FEEDBACK TUPLE SPECIFICATION OF Projection, DIRECT SPECIFICATION IN CONSTRUCTOR -# - FIX: why aren't FEEDBACK_SENDER and FEEDBACK_RECEIVER roles being assigned when feedback is specified? -# - add property that keeps track of warnings that have been issued, and suppresses repeats if specified -# - add property of Composition that lists it cycles -# - Add warning if termination_condition is trigged (and verbosePref is set) -# - Addition of projections to a ControlMechanism seems too dependent on the order in which the -# the ControlMechanism is constructed with respect to its afferents (if it comes before one, -# the projection to it (i.e., for monitoring) does not get added to the Composition -# - - IMPLEMENTATION OF LEARNING: NEED ERROR IF TRY TO CALL LEARN ON A COMPOSITION THAT HAS NO LEARNING MECHANISMS -# INCLUDING IN PYTHON MODE?? OR JUST ALLOW IT TO CONSTRUCT THE PATHWAY AUTOMATICALLY? -# - Change size argument in constructor to use standard numpy shape format if tupe, and PNL format if list -# - Write convenience Function for returning current time from context -# - requires it be called from execution within aComposition, error otherwise) -# - takes argument for time scale (e.g., TimeScale.TRIAL, TimeScale.RUN, etc.) -# - Add TimeMechanism for which this is the function, and can be configured to report at a timescale -# - Add Composition.run_status attribute assigned a context flag, with is_preparing property that checks it -# (paralleling handling of is_initializing) -# - Allow set of lists as specification for pathways in Composition -# - Add support for set notation in add_backpropagation_learning_pathway (to match add_linear_processing_pathway) -# see ScratchPad: COMPOSITION 2 INPUTS UNNESTED VERSION: MANY-TO-MANY -# - Make sure that shadow inputs (see InputPort_Shadow_Inputs) uses the same matrix as shadowed input. -# - composition.add_backpropagation_learning_pathway(): support use of set notation for multiple nodes that -# project to a single one. -# - add LearningProjections executed in EXECUTION_PHASE to self.projections -# and then remove MODIFIED 8/1/23 in _check_for_unused_projections -# - Why can't verbosePref be set directly on a composition? -# - Composition.add_nodes(): -# - should check, on each call to add_node, to see if one that has a releavantprojection and, if so, add it. -# - Allow [None] as argument and treat as [] -# - IF InputPort HAS default_input = DEFAULT_VARIABLE, -# THEN IT SHOULD BE IGNORED AS AN INPUT NODE IN A COMPOSITION -# - Add use of dict in pathways specification to map outputs from a set to inputs of another set -# (including nested comps) -# -# - ShowGraph: (show_graph) -# - don't show INPUT/OUTPUT Nodes for nested Comps in green/red -# (as they don't really receive input or generate output on a run -# - show feedback projections as pink (shouldn't that already be the case?) -# - add mode for showing projections as diamonds without show_learning (e.g., "show_projections") -# - figure out how to get storage_node to show without all other learning stuff -# - show 'operation' parameter for LinearCombination in show_node_structure=ALL -# - specify set of nodes to show and only show those -# - fix: show_learning=ALL (or merge from EM branch) -# -# - ControlMechanism -# - refactor ControlMechanism per notes of 11/3/21, including: -# FIX: 11/3/21 - MOVE _parse_monitor_specs TO HERE FROM ObjectiveMechanism -# - EpisodicMemoryMechanism: -# - make storage_prob and retrieval_prob parameters linked to function -# - make distance_field_weights a parameter linked to function -# -# - LinearCombination Function: -# - finish adding derivative (for if exponents are specified) -# - remove properties (use getter and setter for Parameters) -# -# - ContentAddressableMemory Function: -# - rename "cue" -> "query" -# - add field_weights as parameter of EM, and make it a shared_parameter ?as well as a function_parameter? - -# - DDM: -# - make reset_stateful_function_when a Parameter and arg in constructor -# and align with reset Parameter of IntegratorMechanism) -# -# - FIX: BUGS: -# - composition: -# - If any MappingProjection is specified from nested node to outer node, -# then direct projections are instantiated to the output_CIM of the outer comp, and the -# nested comp is treated as OUTPUT Node of outer comp even if all its projections are to nodes in outer comp -# LOOK IN add_projections? for nested comps -# - composition (?add_backpropagation_learning_pathway?): -# THIS FAILS: -# comp = Composition(name='a_outer') -# comp.add_backpropagation_learning_pathway([input_1, hidden_1, output_1]) -# comp.add_backpropagation_learning_pathway([input_1, hidden_1, output_2]) -# BUT THE FOLLOWING WORKS (WITH IDENTICAL show_graph(show_learning=True)): -# comp = Composition(name='a_outer') -# comp.add_backpropagation_learning_pathway([input_1, hidden_1, output_1]) -# comp.add_backpropagation_learning_pathway([hidden_1, output_2]) -# - show_graph(): QUIRK (BUT NOT BUG?): -# SHOWS TWO PROJECTIONS FROM a_inner.input_CIM -> hidden_x: -# ?? BECAUSE hidden_x HAS TWO input_ports SINCE ITS FUNCTION IS LinearCombination? -# a_inner = AutodiffComposition([hidden_x],name='a_inner') -# a_outer = AutodiffComposition([[input_1, a_inner, output_1], -# [a_inner, output_2]], -# a_outer.show_graph(show_cim=True) - -# -LearningMechanism / Backpropagation LearningFunction: -# - Construction of LearningMechanism on its own fails; e.g.: -# lm = LearningMechanism(learning_rate=.01, learning_function=BackPropagation()) -# causes the following error: -# TypeError("Logistic.derivative() missing 1 required positional argument: 'self'") -# - Adding GatingMechanism after Mechanisms they gate fails to implement gating projections -# (example: reverse order of the following in _construct_pathways -# self.add_nodes(self.softmax_nodes) -# self.add_nodes(self.field_weight_nodes) -# - add Normalize as option -# - Anytime a row's norm is 0, replace with 1s -# - WHY IS Concatenate NOT WORKING AS FUNCTION OF AN INPUTPORT (WASN'T THAT USED IN CONTEXT OF BUFFER? -# SEE NOTES TO KATHERINE -# -# - TESTS -# For duplicate Projections (e.g., assign a Mechanism in **monitor** of ControlMechanism -# and use comp.add_projection(MappingProjection(mointored, control_mech) -> should generate a duplicate -# then search for other instances of the same error message """ - Contents -------- @@ -250,17 +16,18 @@ - `Organization ` - `Operation ` * `EMComposition_Creation` - - `Fields ` + - `Memory ` - `Capacity ` + - `Fields ` - `Storage and Retrieval ` - `Learning ` * `EMComposition_Structure` - `Input ` - - `Memory ` + - `Memory ` - `Output ` * `EMComposition_Execution` - `Processing ` - - `Learning ` + - `Learning ` * `EMComposition_Examples` - `Memory Template and Fill ` - `Field Weights ` @@ -271,27 +38,36 @@ Overview -------- -The EMComposition implements a configurable, content-addressable form of episodic, or eternal memory, that emulates +The EMComposition implements a configurable, content-addressable form of episodic (or external) memory. It emulates an `EpisodicMemoryMechanism` -- reproducing all of the functionality of its `ContentAddressableMemory` `Function` -- -in the form of an `AutodiffComposition` that is capable of learning how to differentially weight different cues used -for retrieval,, and that adds the capability for `memory_decay `. Its `memory -` is configured using two arguments of its constructor: **memory_template** argument, that defines -how each entry in `memory ` is structured (the number of fields in each entry and the length -of each field); and **field_weights** argument, that defines which fields are used as cues for retrieval, i.e., "keys", -including whether and how they are differentially weighted in the match process used for retrieval); and which -fields are treated as "values" that are stored retrieved, but not used by the match process. The inputs to an -EMComposition, corresponding to each key ("query") and value field are assigned to each of its `INPUT ` -`Nodes ` (listed in its `query_input_nodes ` and `value_input_nodes -` attributes, respectively), and the retrieved values are represented as `OUTPUT -` `Nodes ` of the EMComposition. The `memory ` can be -accessed using its `memory ` attribute. +in the form of an `AutodiffComposition`. This allows it to backpropagate error signals based retrieved values to +it inputs, and learn how to differentially weight cues (queries) used for retrieval. It also adds the capability for +`memory_decay `. In these respects, it implements a variant of a `Modern Hopfield +Network `_, as well as some of the features of a `Transformer +`_ + +The `memory ` of an EMComposition is configured using two arguments of its constructor: +the **memory_template** argument, that defines the overall structure of its `memory ` (the +number of fields in each entry, the length of each field, and the number of entries); and **fields** argument, that +defines which fields are used as cues for retrieval (i.e., as "keys"), including whether and how they are weighted in +the match process used for retrieval, which fields are treated as "values" that are stored retrieved but not used by +the match process, and which are involved in learning. The inputs to an EMComposition, corresponding to its keys and +values, are assigned to each of its `INPUT ` `Nodes `: inputs to be matched to keys +(i.e., used as "queries") are assigned to its `query_input_nodes `; and the remaining +inputs assigned to it `value_input_nodes `. When the EMComposition is executed, the +retrieved values for all fields are returned as the result, and recorded in its `results ` +attribute. The value for each field is assigned as the `value ` of its `OUTPUT ` +`Nodes `. The input is then stored in its `memory `, with a probability +determined by its `storage_prob ` `Parameter`, and all previous memories decayed by its +`memory_decay_rate `. The `memory ` can be accessed using its +`memory ` Parameter. .. technical_note:: - The memories of an EMComposition are actually stored in the `matrix ` attribute of a - set of `MappingProjections ` (see `note below `). The `memory - ` attribute compiles and formats these as a single 3d array, the rows of which (axis 0) - are each entry, the columns of which (axis 1) are the fields of each entry, and the items of which (axis 2) - are the values of each field (see `EMComposition_Memory` for additional details). + The memories of an EMComposition are actually stored in the `matrix ` `Parameter` + of a set of `MappingProjections ` (see `note below `). The + `memory ` Parameter compiles and formats these as a single 3d array, the rows of which + (axis 0) are each entry, the columns of which (axis 1) are the fields of each entry, and the items of which + (axis 2) are the values of each field (see `EMComposition_Memory_Configuration` for additional details). .. _EMComposition_Organization: @@ -302,14 +78,14 @@ *Entries and Fields*. Each entry in memory can have an arbitrary number of fields, and each field can have an arbitrary length. However, all entries must have the same number of fields, and the corresponding fields must all have the same length across entries. Each field is treated as a separate "channel" for storage and retrieval, and is associated with -its own corresponding input (key or value) and output (retrieved value) `Node ` some or all of +its own corresponding input (key or value) and output (retrieved value) `Node `, some or all of which can be used to compute the similarity of the input (key) to entries in memory, that is used for retreieval. Fields can be differentially weighted to determine the influence they have on retrieval, using the `field_weights -` parameter (see `retrieval ` below). The number and -shape of the fields in each entry is specified in the **memory_template** argument of the EMComposition's constructor -(see `memory_template `). Which fields treated as keys (i.e., matched against queries during -retrieval) and which are treated as values (i.e., retrieved but not used for matching retrieval) is specified in the -**field_weights** argument of the EMComposition's constructor (see `field_weights `). +` parameter (see `retrieval ` below). The number and shape +of the fields in each entry is specified in the **memory_template** argument of the EMComposition's constructor (see +`memory_template `). Which fields treated as keys (i.e., matched against queries +during retrieval) and which are treated as values (i.e., retrieved but not used for matching retrieval) is specified in +the **field_weights** argument of the EMComposition's constructor (see `field_weights `). .. _EMComposition_Operation: @@ -317,39 +93,46 @@ *Retrieval.* The values retrieved from `memory ` (one for each field) are based on the relative similarity of the keys to the entries in memory, computed as the distance of each key and the -values in the corresponding field for each entry in memory. By default, normalized dot products (comparable to cosine -similarity) are used to compute the similarity of each query to each key in memory. These distances are then -weighted by the corresponding `field_weights ` for each field (if specified) and then -summed, and the sum is softmaxed to produce a softmax distribution over the entries in memory. That is then used to -generate a softmax-weighted average of the retrieved values across all fields, which is returned as the `result -` of the EMComposition's `execution ` (an EMComposition can also be -configured to return the entry with the lowest distance weighted by field, however then it is not compatible -with learning; see `softmax_choice `). +values in the corresponding field for each entry in memory. By default, for queries and keys that are vectors, +normalized dot products (comparable to cosine similarity) are used to compute the similarity of each query to each +key in memory; and if they are scalars the L0 norm is used. These distances are then weighted by the corresponding +`field_weights ` for each field (if specified) and then summed, and the sum is softmaxed +to produce a softmax distribution over the entries in memory. That is then used to generate a softmax-weighted average +of the retrieved values across all fields, which is returned as the `result ` of the EMComposition's +`execution ` (an EMComposition can also be configured to return the exact entry with the lowest +distance (weighted by field), however then it is not compatible with learning; see `softmax_choice +`). COMMENT: TBD DISTANCE ATTRIBUTES: - The distances used for the last retrieval is stored in XXXX and the distances of each of their corresponding fields + The distance used for the last retrieval is stored in XXXX, and the distances of each of their corresponding fields (weighted by `distance_field_weights `), are returned in XXX, respectively. COMMENT -*Storage.* The `inputs ` to the EMComposition's fields are stored in `memory -` after each execution, with a probability determined by `storage_prob -`. If `memory_decay_rate ` is specified, then the `memory -` is decayed by that amount after each execution. If `memory_capacity -` has been reached, then each new memory replaces the weakest entry (i.e., the one -with the smallest norm across all of its fields) in `memory `. +*Storage.* The `inputs ` to the EMComposition's fields are stored +in `memory ` after each execution, with a probability determined by `storage_prob +`. If `memory_decay_rate ` is specified, then +the `memory ` is decayed by that amount after each execution. If `memory_capacity +` has been reached, then each new memory replaces the weakest entry +(i.e., the one with the smallest norm across all of its fields) in `memory `. .. _EMComposition_Creation: Creation -------- -An EMComposition is created by calling its constructor, that takes the following arguments: +An EMComposition is created by calling its constructor. There are four major elements that can be configured: +the structure of its `memory ; the fields ` for the entries +in memory; how `storage and retrieval ` operate; and whether and how `learning +` is carried out. + +.. _EMComposition_Memory_Specification: - .. _EMComposition_Fields: +*Memory Specification* +~~~~~~~~~~~~~~~~~~~~~~ -*Field Specification* +These arguments are used to specify the shape and number of memory entries. .. _EMComposition_Memory_Template: @@ -394,18 +177,6 @@ zeros, and **memory_fill** is specified, then the matrix is filled with the value specified in **memory_fill**; otherwise, zeros are used to fill all entries. -.. _EMComposition_Memory_Capacity: - -*Memory Capacity* - -* **memory_capacity**: specifies the number of items that can be stored in the EMComposition's memory; when - `memory_capacity ` is reached, each new entry overwrites the weakest entry (i.e., the - one with the smallest norm across all of its fields) in `memory `. If `memory_template - ` is specified as a 3-item tuple or 3d list or array (see above), then that is used - to determine `memory_capacity ` (if it is specified and conflicts with either of those - an error is generated). Otherwise, it can be specified using a numerical value, with a default of 1000. The - `memory_capacity ` cannot be modified once the EMComposition has been constructed. - .. _EMComposition_Memory_Fill: * **memory_fill**: specifies the value used to fill the `memory `, based on the shape specified @@ -420,66 +191,130 @@ This can be ignored, as it does not affect the results of execution, but it can be averted by specifying `memory_fill ` to use small random values (e.g., ``memory_fill=(0,.001)``). +.. _EMComposition_Memory_Capacity: + +* **memory_capacity**: specifies the number of items that can be stored in the EMComposition's memory; when + `memory_capacity ` is reached, each new entry overwrites the weakest entry (i.e., the + one with the smallest norm across all of its fields) in `memory `. If `memory_template + ` is specified as a 3-item tuple or 3d list or array (see above), then that is used + to determine `memory_capacity ` (if it is specified and conflicts with either of those + an error is generated). Otherwise, it can be specified using a numerical value, with a default of 1000. The + `memory_capacity ` cannot be modified once the EMComposition has been constructed. + +.. _EMComposition_Fields: + +*Fields* +~~~~~~~~ + +These arguments are used to specify the names of the fields in a memory entry, which are used as keys and how those are +weighted for retrieval, and whether those weights are learned. + +.. _EMComposition_Field_Specification_Dict: + +* **fields**: a dict that specifies the names of the fields and their attributes. There must be an entry for each + field specified in the **memory_template**, and must have the following format: + + * *key*: a string that specifies the name of the field. + + * *value*: a dict or tuple with three entries; if a dict, the key to each entry must be the keyword specified below, + and if a tuple, the entries must appear in the following order: + + - *FIELD_WEIGHT* `specification ` - value must be a scalar or None. If it is a scalar, + the field is treated as a `retrieval key ` in `memory ` that + is weighted by that value during retrieval; if None, it is treated as a value in `memory ` + and the field cannot be reconfigured later. + + - *LEARN_FIELD_WEIGHT* `specification ` - value must be a boolean or a float; + if False, the field_weight for that field is not learned; if True, the field weight is learned using the + EMComposition's `learning_rate `; if a float, that is used as its learning_rate. + + - *TARGET_FIELD* `specification ` - value must be a boolean; if True, the value of the + `retrieved_node ` for that field conrtributes to the error computed during learning + and backpropagated through the EMComposition (see `Backpropagation of `); + if False, the retrieved value for that field does not contribute to the error; however, its field_weight can still + be learned if that is specfified in `learn_field_weight `. + + .. _note: + The **fields** argument is provided as a convenient and reliable way of specifying field attributes; + the dict itself is not retained as a `Parameter` or attribute of the EMComposition. + + The specifications provided in the **fields** argument are assigned to the corresponding Parameters of + the EMComposition which, alternatively, can be specified individually using the **field_names**, **field_weights**, + **learn_field_weights** and **target_fields** arguments of the EMComposition's constructor, as described below. + However, these and the **fields** argument cannot both be used together; doing so raises an error. + +.. _EMComposition_Field_Names: + +* **field_names**: a list specifies names that can be assigned to the fields. The number of names specified must match + the number of fields specified in the memory_template. If specified, the names are used to label the nodes of the + EMComposition; otherwise, the fields are labeled generically as "Key 0", "Key 1", and "Value 1", "Value 2", etc.. + .. _EMComposition_Field_Weights: -* **field_weights**: specifies which fields are used as keys, and how they are weighted during retrieval. The - number of entries specified must match the number of fields specified in **memory_template** (i.e., the size of - of its first dimension (axis 0)). All non-zero entries must be positive; these designate *keys* -- fields - that are used to match queries against entries in memory for retrieval (see `Match memories by field - `). Entries of 0 designate *values* -- fields that are ignored during the matching - process, but the values of which are retrieved and assigned as the `value ` of the - corresponding `retrieved_node `. This distinction between keys and value corresponds +* **field_weights**: specifies which fields are used as keys, and how they are weighted during retrieval. Fields + designated as keys used to match inputs (queries) against entries in memory for retrieval (see `Match memories by + field `); entries designated as *values* are ignored during the matching process, but + their values in memory are retrieved and assigned as the `value ` of the corresponding + `retrieved_node `. This distinction between keys and value corresponds to the format of a standard "dictionary," though in that case only a single key and value are allowed, whereas - here there can be one or more keys and any number of values; if all fields are keys, this implements a full form of - content-addressable memory. If **learn_field_weights** is True (and `enable_learning` - is either True or a list with True for at least one entry), then the field_weights can be modified during training - (this functions similarly to the attention head of a Transformer model, although at present the field can only be - scalar values rather than vecdtors); if **learn_field_weights** is False, then the field_weights are fixed. - The following options can be used to specify **field_weights**: - - * *None* (the default): all fields except the last are treated as keys, and are weighted equally for retrieval, - while the last field is treated as a value field; - - * *single entry*: all fields are treated as keys (i.e., used for retrieval) and weighted equally for retrieval. - if `normalize_field_weights ` is True, the value is ignored and all - of keys are weighted by 1 / number of keys (i.e., normalized), whereas if `normalize_field_weights - ` is False, then the value specified is used to weight the retrieval of - every keys. - - * *multiple non-zero entries*: If all entries are identical, the value is ignored and the corresponding keys - are weighted equally for retrieval; if the non-zero entries are non-identical, they are used to weight the - corresponding fields during retrieval (see `Weight fields `). In either case, - the remaining fields (with zero weights) are treated as value fields. - - _EMComposition_Field_Weights_Note: + in an EMComposition there can be one or more keys and any number of values; if all fields are keys, this implements a + full form of content-addressable memory. The following options can be used to specify **field_weights**: + + * *None* (the default): all fields except the last are treated as keys, and are assigned a weight of 1, + while the last field is treated as a value field (same as assiging it None in a list or tuple (see below). + + * *scalar*: all fields are treated as keys (i.e., used for retrieval) and weighted equally for retrieval. If + `normalize_field_weights ` is True, the value is divided by the number + of keys, whereas if `normalize_field_weights ` is False, then the value + specified is used to weight the retrieval of all keys with that value. + + .. note:: + At present these have the same result, since the `SoftMax` function is used to normalize the match between + queries and keys. However, other retrieval functions could be used in the future that would be affected by + the value of the `field_weights `. Therefore, it is recommended to leave + `normalize_field_weights ` set to True (the default) to ensure that + the `field_weights ` are normalized to sum to 1.0. + + * *list or tuple*: the number of entries must match the number of fields specified in **memory_template**, and + all entries must be either 0, a positive scalar value, or None. If all entries are identical, they are treated + as if a single value was specified (see above); if the entries are non-identical, any entries that are not None + are used to weight the corresponding fields during retrieval (see `Weight fields `), + including those that are 0 (though these will not be used in the retrieval process unless/until they are changed + to a positive value). If `normalize_field_weights ` is True, all non-None + entries are normalized so that they sum to 1.0; if False, the raw values are used to weight the retrieval of + the corresponding fields. All entries of None are treated as value fields, are not assigned a `field_weight_node + `, and are ignored during retrieval. These *cannot be modified* after the + EMComposition has been constructed (see note below). + + .. _EMComposition_Field_Weights_Change_Note: + .. note:: The field_weights can be modified after the EMComposition has been constructed, by assigning a new set of weights to its `field_weights ` `Parameter`. However, only field_weights associated with - key fields (i.e., were initially assigned non-zero field_weights) can be modified; the weights for value fields - (i.e., ones that were initially assigned a field_weight of 0) cannot be modified, and an attempt to do so will - generate an error. If a field initially used as a value may later need to be used as a key, it should be - assigned a non-zero field_weight when the EMComposition is constructed; it can then be assigned 0 just after - construction, and later changed as needed. + key fields (i.e., that were initially assigned non-zero field_weights) can be modified; the weights for value + fields (i.e., ones that were initially assigned a field_weight of None) cannot be modified, and doing so raises + an error. If a field that will be used initially as a value may later need to be used as a key, it should be + assigned a `field_weight ` of 0 at construction (rather than None), which can then + later be changed as needed. .. technical_note:: - The reason that only field_weights for keys can be modified is that only `field_weight_nodes - ` for keys are constructed, since ones for values would have no effect on the - retrieval process and thus are uncecessary. + The reason that field_weights can be modified only for keys is that `field_weight_nodes + ` are constructed only for keys, since ones for values would have no effect + on the retrieval process and therefore are uncecessary (and can be misleading). -.. _EMComposition_Normalize_Field_Weights: -* **normalize_field_weights**: specifies whether the `field_weights ` are normalized - or their raw values are used. If True, the `field_weights ` are normalized so that - they sum to 1.0, and are used to weight (i.e., multiply) the corresponding fields during retrieval (see `Weight - fields `). If False, the raw values of the `field_weights ` - are used to weight the retrieved value of each field. This setting is ignored if **field_weights** - is None or `concatenate_queries ` is in effect. +* **learn_field_weights**: if **enable_learning** is True, this specifies which field_weights are subject to learning, + and optionally the `learning_rate ` for each (see `learn_field_weights + ` below for details of specification). -.. _EMComposition_Field_Names: +.. _EMComposition_Normalize_Field_Weights: -* **field_names**: specifies names that can be assigned to the fields. The number of names specified must - match the number of fields specified in the memory_template. If specified, the names are used to label the - nodes of the EMComposition. If not specified, the fields are labeled generically as "Key 0", "Key 1", etc.. +* **normalize_field_weights**: specifies whether the `field_weights ` are normalized or + their raw values are used. If True, the value of all non-None `field_weights ` are + normalized so that they sum to 1.0, and the normalized values are used to weight (i.e., multiply) the corresponding + fields during retrieval (see `Weight fields `). If False, the raw values of the + `field_weights ` are used to weight the retrieved value of each field. This setting + is ignored if **field_weights** is None or `concatenate_queries ` is True. .. _EMComposition_Concatenate_Queries: @@ -503,27 +338,20 @@ are always preserved, even when `concatenate_queries ` is True, so that separate inputs can be provided for each key, and the value of each key can be retrieved separately. -.. _EMComposition_Memory_Decay_Rate - -* **memory_decay_rate**: specifies the rate at which items in the EMComposition's memory decay; the default rate - is *AUTO*, which sets it to 1 / `memory_capacity `, such that the oldest memories - are the most likely to be replaced when `memory_capacity ` is reached. If - **memory_decay_rate** is set to 0 None or False, then memories do not decay and, when `memory_capacity - ` is reached, the weakest memories are replaced, irrespective of order of entry. - .. _EMComposition_Retrieval_Storage: *Retrieval and Storage* +~~~~~~~~~~~~~~~~~~~~~~~ -* **storage_prob** : specifies the probability that the inputs to the EMComposition will be stored as an item in +* **storage_prob**: specifies the probability that the inputs to the EMComposition will be stored as an item in `memory ` on each execution. -* **normalize_memories** : specifies whether queries and keys in memory are normalized before computing their dot +* **normalize_memories**: specifies whether queries and keys in memory are normalized before computing their dot products. .. _EMComposition_Softmax_Gain: -* **softmax_gain** : specifies the gain (inverse temperature) used for softmax normalizing the combined distances +* **softmax_gain**: specifies the gain (inverse temperature) used for softmax normalizing the combined distances used for retrieval (see `EMComposition_Execution` below). The following options can be used: * numeric value: the value is used as the gain of the `SoftMax` Function for the EMComposition's @@ -548,7 +376,7 @@ .. _EMComposition_Softmax_Choice: -* **softmax_choice** : specifies how the `SoftMax` Function of the EMComposition's `softmax_node +* **softmax_choice**: specifies how the `SoftMax` Function of the EMComposition's `softmax_node ` is used, with the combined distances, to generate a retrieved item; the following are the options that can be used and the retrieved value they produce: @@ -562,7 +390,7 @@ .. warning:: Use of the *ARG_MAX* and *PROBABILISTIC* options is not compatible with learning, as these implement a discrete choice and thus are not differentiable. Constructing an EMComposition with **softmax_choice** set to either of - these options and **enable_learning** set to True (or a list with any True entries) will generate a warning, and + these options and **learn_field_weights** set to True (or a list with any True entries) will generate a warning, and calling the EMComposition's `learn ` method will generate an error; it must be changed to *WEIGHTED_AVG* to execute learning. @@ -571,37 +399,91 @@ passed as *ARG_MAX_INDICATOR*; and *PROBALISTIC* is passed as *PROB_INDICATOR*; the other SoftMax options are not currently supported. +.. _EMComposition_Memory_Decay_Rate: + +* **memory_decay_rate**: specifies the rate at which items in the EMComposition's memory decay; the default rate + is *AUTO*, which sets it to 1 / `memory_capacity `, such that the oldest memories + are the most likely to be replaced when `memory_capacity ` is reached. If + **memory_decay_rate** is set to 0 None or False, then memories do not decay and, when `memory_capacity + ` is reached, the weakest memories are replaced, irrespective of order of entry. + +.. _EMComposition_Purge_by_Weight: + +* **purge_by_field_weight**: specifies whether `field_weights ` are used in determining + which memory entry is replaced when a new memory is `stored `. If True, the norm of each + entry is multiplied by its `field_weight ` to determine which entry is the weakest and + will be replaced. + .. _EMComposition_Learning: *Learning* +~~~~~~~~~~ -EMComposition supports two forms of learning -- error backpropagation and the learning of `field_weights -` -- that can be configured by the following arguments of the EMComposition's constructor: - -* **enable_learning** : specifies whether learning is enabled for the EMComposition and, if so, which `retrieved_nodes - ` are used to compute errors, and propagate these back through the network. If - **enable_learning** is False, then no learning occurs, including of `field_weights `). - If it is True, then all of the `retrieved_nodes ` participate in learning: For - those that do not project to an outer Composition (i.e., one in which the EMComposition is `nested - `), a `TARGET ` node is constructed for each, and used to compute errors that - are backpropagated through the network to its `query_input_nodes ` and - `value_input_nodes `, and on to any nodes that project to it from a composition - in which the EMComposition is `nested `; retrieved_nodes that *do* project to an outer - Composition receive their errors from those nodes, which are also backpropagated through the EMComposition. - If **enable_learning** is a list, then only the `retrieved_nodes ` specified in the - list participate in learning, and errors are computed only for those nodes. The list must contain the same - number of entries as there are `fields ` and corresponding `retreived_nodes - `, and each entry must be a boolean that specifies whether the corresponding - `retrieved_node ` is used for learning. - -* **learn_field_weights** : specifies whether `field_weights ` are modifiable during - learning (see `field_weights ` and `Learning ` for additional - information. For learning of `field_weights ` to occur, **enable_learning** must - also be True, or it must be a list with at least one True entry. If **learn_field_weights** is True, - **use_gating_for_weighting** must be False (see `note `). - -* **learning_rate** : specifies the rate at which `field_weights ` are learned if - **learn_field_weights** is True; see `Learning ` for additional information. +EMComposition supports two forms of learning: error backpropagation through the entire Composition, and the learning +of `field_weights ` within it. Learning is enabled by setting the **enable_learning** +argument of the EMComposition's constructor to True, and optionally specifying the **learn_field_weights** argument +(as detailed below). If **enable_learning** is False, no learning of any kind occurs; if it is True, then both forms +of learning are enable. + +.. _EMComposition_Error_BackPropagation + +*Backpropagation of error*. If **enable_learning** is True, then the values retrieved from `memory +` when the EMComposition is executed during learning can be used for error computation +and backpropagation through the EMComposition to its inputs. By default, the values of all of its `retrieved_nodes +` are included. For those that do not project to an outer Composition (i.e., one in +which the EMComposition is `nested `), a `TARGET ` node is constructed +for each, and used to compute errors that are backpropagated through the network to its `query_input_nodes +` and `value_input_nodes `, and on to any +nodes that project to those from a Composition within which the EMComposition is `nested `. +Retrieved_nodes that *do* project to an outer Composition receive their errors from those nodes, which are also +backpropagated through the EMComposition. Fields can be selecdtively specified for learning in the **fields** argument +or the **target_fields** argument of the EMComposition's constructor, as detailed below. + +*Field Weight Learning*. If **enable_learning** is True, then the `field_weights ` can +be learned, by specifing these either in the **fields** argument or the **learn_field_weights** argument of the +EMComposition's constructor, as detailed below. Learning field_weights implements a function comparable to the learning +in an attention head of the `Transformer `_ architecture, although at present the +field can only be scalar values rather than vectors or matrices, and it cannot receive input. These capabilities will +be added in the future. + +The following arguments of the EMComposition's constructor can be used to configure learning: + +* **enable_learning**: specifies whether any learning is enabled for the EMComposition. If False, + no learning occurs; ` if True, then both error backpropagation and learning of `field_weights + ` can occur. If **enable_learning** is True, **use_gating_for_weighting** + must be False (see `note `). + +.. _EMComposition_Target_Fields: + +* **target_fields**: specifies which `retrieved_nodes ` are used to compute + errors, and propagate these back through the EMComposition to its `query ` and + `value_input_nodes `. If this is None (the default), all `retrieved_nodes + ` are used; if it is a list or tuple, then it must have the same number of entries + as there are fields, and each entry must be a boolean specifying whether the corresponding `retrieved_nodes + ` participate in learning, and errors are computed only for those nodes. This can + also be specified in a dict for the **fields** argument (see `fields `). + +.. _EMComposition_Field_Weights_Learning: + +* **learn_field_weights**: specifies which field_weights are subject to learning, and optionally the `learning_rate + ` for each; this can also be specified in a dict for the **fields** argument (see + `fields `). The following specfications can be used: + + * *None*: all field_weights are subject to learning, and the `learning_rate ` for the + EMComposition is used as the learning_rate for all field_weights. + + * *bool*: If True, all field_weights are subject to learning, and the `learning_rate ` + for the EMComposition is used as the learning rate for all field_weights; if False, no field_weights are + subject to learning, regardless of `enable_learning `. + + * *list* or *tuple*: must be the same length as the number of fields specified in the memory_template, and each entry + must be either True, False or a positive scalar value. If True, the corresponding field_weight is subject to + learning and the `learning_rate ` for the EMComposition is used to specify the + learning_ rate for that field; if False, the corresponding field_weight is not subject to learning; if a scalar + value is specified, it is used as the `learning_rate` for that field. + +* **learning_rate**: specifies the learning_rate for any `field_weights ` for which a + learning_rate is not individually specified in the **learn_field_weights** argument (see above). .. _EMComposition_Structure: @@ -617,7 +499,7 @@ ` of the EMComposition, listed in its `query_input_nodes ` and `value_input_nodes ` attributes, respectively, -.. _EMComposition_Memory: +.. _EMComposition_Memory_Structure: *Memory* ~~~~~~~~ @@ -672,8 +554,8 @@ * **Input**. The inputs to the EMComposition are provided to the `query_input_nodes ` and `value_input_nodes `. The former are used for matching to the corresponding - `fields ` of the `memory `, while the latter are retrieved but not used - for matching. + `fields ` of the `memory `, while the latter are retrieved + but not used for matching. * **Concatenation**. By default, the input to every `query_input_node ` is passed to a to its own `match_node ` through a `MappingProjection` that computes its @@ -700,9 +582,9 @@ (or the `concatenate_queries_node ` if `concatenate_queries ` attribute is True) are passed through a `MappingProjection` that computes the distance between the corresponding input (query) and each memory (key) for the corresponding field, - the result of which is possed to the corresponding `match_node `. By default, the - distance is computed as the normalized dot product (i.e., between the normalized query vector and the normalized - key for the corresponding `field `, that is comparable to using cosine similarity). However, + the result of which is possed to the corresponding `match_node `. By default, the distance + is computed as the normalized dot product (i.e., between the normalized query vector and the normalized key for the + corresponding `field `, that is comparable to using cosine similarity). However, if `normalize_memories ` is set to False, just the raw dot product is computed. The distance can also be customized by specifying a different `function ` for the `MappingProjection` to the `match_node `. The result is assigned as the `value @@ -751,9 +633,12 @@ `gain ` parameter; if None is specified, the default value of the `Softmax` Function is used as the `gain ` parameter (see `Softmax_Gain ` for additional details). +.. _EMComposition_Retreived_Values: + * **Retrieve values by field**. The vector of softmax weights for each memory generated by the `softmax_node ` is passed through the Projections to the each of the `retrieved_nodes - ` to compute the retrieved value for each field. + ` to compute the retrieved value for each field, which is assigned as the value + of the corresponding `retrieved_node `. * **Decay memories**. If `memory_decay ` is True, then each of the memories is decayed by the amount specified in `memory_decay_rate `. @@ -768,19 +653,19 @@ .. _EMComposition_Storage: -* **Store memories**. After the values have been retrieved, the inputs to for each field (i.e., values in the - `query_input_nodes ` and `value_input_nodes `) - are added by the `storage_node ` as a new entry in `memory `, - replacing the weakest one if `memory_capacity ` has been reached. +* **Store memories**. After the values have been retrieved, the `storage_node ` + adds the inputs to each field (i.e., values in the `query_input_nodes ` and + `value_input_nodes `) as a new entry in `memory `, + replacing the weakest one. The weakest memory is the one with the lowest norm, multipled by its `field_weight + ` if `purge_by_field_weight ` is True. .. technical_note:: - This is done by adding the input vectors to the the corresponding rows of the `matrix ` - of the `MappingProjection` from the `combined_matches_node ` to each - of the `retrieved_nodes `, as well as the `matrix ` - parameter of the `MappingProjection` from each `query_input_node ` to the - corresponding `match_node ` (see note `above ` for - additional details). If `memory_capacity ` has been reached, then the weakest - memory (i.e., the one with the lowest norm across all fields) is replaced by the new memory. + The norm of each entry is calculated by adding the input vectors to the the corresponding rows of + the `matrix ` of the `MappingProjection` from the `combined_matches_node + ` to each of the `retrieved_nodes `, + as well as the `matrix ` parameter of the `MappingProjection` from each + `query_input_node ` to the corresponding `match_node + ` (see note `above ` for additional details). COMMENT: FROM CodePilot: (OF HISTORICAL INTEREST?) @@ -798,25 +683,24 @@ *Training* ~~~~~~~~~~ -If `learn ` is called, `enable_learning ` is True or a list with -any True entries, then errors will be computed for each of the `retrieved_nodes ` -that is specified for learning (see `Learning ` for details about specification). These errors -are derived either from any errors backprpated to the EMComposition from an outer Composition in which it is `nested -`, or locally by the difference between the `retrieved_nodes ` -and the `target_nodes ` that are created for each of the `retrieved_nodes -` that do not project to an outer Composition. These errors are then backpropagated -through the EMComposition to the `query_input_nodes ` and `value_input_nodes -`, and on to any nodes that project to it from a composition in which the -EMComposition is `nested `. - -If `learn_field_weights ` is also True, then the `field_weights -` are modified to minimize the error passed to the EMComposition retrieved nodes, using the -`learning_rate ` specified in the `learning_rate ` attribute. -If `learn_field_weights ` is False (or `run ` is called, then the -If `learn_field_weights ` is False), then the `field_weights -` are not modified and the EMComposition is simply executed -without any modification, and error signals are passed to the nodes that project to its `query_input_nodes -` and `value_input_nodes `. +If `learn ` is called, `enable_learning ` is True, then errors +will be computed for each of the `retrieved_nodes ` that is specified for learning +(see `Learning ` for details about specification). These errors are derived either from any +errors backprpated to the EMComposition from an outer Composition in which it is `nested `, +or locally by the difference between the `retrieved_nodes ` and the `target_nodes +` that are created for each of the `retrieved_nodes ` +that do not project to an outer Composition. These errors are then backpropagated through the EMComposition to the +`query_input_nodes ` and `value_input_nodes `, +and on to any nodes that project to it from a composition in which the EMComposition is `nested `. + +If `learn_field_weights` is also specified, then the corresponding `field_weights ` are +modified to minimize the error passed to the EMComposition retrieved nodes that have been specified for learning, +using the `learning_rate ` for them in `learn_field_weights +` or the default `learning rate ` for the EMComposition. +If `enable_learning ` is False (or `run ` is called rather than +`learn `, then the `field_weights ` are not modified, and no error +signals are passed to the nodes that project to its `query_input_nodes ` and +`value_input_nodes `. .. note:: The only parameters modifable by learning in the EMComposition are its `field_weights @@ -827,7 +711,7 @@ Although memory storage is implemented as a form of learning (though modification of MappingProjection `matrix ` parameters; see `memory storage `), this occurs irrespective of how EMComposition is run (i.e., whether `learn ` or `run - ` is called), and is not affected by the `learn_field_weights ` + ` is called), and is not affected by the `enable_learning ` or `learning_rate ` attributes, which pertain only to whether the `field_weights ` are modified during learning. Furthermore, when run in PyTorch mode, storage is executed after the forward() and backward() passes are complete, and is not considered as part of the @@ -898,7 +782,7 @@ >>> em = EMComposition(memory_template=(4,2,5)) both of which create a memory with 4 entries, each with 2 fields of length 5. The contents of `memory -` can be inspected using the `memory ` attribute:: +` can be inspected using the `memory ` attribute:: >>> em.memory [[array([0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0.])], @@ -1038,7 +922,6 @@ --------------- """ import numpy as np -import graph_scheduler as gs import warnings import psyneulink.core.scheduling.condition as conditions @@ -1066,12 +949,20 @@ from psyneulink.core.llvm import ExecutionMode -__all__ = ['EMComposition', 'EMCompositionError', 'WEIGHTED_AVG', 'PROBABILISTIC'] +__all__ = ['EMComposition', 'EMCompositionError', 'FIELD_WEIGHT', 'LEARN_FIELD_WEIGHT', + 'PROBABILISTIC', 'TARGET_FIELD','WEIGHTED_AVG'] +# softmax_choice options: STORAGE_PROB = 'storage_prob' WEIGHTED_AVG = ALL PROBABILISTIC = PROB_INDICATOR +# specs for entry of fields specification dict +FIELD_WEIGHT = 'field_weight' +LEARN_FIELD_WEIGHT = 'learn_field_weight' +TARGET_FIELD = 'target_field' + +# Node names QUERY_NODE_NAME = 'QUERY' QUERY_AFFIX = f' [{QUERY_NODE_NAME}]' VALUE_NODE_NAME = 'VALUE' @@ -1123,23 +1014,26 @@ def field_weights_setter(field_weights, owning_component=None, context=None): raise EMCompositionError(f"The number of field_weights ({len(field_weights)}) must match the number of fields " f"{len(owning_component.field_weights)}") if owning_component.normalize_field_weights: - field_weights = field_weights / np.sum(field_weights) + denominator = np.sum(np.where(field_weights is not None, field_weights, 0)) + field_weights = [fw / denominator if fw is not None else None for fw in field_weights] + + # Assign new fields_weights to default_variable of field_weight_nodes field_wt_node_idx = 0 # Needed since # of field_weight_nodes may be less than # of fields + # and now way to know if user has assigned a value where there used to be a None for i, field_weight in enumerate(field_weights): - # Check if original value was 0 (i.e., a value node), in which case disallow change - if not owning_component.parameters.field_weights.default_value[i]: + # Check if original value was None (i.e., a value node), in which case disallow change + if owning_component.parameters.field_weights.default_value[i] is None: if field_weight: raise EMCompositionError(f"Field '{owning_component.field_names[i]}' of '{owning_component.name}' " - f"was originally assigned as a value node (i.e., with a field_weight = 0); " + f"was originally assigned as a value node (i.e., with a field_weight = None); " f"this cannot be changed after construction. If you want to change it to a " - f"key field, you must re-construct the EMComposition using a non-zero value " - f"for its field in the `field_weights` arg, " - f"which can then be changed to 0 after construction.") + f"key field, you must re-construct the EMComposition using a scalar " + f"for its field in the `field_weights` arg (including 0.") continue owning_component.field_weight_nodes[field_wt_node_idx].input_port.defaults.variable = field_weights[i] owning_component.field_weights[i] = field_weights[i] field_wt_node_idx += 1 - return field_weights + return np.array(field_weights) def get_softmax_gain(v, scale=1, base=1, entropy_weighting=.1)->float: """Compute the softmax gain (inverse temperature) based on the entropy of the distribution of values. @@ -1166,17 +1060,19 @@ class EMComposition(AutodiffComposition): memory_template=[[0],[0]], \ memory_fill=0, \ memory_capacity=None, \ + fields=None, \ + field_names=None, \ field_weights=None, \ + learn_field_weights=False, \ + learning_rate=True, \ normalize_field_weights=True, \ - field_names=None, \ concatenate_queries=False, \ normalize_memories=True, \ softmax_gain=THRESHOLD, \ storage_prob=1.0, \ memory_decay_rate=AUTO, \ enable_learning=True, \ - learn_field_weights=True, \ - learning_rate=True, \ + target_fields=None, \ use_gating_for_weighting=False, \ name="EM_Composition" \ ) @@ -1190,49 +1086,67 @@ class EMComposition(AutodiffComposition): --------- memory_template : tuple, list, 2d or 3d array : default [[0],[0]] - specifies the shape of an item to be stored in the EMComposition's memory; - see `memory_template ` for details. + specifies the shape of an item to be stored in the EMComposition's memory + (see `memory_template ` for details). memory_fill : scalar or tuple : default 0 - specifies the value used to fill the memory when it is initialized; - see `memory_fill ` for details. + specifies the value used to fill the memory when it is initialized + (see `memory_fill ` for details). memory_capacity : int : default None specifies the number of items that can be stored in the EMComposition's memory; - see `memory_capacity ` for details. + (see `memory_capacity ` for details). + + fields : dict[tuple[field weight, learning specification]] : default None + each key must a string that is the name of a field, and its value a dict or tuple that specifies that field's + `field_weight `, `learn_field_weights `, and + `target_fields ` specifications (see `fields ` for details + of specificaton format). The **fields** arg replaces the **field_names**, **field_weights** + **learn_field_weights**, and **target_fields** arguments, and specifying any of these raises an error. + + field_names : list or tuple : default None + specifies the names assigned to each field in the memory_template (see `field names ` + for details). If the **fields** argument is specified, this is not necessary and specifying raises an error. + + field_weights : list or tuple : default (1,0) + specifies the relative weight assigned to each key when matching an item in memory (see `field weights + ` for additional details). If the **fields** argument is specified, this + is not necessary and specifying raises an error. - field_weights : tuple : default (1,0) - specifies the relative weight assigned to each key when matching an item in memory; - see `field weights ` for additional details. + learn_field_weights : bool or list[bool, int, float]: default False + specifies whether the `field_weights ` are learnable and, if so, optionally what + the learning_rate is for each field (see `learn_field_weights ` for + specifications). If the **fields** argument is specified, this is not necessary and specifying raises an error. + + learning_rate : float : default .01 + specifies the default learning_rate for `field_weights ` not + specified in `learn_field_weights ` (see `learning_rate + ` for additional details). normalize_field_weights : bool : default True specifies whether the **fields_weights** are normalized over the number of keys, or used as absolute - weighting values when retrieving an item from memory; see `normalize_field weights - ` for additional details. - - field_names : list : default None - specifies the optional names assigned to each field in the memory_template; - see `field names ` for details. + weighting values when retrieving an item from memory (see `normalize_field weights + ` for additional details). concatenate_queries : bool : default False specifies whether to concatenate the keys into a single field before matching them to items in - the corresponding fields in memory; see `concatenate keys ` for details. + the corresponding fields in memory (see `concatenate keys ` for details). normalize_memories : bool : default True - specifies whether keys and memories are normalized before computing their dot product (similarity); - see `Match memories by field ` for additional details. + specifies whether keys and memories are normalized before computing their dot product (similarity) + (see `Match memories by field ` for additional details). softmax_gain : float, ADAPTIVE or CONTROL : default 1.0 - specifies the temperature used for softmax normalizing the distance of queries and keys in memory; - see `Softmax normalize matches over fields ` for additional details. + specifies the temperature used for softmax normalizing the distance of queries and keys in memory + (see `Softmax normalize matches over fields ` for additional details). softmax_threshold : float : default .0001 - specifies the threshold used to mask out small values in the softmax calculation; + specifies the threshold used to mask out small values in the softmax calculation see *mask_threshold* under `Thresholding and Adaptive Gain ` for details). softmax_choice : WEIGHTED_AVG, ARG_MAX, PROBABILISTIC : default WEIGHTED_AVG - specifies how the softmax over distances of queries and keys in memory is used for retrieval; - see `softmax_choice ` for a description of each option. + specifies how the softmax over distances of queries and keys in memory is used for retrieval + (see `softmax_choice ` for a description of each option). storage_prob : float : default 1.0 specifies the probability that an item will be stored in `memory ` @@ -1240,23 +1154,23 @@ class EMComposition(AutodiffComposition): additional details). memory_decay_rate : float : AUTO - specifies the rate at which items in the EMComposition's memory decay; - see `memory_decay_rate ` for details. + specifies the rate at which items in the EMComposition's memory decay + (see `memory_decay_rate ` for details). - enable_learning : bool or list[bool]: default True - specifies whether a learning pathway is constructed for each `field ` - of the EMComposition. If it is a list, each item must be ``True`` or ``False`` and the number of items - must be equal to the number of `fields specified; see `enable_learning - ` for additional details. + purge_by_field_weights : bool : False + specifies whether `fields_weights ` are used to determine which memory to + replace when a new one is stored (see `purge_by_field_weight ` for details). - learn_field_weights : bool : default True - specifies whether `field_weights ` are learnable during training; - requires **enable_learning** to be True to have any effect, and **use_gating_for_weighting** must be False; - see `learn_field_weights ` for additional details. + enable_learning : bool : default True + specifies whether learning is enabled for the EMCComposition (see `Learning ` + for additional details); **use_gating_for_weighting** must be False. - learning_rate : float : default .01 - specifies rate at which `field_weights ` are learned - if `learn_field_weights ` is True. + target_fields : list[bool]: default None + specifies whether a learning pathway is constructed for each `field ` + of the EMComposition. If it is a list, each item must be ``True`` or ``False`` and the number of items + must be equal to the number of `fields specified (see `Target Fields + ` for additional details). If the **fields** argument is specified, + this is not necessary and specifying raises an error. # 7/10/24 FIX: STILL TRUE? DOES IT PRECLUDE USE OF EMComposition as a nested Composition?? .. technical_note:: @@ -1275,7 +1189,8 @@ class EMComposition(AutodiffComposition): memory : ndarray 3d array of entries in memory, in which each row (axis 0) is an entry, each column (axis 1) is a field, and - each item (axis 2) is the value for the corresponding field; see `EMComposition_Memory` for additional details. + each item (axis 2) is the value for the corresponding field (see `EMComposition_Memory_Specification` for + additional details). .. note:: This is a read-only attribute; memories can be added to the EMComposition's memory either by @@ -1287,8 +1202,12 @@ class EMComposition(AutodiffComposition): .. _EMComposition_Parameters: memory_capacity : int - determines the number of items that can be stored in `memory `; see `memory_capacity - ` for additional details. + determines the number of items that can be stored in `memory ` + (see `memory_capacity ` for additional details). + + field_names : list[str] + determines which names that can be used to label fields in `memory ` + (see `field_names ` for additional details). field_weights : tuple[float] determines which fields of the input are treated as "keys" (non-zero values) that are used to match entries in @@ -1298,37 +1217,42 @@ class EMComposition(AutodiffComposition): see `field_weights ` additional details. The field_weights can be changed by assigning a new list of weights to the `field_weights ` attribute, however only the weights for fields used as `keys ` can be changed (see - `EMComposition_Field_Weights_Note` for additional details). + `EMComposition_Field_Weights_Change_Note` for additional details). - normalize_field_weights : bool : default True - determines whether `fields_weights ` are normalized over the number of keys, or - used as absolute weighting values when retrieving an item from memory; see `normalize_field weights - ` for additional details. + learn_field_weights : bool or list[bool, int, float] + determines whether the `field_weight ` for each `field + is learnable (see `learn_field_weights ` for additional details). - field_names : list[str] - determines which names that can be used to label fields in `memory `; see - `field_names ` for additional details. + learning_rate : float + determines the default learning_rate for `field_weights ` + not specified in `learn_field_weights ` + (see `learning_rate ` for additional details). + + normalize_field_weights : bool + determines whether `fields_weights ` are normalized over the number of keys, or + used as absolute weighting values when retrieving an item from memory (see `normalize_field weights + ` for additional details). concatenate_queries : bool determines whether keys are concatenated into a single field before matching them to items in `memory - `; see `concatenate keys ` for additional details. + ` for additional details). normalize_memories : bool - determines whether keys and memories are normalized before computing their dot product (similarity); - see `Match memories by field ` for additional details. + determines whether keys and memories are normalized before computing their dot product (similarity) + (see `Match memories by field ` for additional details). softmax_gain : float, ADAPTIVE or CONTROL - determines gain (inverse temperature) used for softmax normalizing the summed distances of queries and keys in - memory by the `SoftMax` Function of the `softmax_node `; see `Softmax normalize - distances ` for additional details. + determines gain (inverse temperature) used for softmax normalizing the summed distances of queries + and keys in memory by the `SoftMax` Function of the `softmax_node ` + (see `Softmax normalize distances ` for additional details). softmax_threshold : float - determines the threshold used to mask out small values in the softmax calculation; - see *mask_threshold* under `Thresholding and Adaptive Gain ` for details). + determines the threshold used to mask out small values in the softmax calculation + (see *mask_threshold* under `Thresholding and Adaptive Gain ` for details). softmax_choice : WEIGHTED_AVG, ARG_MAX or PROBABILISTIC - determines how the softmax over distances of queries and keys in memory is used for retrieval; - see `softmax_choice ` for a description of each option. + determines how the softmax over distances of queries and keys in memory is used for retrieval + (see `softmax_choice ` for a description of each option). storage_prob : float determines the probability that an item will be stored in `memory ` @@ -1336,26 +1260,20 @@ class EMComposition(AutodiffComposition): additional details). memory_decay_rate : float - determines the rate at which items in the EMComposition's memory decay (see `memory_decay_rate - ` for details). - - enable_learning : bool or list[bool] - determines whether `learning ` is enabled for the EMComposition, allowing any error - received by the `retrieved_nodes ` to be propagated to the corresponding - `query_input_nodes ` and `value_input_nodes - `, and on to any `Nodes ` that project to them. - If True, learning is enabled for all fields and if False learning is disabled for all fields; If it is a - list, then each entry specifies whether learning is enabled or disabled for the corresponding field - see `Learning ` and `Fields ` for additional details. - - learn_field_weights : bool - determines whether `field_weights ` are learnable during training; - requires `enable_learning ` to be True or a list with at least one True - entry for the corresponding field; see `Learning ` for additional details. + determines the rate at which items in the EMComposition's memory decay + (see `memory_decay_rate ` for details). - learning_rate : float - determines whether the rate at which `field_weights ` are learned - if `learn_field_weights` is True; see `Learning ` for additional details. + purge_by_field_weights : bool + determines whether `fields_weights ` are used to determine which memory to + replace when a new one is stored (see `purge_by_field_weight ` for details). + + enable_learning : bool + determines whether learning is enabled for the EMCComposition + (see `Learning ` for additional details). + + target_fields : list[bool] + determines which fields convey error signals during learning + (see `Target Fields ` for additional details). .. _EMComposition_Nodes: @@ -1394,7 +1312,7 @@ class EMComposition(AutodiffComposition): as the corresponding `query_input_nodes `. weighted_match_nodes : list[ProcessingMechanism] - `ProcessingMechanisms ` that combine the `field weight ` + `ProcessingMechanisms ` that combine the `field weight ` for each `key field ` with the dot product computed by the corresponding the `match_node `. These are only implemented if `use_gating_for_weighting ` is False (see `Weight distances ` @@ -1426,7 +1344,7 @@ class EMComposition(AutodiffComposition): ` (see `Retrieve values by field ` for additional details). These are assigned the same names as the `query_input_nodes ` and `value_input_nodes ` to which they correspond appended with the suffix - * [RETRIEVED]*, and are in the same order as `input_nodes_by_fields ` + * [RETRIEVED]*, and are in the same order as `input_nodes ` to which to which they correspond. storage_node : EMStorageMechanism @@ -1441,13 +1359,13 @@ class EMComposition(AutodiffComposition): any subequent processing is done (i.e., in a composition in which the EMComposition may be embededded. input_nodes : list[ProcessingMechanism] - Full list of `INPUT ` `Nodes ` ordered with query_input_nodes first - followed by value_input_nodes; used primarily for internal computations - - input_nodes_by_fields : list[ProcessingMechanism] Full list of `INPUT ` `Nodes ` in the same order specified in the **field_names** argument of the constructor and in `self.field_names `. + query_and_value_input_nodes : list[ProcessingMechanism] + Full list of `INPUT ` `Nodes ` ordered with query_input_nodes first + followed by value_input_nodes; used primarily for internal computations. + """ componentCategory = EM_COMPOSITION @@ -1472,7 +1390,7 @@ class Parameters(AutodiffComposition.Parameters): see `enable_learning ` :default value: True - :type: ``bool`` or ``list`` + :type: ``bool`` field_names see `field_names ` @@ -1486,18 +1404,18 @@ class Parameters(AutodiffComposition.Parameters): :default value: None :type: ``numpy.ndarray`` + learn_field_weights + see `learn_field_weights ` + + :default value: True + :type: ``numpy.ndarray`` + learning_rate see `learning_results ` :default value: [] :type: ``list`` - learn_field_weights - see `learn_field_weights ` - - :default value: True - :type: ``bool`` - memory see `memory ` @@ -1534,6 +1452,12 @@ class Parameters(AutodiffComposition.Parameters): :default value: True :type: ``bool`` + purge_by_field_weights + see `purge_by_field_weights ` + + :default value: False + :type: ``bool`` + random_state see `random_state ` @@ -1564,9 +1488,11 @@ class Parameters(AutodiffComposition.Parameters): memory = Parameter(None, loggable=True, getter=_memory_getter, read_only=True) memory_template = Parameter([[0],[0]], structural=True, valid_types=(tuple, list, np.ndarray), read_only=True) memory_capacity = Parameter(1000, structural=True) - field_weights = Parameter(None, setter=field_weights_setter) - normalize_field_weights = Parameter(True) field_names = Parameter(None, structural=True) + field_weights = Parameter([1], setter=field_weights_setter) + learn_field_weights = Parameter(False, structural=True) + learning_rate = Parameter(.001, modulable=True) + normalize_field_weights = Parameter(True) concatenate_queries = Parameter(False, structural=True) normalize_memories = Parameter(True) softmax_gain = Parameter(1.0, modulable=True) @@ -1574,9 +1500,9 @@ class Parameters(AutodiffComposition.Parameters): softmax_choice = Parameter(WEIGHTED_AVG, modulable=False, specify_none=True) storage_prob = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM]) memory_decay_rate = Parameter(AUTO, modulable=True) + purge_by_field_weights = Parameter(False, structural=True) enable_learning = Parameter(True, structural=True) - learn_field_weights = Parameter(True, structural=True) - learning_rate = Parameter(.001, modulable=True) + target_fields = Parameter(None, read_only=True, structural=True) random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed') seed = Parameter(DEFAULT_SEED(), modulable=True, setter=_seed_setter) @@ -1597,27 +1523,29 @@ def _validate_memory_template(self, memory_template): else: return f"must be tuple of length 2 or 3, or a list or array that is either 2 or 3d." + def _validate_field_names(self, field_names): + if field_names and not all(isinstance(item, str) for item in field_names): + return f"must be a list of strings." + def _validate_field_weights(self, field_weights): if field_weights is not None: if not np.atleast_1d(field_weights).ndim == 1: return f"must be a scalar, list of scalars, or 1d array." - if any([field_weight < 0 for field_weight in field_weights]): + if len(field_weights) == 1 and field_weights[0] is None: + raise EMCompositionError(f"must be a scalar, since there is only one field specified.") + if any([field_weight < 0 for field_weight in field_weights if field_weight is not None]): return f"must be all be positive values." def _validate_normalize_field_weights(self, normalize_field_weights): if not isinstance(normalize_field_weights, bool): return f"must be all be a boolean value." - def _validate_field_names(self, field_names): - if field_names and not all(isinstance(item, str) for item in field_names): - return f"must be a list of strings." - - def _validate_enable_learning(self, enable_learning): - if isinstance(enable_learning, list): - if not all(isinstance(item, bool) for item in enable_learning): - return f"can only contains bools as entries." - elif not isinstance(enable_learning, bool): - return f"must be a bool or list of bools." + def _validate_learn_field_weights(self, learn_field_weights): + if isinstance(learn_field_weights, (list, np.ndarray)): + if not all(isinstance(item, (bool, int, float)) for item in learn_field_weights): + return f"can only contains bools, ints or floats as entries." + elif not isinstance(learn_field_weights, bool): + return f"must be a bool or list of bools, ints and/or floats." def _validate_memory_decay_rate(self, memory_decay_rate): if memory_decay_rate is None or memory_decay_rate == AUTO: @@ -1642,8 +1570,11 @@ def __init__(self, memory_template:Union[tuple, list, np.ndarray]=[[0],[0]], memory_capacity:Optional[int]=None, memory_fill:Union[int, float, tuple, RANDOM]=0, + fields:Optional[dict]=None, field_names:Optional[list]=None, - field_weights:tuple=None, + field_weights:Union[int,float,list,tuple]=None, + learn_field_weights:Union[bool,list,tuple]=None, + learning_rate:float=None, normalize_field_weights:bool=True, concatenate_queries:bool=False, normalize_memories:bool=True, @@ -1652,9 +1583,9 @@ def __init__(self, softmax_choice:Optional[Union[WEIGHTED_AVG, ARG_MAX, PROBABILISTIC]]=WEIGHTED_AVG, storage_prob:float=1.0, memory_decay_rate:Union[float,AUTO]=AUTO, - enable_learning:Union[bool,list]=True, - learn_field_weights:bool=True, - learning_rate:float=None, + purge_by_field_weights:bool=False, + enable_learning:bool=True, + target_fields:Optional[Union[list, tuple, np.ndarray]]=None, use_storage_node:bool=True, use_gating_for_weighting:bool=False, random_state=None, @@ -1665,17 +1596,30 @@ def __init__(self, # Construct memory -------------------------------------------------------------------------------- memory_fill = memory_fill or 0 # FIX: GET RID OF THIS ONCE IMPLEMENTED AS A Parameter - self._validate_memory_specs(memory_template, memory_capacity, memory_fill, field_weights, field_names, name) + self._validate_memory_specs(memory_template, + memory_capacity, + memory_fill, + field_weights, + field_names, + name) + memory_template, memory_capacity = self._parse_memory_template(memory_template, memory_capacity, memory_fill) - field_weights, field_names, concatenate_queries = self._parse_fields(field_weights, - normalize_field_weights, - field_names, - concatenate_queries, - normalize_memories, - learning_rate, - name) + (field_names, + field_weights, + learn_field_weights, + target_fields, + concatenate_queries) = self._parse_fields(fields, + field_names, + field_weights, + learn_field_weights, + learning_rate, + normalize_field_weights, + concatenate_queries, + normalize_memories, + target_fields, + name) if memory_decay_rate is AUTO: memory_decay_rate = 1 / memory_capacity @@ -1690,27 +1634,29 @@ def __init__(self, super().__init__(name=name, memory_template = memory_template, memory_capacity = memory_capacity, - field_weights = field_weights, field_names = field_names, + field_weights = field_weights, + learn_field_weights=learn_field_weights, + learning_rate = learning_rate, normalize_field_weights = normalize_field_weights, concatenate_queries = concatenate_queries, + normalize_memories = normalize_memories, softmax_gain = softmax_gain, softmax_threshold = softmax_threshold, softmax_choice = softmax_choice, storage_prob = storage_prob, memory_decay_rate = memory_decay_rate, - normalize_memories = normalize_memories, - enable_learning=enable_learning, - learn_field_weights = learn_field_weights, - learning_rate = learning_rate, + purge_by_field_weights = purge_by_field_weights, + enable_learning = enable_learning, + target_fields = target_fields, random_state = random_state, seed = seed, **kwargs ) - self._validate_options_with_learning(enable_learning, + self._validate_options_with_learning(learn_field_weights, use_gating_for_weighting, - learn_field_weights, + enable_learning, softmax_choice) self._construct_pathways(self.memory_template, @@ -1724,8 +1670,8 @@ def __init__(self, self.storage_prob, self.memory_decay_rate, self._use_storage_node, - self.enable_learning, self.learn_field_weights, + self.enable_learning, self._use_gating_for_weighting) # if torch_available: @@ -1767,7 +1713,7 @@ def __init__(self, # self.scheduler.add_condition(self.storage_node, conditions.AllHaveRun(*self.retrieved_nodes)) # # Generates the desired execution set for a single pass, and runs with expected results, - # but generates warning messages for every node of the following sort: + # but raises a warning messages for every node of the following sort: # /Users/jdc/PycharmProjects/PsyNeuLink/psyneulink/core/scheduling/scheduler.py:120: # UserWarning: BeforeNCalls((EMStorageMechanism STORAGE MECHANISM), 1) is dependent on # (EMStorageMechanism STORAGE MECHANISM), but you are assigning (EMStorageMechanism STORAGE MECHANISM) @@ -1827,7 +1773,7 @@ def _validate_memory_specs(self, memory_template, memory_capacity, memory_fill, for entry in memory_template: if not (len(entry) == num_fields and np.all([len(entry[i]) == len(memory_template[0][i]) for i in range(num_fields)])): - raise EMCompositionError(f"The 'memory_template' arg for {self.name} must specify a list " + raise EMCompositionError(f"The 'memory_template' arg for {name} must specify a list " f"or 2d array that has the same shape for all entries.") # Validate memory_fill specification (int, float, or tuple with two scalars) @@ -1837,24 +1783,35 @@ def _validate_memory_specs(self, memory_template, memory_capacity, memory_fill, raise EMCompositionError(f"The 'memory_fill' arg ({memory_fill}) specified for {name} " f"must be a float, int or len tuple of ints and/or floats.") - # If enable_learning is a list of bools, it must match the len of 1st dimension (axis 0) of memory_template: - if isinstance(self.enable_learning, list) and len(self.enable_learning) != num_fields: - raise EMCompositionError(f"The number of items ({len(self.enable_learning)}) in the 'enable_learning' arg " - f"for {name} must match the number of fields in memory " - f"({num_fields}).") + # If learn_field_weights is a list of bools, it must match the len of 1st dimension (axis 0) of memory_template: + if isinstance(self.learn_field_weights, list) and len(self.learn_field_weights) != num_fields: + raise EMCompositionError(f"The number of items ({len(self.learn_field_weights)}) in the " + f"'learn_field_weights' arg for {name} must match the number of " + f"fields in memory ({num_fields}).") + + _field_wts = np.atleast_1d(field_weights) + _field_wts_len = len(_field_wts) # If len of field_weights > 1, must match the len of 1st dimension (axis 0) of memory_template: - field_weights_len = len(np.atleast_1d(field_weights)) - if field_weights is not None and field_weights_len > 1 and field_weights_len != num_fields: - raise EMCompositionError(f"The number of items ({field_weights_len}) in the 'field_weights' arg " - f"for {name} must match the number of items in an entry of memory " - f"({num_fields}).") + if field_weights is not None: + if (_field_wts_len > 1 and _field_wts_len != num_fields): + raise EMCompositionError(f"The number of items ({_field_wts_len}) in the 'field_weights' arg " + f"for {name} must match the number of items in an entry of memory " + f"({num_fields}).") + # Deal with this here instead of Parameter._validate_field_weights since this is called before super() + if all([fw is None for fw in _field_wts]): + raise EMCompositionError(f"The entries in 'field_weights' arg for {name} can't all be 'None' " + f"since that will preclude the construction of any keys.") + if all([fw in {0, None} for fw in _field_wts]): + warnings.warn(f"All of the entries in the 'field_weights' arg for {name} are either None or " + f"set to 0; this will result in no retrievals unless/until the 0(s) is(are) changed " + f"to a positive value.") # If field_names has more than one value it must match the first dimension (axis 0) of memory_template: if field_names and len(field_names) != num_fields: raise EMCompositionError(f"The number of items ({len(field_names)}) " f"in the 'field_names' arg for {name} must match " - f"the number of fields ({field_weights_len}).") + f"the number of fields ({_field_wts_len}).") def _parse_memory_template(self, memory_template, memory_capacity, memory_fill)->(np.ndarray,int): """Construct memory from memory_template and memory_fill @@ -1943,15 +1900,55 @@ def _construct_entries(entry_template, num_entries, memory_fill=None)->np.ndarra return memory, memory_capacity def _parse_fields(self, + fields, + field_names, field_weights, + learn_field_weights, + learning_rate, normalize_field_weights, - field_names, concatenate_queries, normalize_memories, - learning_rate, - name): + target_fields, + name)->(list, list, list, bool): + + def _parse_fields_dict(name, fields, num_fields)->(list,list,list,list): + """Parse fields dict into field_names, field_weights, learn_field_weights, and target_fields""" + if len(fields) != num_fields: + raise EMCompositionError(f"The number of entries ({len(fields)}) in the dict specified in the 'fields' " + f"arg of '{name}' does not match the number of fields in its memory " + f"({self.num_fields}).") + field_names = [None] * num_fields + field_weights = [None] * num_fields + learn_field_weights = [None] * num_fields + target_fields = [None] * num_fields + for i, field_name in enumerate(fields): + field_names[i] = field_name + if isinstance(fields[field_name], (tuple, list)): + # field specified as tuple or list + field_weights[i] = fields[field_name][0] + learn_field_weights[i] = fields[field_name][1] + target_fields[i] = fields[field_name][2] + elif isinstance(fields[field_name], dict): + # field specified as dict + field_weights[i] = fields[field_name][FIELD_WEIGHT] + learn_field_weights[i] = fields[field_name][LEARN_FIELD_WEIGHT] + target_fields[i] = fields[field_name][TARGET_FIELD] + else: + raise EMCompositionError(f"Unrecognized specification for field '{field_name}' in the 'fields' " + f"arg of '{name}'; it must be a tuple, list or dict.") + return field_names, field_weights, learn_field_weights, target_fields - num_fields = len(self.entry_template) + self.num_fields = len(self.entry_template) + + if fields: + # If a fields dict has been specified, use that to assign field_names, field_weights & learn_field_weights + if any([field_names, field_weights, learn_field_weights, target_fields]): + warnings.warn(f"The 'fields' arg for '{name}' was specified, so any of the 'field_names', " + f"'field_weights', 'learn_field_weights' or 'target_fields' args will be ignored.") + (field_names, + field_weights, + learn_field_weights, + target_fields) = _parse_fields_dict(name, fields, self.num_fields) # Deal with default field_weights if field_weights is None: @@ -1959,36 +1956,43 @@ def _parse_fields(self, field_weights = [1] else: # Default is to treat all fields as keys except the last one, which is the value - field_weights = [1] * num_fields - field_weights[-1] = 0 + field_weights = [1] * self.num_fields + field_weights[-1] = None field_weights = np.atleast_1d(field_weights) - # Fill out field_weights, normalizing if specified: - if len(field_weights) == 1: - if normalize_field_weights: - parsed_field_weights = np.repeat(field_weights / np.sum(field_weights), len(self.entry_template)) - else: - parsed_field_weights = np.repeat(field_weights[0], len(self.entry_template)) + if normalize_field_weights and not all([fw == 0 for fw in field_weights]): # noqa: E127 + fld_wts_0s_for_Nones = [fw if fw is not None else 0 for fw in field_weights] + parsed_field_weights = fld_wts_0s_for_Nones / np.sum(fld_wts_0s_for_Nones) + parsed_field_weights = [pfw if fw is not None else None + for pfw, fw in zip(parsed_field_weights, field_weights)] else: - if normalize_field_weights: - parsed_field_weights = np.array(field_weights) / np.sum(field_weights) - else: - parsed_field_weights = field_weights + parsed_field_weights = field_weights + + # If only one field_weight was specified, but there is more than one field, + # repeat the single weight for each field + if len(field_weights) == 1 and self.num_fields > 1: + parsed_field_weights = np.repeat(parsed_field_weights, self.num_fields) + + # Make sure field_weight learning was not specified for any value fields (since they don't have field_weights) + if isinstance(learn_field_weights, (list, tuple, np.ndarray)): + for i, lfw in enumerate(learn_field_weights): + if parsed_field_weights[i] is None and lfw is not False: + warnings.warn(f"Learning was specified for field '{field_names[i]}' in the 'learn_field_weights' " + f"arg for '{name}', but it is not allowed for value fields; it will be ignored.") # Memory structure Parameters parsed_field_names = field_names.copy() if field_names is not None else None # Set memory field attributes - self.num_fields = len(self.entry_template) - keys_weights = [i for i in parsed_field_weights if i != 0] + keys_weights = [i for i in parsed_field_weights if i is not None] self.num_keys = len(keys_weights) # Get indices of field_weights that specify keys and values: - self.key_indices = np.flatnonzero(parsed_field_weights) + self.key_indices = [i for i, pfw in enumerate(parsed_field_weights) if pfw is not None] assert len(self.key_indices) == self.num_keys, \ f"PROGRAM ERROR: number of keys ({self.num_keys}) does not match number of " \ f"non-zero values in field_weights ({len(self.key_indices)})." - self.value_indices = np.where(parsed_field_weights==0)[0] + self.value_indices = [i for i, pfw in enumerate(parsed_field_weights) if pfw is None] self.num_values = self.num_fields - self.num_keys assert len(self.value_indices) == self.num_values, \ f"PROGRAM ERROR: number of values ({self.num_values}) does not match number of " \ @@ -2014,7 +2018,6 @@ def _parse_fields(self, # field weights are not all equal and/or # normalize_memories is False and/or # there is only one key - fw_error_msg = nm_error_msg = fw_correction_msg = nm_correction_msg = None if self.num_keys == 1: error_msg = f"there is only one key" correction_msg = "" @@ -2028,8 +2031,17 @@ def _parse_fields(self, warnings.warn(f"The 'concatenate_queries' arg for '{name}' is True but {error_msg}; " f"concatenation will be ignored.{correction_msg}") + # Deal with default target_fields + if target_fields is None: + target_fields = [True] * self.num_fields + self.learning_rate = learning_rate - return parsed_field_weights, parsed_field_names, parsed_concatenate_queries + + return (parsed_field_names, + parsed_field_weights, + learn_field_weights, + target_fields, + parsed_concatenate_queries) def _parse_memory_shape(self, memory_template): """Parse shape of memory_template to determine number of entries and fields""" @@ -2063,8 +2075,8 @@ def _construct_pathways(self, storage_prob, memory_decay_rate, use_storage_node, - enable_learning, learn_field_weights, + enable_learning, use_gating_for_weighting, ): """Construct Nodes and Pathways for EMComposition""" @@ -2076,15 +2088,15 @@ def _construct_pathways(self, # First, construct Nodes of Composition with their Projections self.query_input_nodes = self._construct_query_input_nodes(field_weights) self.value_input_nodes = self._construct_value_input_nodes(field_weights) - self.input_nodes = self.query_input_nodes + self.value_input_nodes + self.query_and_value_input_nodes = self.query_input_nodes + self.value_input_nodes # Get list of nodes in order specified in self.field_names - self.input_nodes_by_fields = [None] * len(field_weights) + self.input_nodes = [None] * len(field_weights) for i in range(self.num_keys): - self.input_nodes_by_fields[self.key_indices[i]] = self.query_input_nodes[i] + self.input_nodes[self.key_indices[i]] = self.query_input_nodes[i] for i in range(self.num_values): - self.input_nodes_by_fields[self.value_indices[i]] = self.value_input_nodes[i] - assert all(self.input_nodes_by_fields), "PROGRAM ERROR: input_nodes_by_fields not fully populated." + self.input_nodes[self.value_indices[i]] = self.value_input_nodes[i] + assert all(self.input_nodes), "PROGRAM ERROR: input_nodes not fully populated." self.concatenate_queries_node = self._construct_concatenate_queries_node(concatenate_queries) self.match_nodes = self._construct_match_nodes(memory_template, memory_capacity, @@ -2118,6 +2130,42 @@ def _construct_pathways(self, assert not self.field_weight_nodes, \ f"PROGRAM ERROR: There should be no field_weight_nodes for concatenated queries." + # Create field_index map for nodes and projections + _field_index_map = {} + for i in range(len(self.input_nodes)): + _field_index_map[self.input_nodes[i]] = i + if self._use_storage_node: + _field_index_map[self.storage_node.path_afferents[i]] = i + _field_index_map[self.retrieved_nodes[i]] = i + _field_index_map[self.retrieved_nodes[i].path_afferents[0]] = i + if self.concatenate_queries: + for proj in self.concatenate_queries_node.path_afferents: + _field_index_map[proj] = _field_index_map[proj.sender.owner] + _field_index_map[self.concatenate_queries_node] = None + _field_index_map[self.match_nodes[0]] = None + _field_index_map[self.match_nodes[0].path_afferents[0]] = None + _field_index_map[self.match_nodes[0].efferents[0]] = None + else: + # Input nodes, Projections to storage_node, retrieval Projections and retrieved_nodes + for match_node in self.match_nodes: + field_index = _field_index_map[match_node.path_afferents[0].sender.owner] + # match_node + _field_index_map[match_node] = field_index + # afferent MEMORY Projection + _field_index_map[match_node.path_afferents[0]] = field_index + # efferent Projection to weighted_match_node + _field_index_map[match_node.efferents[0]] = field_index + # weighted_match_node + _field_index_map[match_node.efferents[0].receiver.owner] = field_index + # Projection to combined_matches_node + _field_index_map[match_node.efferents[0].receiver.owner.efferents[0]] = field_index + for field_weight_node in self.field_weight_nodes: + # Weight nodes; + _field_index_map[field_weight_node] = _field_index_map[field_weight_node.efferents[0].receiver.owner] + # Weight Projections; + _field_index_map[field_weight_node.efferents[0]] = _field_index_map[field_weight_node] + self._field_index_map = _field_index_map + # Construct Pathways -------------------------------------------------------------------------------- # LEARNING NOT ENABLED -------------------------------------------------- @@ -2141,8 +2189,8 @@ def _construct_pathways(self, # Query-specific pathways if not self.concatenate_queries: if self.num_keys == 1: - self.add_linear_processing_pathway([self.query_input_nodes[i], - self.match_nodes[i], + self.add_linear_processing_pathway([self.query_input_nodes[0], + self.match_nodes[0], self.softmax_node]) else: for i in range(self.num_keys): @@ -2204,15 +2252,12 @@ def _construct_value_input_nodes(self, field_weights)->list: where i is selected randomly without replacement from (0->memory_capacity) """ - # Get indices of field_weights that specify keys: - value_indices = np.where(field_weights == 0)[0] - - assert len(value_indices) == self.num_values, \ + assert len(self.value_indices) == self.num_values, \ f"PROGRAM ERROR: number of values ({self.num_values}) does not match number of " \ - f"non-zero values in field_weights ({len(value_indices)})." + f"non-zero values in field_weights ({len(self.value_indices)})." value_input_nodes = [ProcessingMechanism( - input_shapes=len(self.entry_template[value_indices[i]]), + input_shapes=len(self.entry_template[self.value_indices[i]]), name= f'{self.value_names[i]} [VALUE]') for i in range(self.num_values)] @@ -2271,6 +2316,7 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q normalize=args[0][NORMALIZE]), name=f'MEMORY')}, name='MATCH')] + match_nodes[0]._field_idx = 0 # One node for each key else: @@ -2414,19 +2460,19 @@ def _construct_softmax_node(self, memory_capacity, softmax_gain, softmax_thresho return softmax_node def _validate_options_with_learning(self, - enable_learning, - use_gating_for_weighting, learn_field_weights, + use_gating_for_weighting, + enable_learning, softmax_choice): - if use_gating_for_weighting and learn_field_weights: - warnings.warn(f"The 'learn_field_weights' option for '{self.name}' cannot be used with " + if use_gating_for_weighting and enable_learning: + warnings.warn(f"The 'enable_learning' option for '{self.name}' cannot be used with " f"'use_gating_for_weighting' set to True; this will generate an error if its " f"'learn' method is called. Set 'use_gating_for_weighting' to True in order " f"to enable learning of field weights.") if softmax_choice in {ARG_MAX, PROBABILISTIC} and enable_learning: warnings.warn(f"The 'softmax_choice' arg of '{self.name}' is set to '{softmax_choice}' with " - f"'enable_learning' set to True (or a list); this will generate an error if its " + f"'enable_learning' set to True; this will generate an error if its " f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") def _construct_retrieved_nodes(self, memory_template)->list: @@ -2474,7 +2520,8 @@ def _construct_storage_node(self, and from the value_input_node to the retrieved_node for values. The `function ` of the `EMSorageMechanism` that takes the following arguments: - - **variable** -- template for an `entry ` in `memory`; + - **variable** -- template for an `entry ` + in `memory`; - **fields** -- the `input_nodes ` for the corresponding `fields ` of an `entry ` in `memory `; @@ -2503,7 +2550,7 @@ def _construct_storage_node(self, storage_node = EMStorageMechanism(default_variable=[self.input_nodes[i].value[0] for i in range(self.num_fields)], fields=[self.input_nodes[i] for i in range(self.num_fields)], - field_types=[0 if weight == 0 else 1 for weight in field_weights], + field_types=[0 if weight is None else 1 for weight in field_weights], concatenation_node=concatenate_queries_node, memory_matrix=memory_template, learning_signals=learning_signals, @@ -2519,12 +2566,36 @@ def _set_learning_attributes(self): # 7/10/24 FIX: SHOULD THIS ALSO BE CONSTRAINED BY VALUE OF field_weights FOR CORRESPONDING FIELD? # (i.e., if it is zero then not learnable? or is that a valid initial condition?) for projection in self.projections: - if (projection.sender.owner in self.field_weight_nodes - and self.enable_learning - and self.learn_field_weights): - projection.learnable = True + + projection_is_field_weight = projection.sender.owner in self.field_weight_nodes + + if self.enable_learning is False or not projection_is_field_weight: + projection.learnable = False + continue + + # Use globally specified learning_rate + if self.learn_field_weights is None: # Default, which should be treat same as True + learning_rate = True + elif isinstance(self.learn_field_weights, (bool, int, float)): + learning_rate = self.learn_field_weights + + # Use individually specified learning_rate else: + # FIX: THIS NEEDS TO USE field_index_map, BUT THAT DOESN'T SEEM TO HAVE THE WEIGHT PROJECTION YET + learning_rate = self.learn_field_weights[self._field_index_map[projection]] + + if learning_rate is False: projection.learnable = False + continue + elif learning_rate is True: + # Default (EMComposition's learning_rate) is used for all field_weight Projections: + learning_rate = self.learning_rate + assert isinstance(learning_rate, (int, float)), \ + (f"PROGRAM ERROR: learning_rate for {projection.sender.owner.name} is not a valid value.") + + projection.learnable = True + if projection.learning_mechanism: + projection.learning_mechanism.learning_rate = learning_rate #endregion @@ -2567,10 +2638,9 @@ def _encode_memory(self, context=None): """ # Get least used slot (i.e., weakest memory = row of matrix with lowest weights) computed across all fields - purge_by_field_weights = False field_norms = np.array([np.linalg.norm(field, axis=1) for field in [row for row in self.parameters.memory.get(context)]]) - if purge_by_field_weights: + if self.purge_by_field_weights: field_norms *= self.field_weights row_norms = np.sum(field_norms, axis=1) idx_of_min = np.argmin(row_norms) @@ -2623,11 +2693,11 @@ def learn(self, *args, **kwargs)->list: """Override to check for inappropriate use of ARG_MAX or PROBABILISTIC options for retrieval with learning""" softmax_choice = self.parameters.softmax_choice.get(kwargs[CONTEXT]) use_gating_for_weighting = self._use_gating_for_weighting - learn_field_weights = self.parameters.learn_field_weights.get(kwargs[CONTEXT]) + enable_learning = self.parameters.enable_learning.get(kwargs[CONTEXT]) - if use_gating_for_weighting and learn_field_weights: + if use_gating_for_weighting and enable_learning: raise EMCompositionError(f"Field weights cannot be learned when 'use_gating_for_weighting' is True; " - f"Construct '{self.name}' with the 'learn_field_weights' arg set to False.") + f"Construct '{self.name}' with the 'enable_learning' arg set to False.") if softmax_choice in {ARG_MAX, PROBABILISTIC}: raise EMCompositionError(f"The ARG_MAX and PROBABILISTIC options for the 'softmax_choice' arg " @@ -2646,19 +2716,19 @@ def _get_execution_mode(self, execution_mode): return execution_mode def _identify_target_nodes(self, context)->list: - """Identify retrieval_nodes specified by **enable_learning** as TARGET nodes""" - enable_learning = self.parameters.enable_learning._get(context) - if enable_learning is False: - if self.learn_field_weights: - warnings.warn(f"The 'learn_field_weights' arg for {self.name} is True " - f"but its 'enable_learning' is False, so learn_field_weights will have no effect.") + """Identify retrieval_nodes specified by **target_field_weights** as TARGET nodes""" + target_fields = self.target_fields + if target_fields is False: + if self.enable_learning: + warnings.warn(f"The 'enable_learning' arg for {self.name} is True " + f"but its 'target_fields' is False, so enable_learning will have no effect.") target_nodes = [] - elif enable_learning is True: + elif target_fields is True: target_nodes = [node for node in self.retrieved_nodes] - elif isinstance(enable_learning, list): - target_nodes = [node for node in self.retrieved_nodes if enable_learning[self.retrieved_nodes.index(node)]] + elif isinstance(target_fields, list): + target_nodes = [node for node in self.retrieved_nodes if target_fields[self.retrieved_nodes.index(node)]] else: - assert False, (f"PROGRAM ERROR: enable_learning arg for {self.name}: {enable_learning} " + assert False, (f"PROGRAM ERROR: target_fields arg for {self.name}: {target_fields} " f"is neither True, False nor a list of bools as it should be.") super()._identify_target_nodes(context) return target_nodes @@ -2666,7 +2736,7 @@ def _identify_target_nodes(self, context)->list: def infer_backpropagation_learning_pathways(self, execution_mode, context=None): if self.concatenate_queries: raise EMCompositionError(f"EMComposition does not support learning with 'concatenate_queries'=True.") - super().infer_backpropagation_learning_pathways(execution_mode, context=context) + return super().infer_backpropagation_learning_pathways(execution_mode, context=context) def do_gradient_optimization(self, retain_in_pnl_options, context, optimization_num=None): # 7/10/24 - MAKE THIS CONTEXT DEPENDENT: CALL super() IF BEING EXECUTED ON ITS OWN? diff --git a/psyneulink/library/compositions/pytorchEMcompositionwrapper.py b/psyneulink/library/compositions/pytorchEMcompositionwrapper.py index 38c67017cac..fca4856e4eb 100644 --- a/psyneulink/library/compositions/pytorchEMcompositionwrapper.py +++ b/psyneulink/library/compositions/pytorchEMcompositionwrapper.py @@ -46,16 +46,16 @@ def __init__(self, *args, **kwargs): # ProjectionWrappers for match nodes learning_signals_for_match_nodes = pnl_storage_mech.learning_signals[:num_match_fields] pnl_match_projs = [match_node_learning_signal.efferents[0].receiver.owner - for match_node_learning_signal in learning_signals_for_match_nodes] + for match_node_learning_signal in learning_signals_for_match_nodes] self.match_projection_wrappers = [self.projections_map[pnl_match_proj] - for pnl_match_proj in pnl_match_projs] + for pnl_match_proj in pnl_match_projs] # ProjectionWrappers for retrieve nodes learning_signals_for_retrieve_nodes = pnl_storage_mech.learning_signals[num_match_fields:] pnl_retrieve_projs = [retrieve_node_learning_signal.efferents[0].receiver.owner - for retrieve_node_learning_signal in learning_signals_for_retrieve_nodes] + for retrieve_node_learning_signal in learning_signals_for_retrieve_nodes] self.retrieve_projection_wrappers = [self.projections_map[pnl_retrieve_proj] - for pnl_retrieve_proj in pnl_retrieve_projs] + for pnl_retrieve_proj in pnl_retrieve_projs] def execute_node(self, node, variable, optimization_num, context): """Override to handle storage of entry to memory_matrix by EMStorage Function""" @@ -134,19 +134,26 @@ def store_memory(self, memory_to_store, context): idx_of_weakest_memory = torch.argmin(row_norms) values = [] - for i, field_projection in enumerate(self.match_projection_wrappers + self.retrieve_projection_wrappers): - if i < num_match_fields: - # For match projections, get entry to store from value of sender of Projection matrix - # (this is to accomodate concatenation_node) - axis = 0 + for field_projection in self.match_projection_wrappers + self.retrieve_projection_wrappers: + field_idx = self._composition._field_index_map[field_projection._pnl_proj] + if field_projection in self.match_projection_wrappers: + # For match projections: + # - get entry to store from value of sender of Projection matrix (to accommodate concatenation_node) entry_to_store = field_projection.sender.output + # - store in row + axis = 0 if concatenation_node is None: - assert (entry_to_store == memory_to_store[i]).all(), \ - f"PROGRAM ERROR: misalignment between inputs and fields for storing them" + # Double check that the memory passed in is the output of the projection for the correct field + assert (entry_to_store == + memory_to_store[field_idx]).all(), \ + (f"PROGRAM ERROR: misalignment between memory to be stored (input passed to store_memory) " + f"and value of projection to corresponding field.") else: - # For retrieve projections, get entry to store from memory_to_store (which has inputs to all fields) + # For retrieve projections: + # - get entry to store from memory_to_store (which has inputs to all fields) + entry_to_store = memory_to_store[field_idx] + # - store in column axis = 1 - entry_to_store = memory_to_store[i - num_match_fields] # Get matrix containing memories for the field from the Projection field_memory_matrix = field_projection.matrix diff --git a/tests/composition/test_emcomposition.py b/tests/composition/test_emcomposition.py index d2af70bee95..e70a683a7ca 100644 --- a/tests/composition/test_emcomposition.py +++ b/tests/composition/test_emcomposition.py @@ -43,7 +43,7 @@ def test_two_calls_no_args(self): test_structure_data = [ # NOTE: None => use default value (i.e., don't specify in constructor, rather than forcing None as value of arg) # ------------------ SPECS --------------------------------------------- ------- EXPECTED ------------------- - # memory_template memory_fill field_wts cncat_ky nmlze sm_gain repeat #fields #keys #vals concat + # memory_template memory_fill field_wts cncat_qy nmlze sm_gain repeat #fields #keys #vals concat (0, (2,3), None, None, None, None, None, False, 2, 1, 1, False,), (0.1, (2,3), .1, None, None, None, None, False, 2, 1, 1, False,), (0.2, (2,3), (0,.1), None, None, None, None, False, 2, 1, 1, False,), @@ -61,35 +61,38 @@ def test_two_calls_no_args(self): (6, [[0,0,0],[0],[0,0]], None, [1,1,1], False, None, None, False, 3, 3, 0, False,), (7, [[0,0,0],[0],[0,0]], None, [1,1,1], True, None, None, False, 3, 3, 0, True,), (7.1, [[0,0,0],[0],[0,0]], None, [1,1,1], True , False, None, False, 3, 3, 0, False,), - (8, [[0,0],[0,0],[0,0]], None, [1,2,0], None, None, None, False, 3, 2, 1, False,), - (8.1, [[0,0],[0,0],[0,0]], None, [1,2,0], True, None, None, False, 3, 2, 1, False,), - (9, [[0,1],[0,0],[0,0]], None, [1,2,0], None, None, None, [0,1], 3, 2, 1, False,), - (9.1, [[0,1],[0,0,0],[0,0]], None, [1,2,0], None, None, None, [0,1], 3, 2, 1, False,), - (10, [[0,1],[0,0,0],[0,0]], .1, [1,2,0], None, None, None, [0,1], 3, 2, 1, False,), - (11, [[0,0],[0,0,0],[0,0]], .1, [1,2,0], None, None, None, False, 3, 2, 1, False,), + (8, [[0,0],[0,0],[0,0]], None, [1,2,None], None, None, None, False, 3, 2, 1, False,), + (8.1, [[0,0],[0,0],[0,0]], None, [1,2,None], True, None, None, False, 3, 2, 1, False,), + (8.2, [[0,0],[0,0],[0,0]], None, [1,1,None], True, None, None, False, 3, 2, 1, True,), + (8.3, [[0,0],[0,0],[0,0]], None, [1,1,0], True, None, None, False, 3, 3, 0, False,), + (8.4, [[0,0],[0,0],[0,0]], None, [0,0,0], True, None, None, False, 3, 3, 0, True,), + (9, [[0,1],[0,0],[0,0]], None, [1,2,None], None, None, None, [0,1], 3, 2, 1, False,), + (9.1, [[0,1],[0,0,0],[0,0]], None, [1,2,None], None, None, None, [0,1], 3, 2, 1, False,), + (10, [[0,1],[0,0,0],[0,0]], .1, [1,2,None], None, None, None, [0,1], 3, 2, 1, False,), + (11, [[0,0],[0,0,0],[0,0]], .1, [1,2,None], None, None, None, False, 3, 2, 1, False,), (12, [[[0,0],[0,0],[0,0]], # two entries specified, fields all same length, both entries have all 0's [[0,0],[0,0],[0,0]]], .1, [1,1,1], None, None, None, 2, 3, 3, 0, False,), (12.1, [[[0,0],[0,0,0],[0,0]], # two entries specified, fields have different lenghts, entries all have 0's - [[0,0],[0,0,0],[0,0]]], .1, [1,1,0], None, None, None, 2, 3, 2, 1, False,), + [[0,0],[0,0,0],[0,0]]], .1, [1,1,None], None, None, None, 2, 3, 2, 1, False,), (12.2, [[[0,0],[0,0,0],[0,0]], # two entries specified, first has 0's - [[0,2],[0,0,0],[0,0]]], .1, [1,1,0], None, None, None, 2, 3, 2, 1, False,), + [[0,2],[0,0,0],[0,0]]], .1, [1,1,None], None, None, None, 2, 3, 2, 1, False,), (12.3, [[[0,1],[0,0,0],[0,0]], # two entries specified, fields have same weights, but concatenate is False - [[0,2],[0,0,0],[0,0]]], .1, [1,1,0], None, None, None, 2, 3, 2, 1, False), + [[0,2],[0,0,0],[0,0]]], .1, [1,1,None], None, None, None, 2, 3, 2, 1, False), (13, [[[0,1],[0,0,0],[0,0]], # two entries specified, fields have same weights, and concatenate_queries is True - [[0,2],[0,0,0],[0,0]]], .1, [1,1,0], True, None, None, 2, 3, 2, 1, True), + [[0,2],[0,0,0],[0,0]]], .1, [1,1,None], True, None, None, 2, 3, 2, 1, True), (14, [[[0,1],[0,0,0],[0,0]], # two entries specified, all fields are keys [[0,2],[0,0,0],[0,0]]], .1, [1,1,1], None, None, None, 2, 3, 3, 0, False), (15, [[[0,1],[0,0,0],[0,0]], # two entries specified; fields have different weights, constant memory_fill - [[0,2],[0,0,0],[0,0]]], .1, [1,2,0], None, None, None, 2, 3, 2, 1, False), + [[0,2],[0,0,0],[0,0]]], .1, [1,2,None], None, None, None, 2, 3, 2, 1, False), (15.1, [[[0,1],[0,0,0],[0,0]], # two entries specified; fields have different weights, random memory_fill - [[0,2],[0,0,0],[0,0]]], (0,.1), [1,2,0], None, None, None, 2, 3, 2, 1, False), + [[0,2],[0,0,0],[0,0]]], (0,.1),[1,2,None], None, None, None, 2, 3, 2, 1, False), (16, [[[0,1],[0,0,0],[0,0]], # three entries specified [[0,2],[0,0,0],[0,0]], - [[0,3],[0,0,0],[0,0]]], .1, [1,2,0], None, None, None, 3, 3, 2, 1, False), + [[0,3],[0,0,0],[0,0]]], .1, [1,2,None], None, None, None, 3, 3, 2, 1, False), (17, [[[0,1],[0,0,0],[0,0]], # all four entries allowed by memory_capacity specified [[0,2],[0,0,0],[0,0]], [[0,3],[0,0,0],[0,0]], - [[0,4],[0,0,0],[0,0]]], .1, [1,2,0], None, None, None, 4, 3, 2, 1, False), + [[0,4],[0,0,0],[0,0]]], .1, [1,2,None], None, None, None, 4, 3, 2, 1, False), ] args_names = "test_num, memory_template, memory_fill, field_weights, concatenate_queries, normalize_memories, " \ "softmax_gain, repeat, num_fields, num_keys, num_values, concatenate_node" @@ -244,14 +247,204 @@ def test_softmax_choice_error(self, softmax_choice): em.parameters.softmax_choice.set(softmax_choice) em.learn() - @pytest.mark.parametrize("softmax_choice", [pnl.ARG_MAX, pnl.PROBABILISTIC]) - def test_softmax_choice_warn(self, softmax_choice): - warning_msg = (f"The 'softmax_choice' arg of '.*' is set to '{softmax_choice}' with " - f"'enable_learning' set to True \\(or a list\\); this will generate an error if its " - f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") + for softmax_choice in [pnl.ARG_MAX, pnl.PROBABILISTIC]: + with pytest.warns(UserWarning) as warning: + em = EMComposition(softmax_choice=softmax_choice, enable_learning=True) + warning_msg = (f"The 'softmax_choice' arg of '{em.name}' is set to '{softmax_choice}' with " + f"'enable_learning' set to True; this will generate an error if its " + f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") + assert warning_msg in str(warning[0].message) + + def test_fields_arg(self): + + em = EMComposition(memory_template=(5,1), + memory_capacity=1, + normalize_field_weights=False, + fields={'A': (1.2, 3.4, True), + 'B': (None, False, True), + 'C': (0, True, True), + 'D': (7.8, False, True), + 'E': (5.6, True, True)}) + assert em.num_fields == 5 + assert em.num_keys == 4 + assert (em.field_weights == [1.2, None, 0, 7.8, 5.6]).all() + assert (em.learn_field_weights == [3.4, False, True, False, True]).all() + np.testing.assert_allclose(em.target_fields, [True, True, True, True, True]) + + # # Test wrong number of entries + with pytest.raises(EMCompositionError) as error_text: + EMComposition(memory_template=(3,1), memory_capacity=1, fields={'A': (1.2, 3.4)}) + assert error_text.value.error_value == (f"The number of entries (1) in the dict specified in the 'fields' arg " + f"of 'EM_Composition' does not match the number of fields in its " + f"memory (3).") + # Test dual specification of fields and corresponding args and learning specified for value field + with pytest.warns(UserWarning) as warning: + EMComposition(memory_template=(2,1), + memory_capacity=1, + fields={'A': (1.2, 3.4, True), + 'B': (None, True, True)}, + field_weights=[10, 11.0]) + warning_msg_1 = (f"The 'fields' arg for 'EM_Composition' was specified, so any of the 'field_names', " + f"'field_weights', 'learn_field_weights' or 'target_fields' args will be ignored.") + warning_msg_2 = (f"Learning was specified for field 'B' in the 'learn_field_weights' arg for " + f"'EM_Composition', but it is not allowed for value fields; it will be ignored.") + assert warning_msg_1 in str(warning[0].message) + assert warning_msg_2 in str(warning[1].message) + + + + field_names = ['KEY A','VALUE A', 'KEY B','KEY VALUE','VALUE LEARN'] + field_weights = [1, None, 2, 0, None] + learn_field_weights = [True, False, .01, False, False] + target_fields = [True, False, False, True, True] + dict_subdict = {} + for i, fn in enumerate(field_names): + dict_subdict[fn] = {pnl.FIELD_WEIGHT: field_weights[i], + pnl.LEARN_FIELD_WEIGHT: learn_field_weights[i], + pnl.TARGET_FIELD: target_fields[i]} + dict_tuple = {fn:(fw,lfw,tf) for fn,fw,lfw,tf in zip(field_names, + field_weights, + learn_field_weights, + target_fields)} + test_field_map_and_args_assignment_data = [ + ('args', None, field_names, field_weights, learn_field_weights, target_fields), + ('dict-subdict', dict_subdict, None, None, None, None), + ('dict-tuple', dict_tuple, None, None, None, None)] + field_arg_names = "format, fields, field_names, field_weights, learn_field_weights, target_fields" + + @pytest.mark.parametrize(field_arg_names, test_field_map_and_args_assignment_data, + ids=[x[0] for x in test_field_map_and_args_assignment_data]) + def test_field_args_and_map_assignments(self, + format, + fields, + field_names, + field_weights, + learn_field_weights, + target_fields): + # individual args + em = EMComposition(memory_template=(5,2), + memory_capacity=2, + fields=fields, + field_names=field_names, + field_weights=field_weights, + learn_field_weights=learn_field_weights, + target_fields=target_fields, + learning_rate=0.5) + assert em.num_fields == 5 + assert em.num_keys == 3 + for actual, expected in zip(em.field_weights, [0.33333333, None, 0.66666667, 0, None]): + if expected is None: + assert actual is None + else: + np.testing.assert_allclose(actual, expected) + + # Validate targets for target_fields + np.testing.assert_allclose(em.target_fields, [True, False, False, True, True]) + learning_components = em.infer_backpropagation_learning_pathways(pnl.ExecutionMode.PyTorch) + assert len(learning_components) == 3 + assert 'TARGET for KEY A [RETRIEVED]' in learning_components[0].name + assert 'TARGET for KEY VALUE [RETRIEVED]' in learning_components[1].name + assert 'TARGET for VALUE LEARN [RETRIEVED]' in learning_components[2].name + + # Validate learning specs for field weights + # Presence or absence of field weight components based on keys vs. values: + assert ['KEY A [WEIGHT]' in node.name for node in em.nodes] + assert ['KEY B [WEIGHT]' in node.name for node in em.nodes] + assert ['KEY VALUE [WEIGHT]' in node.name for node in em.nodes] + assert not any('VALUE A [WEIGHT]' in node.name for node in em.nodes) + assert not any('VALUE LEARN [WEIGHT]' in node.name for node in em.nodes) + assert not any('WEIGHT to WEIGHTED MATCH for VALUE A' in proj.name for proj in em.projections) + assert not any('WEIGHT to WEIGHTED MATCH for VALUE LEARN' in proj.name for proj in em.projections) + # Learnability and learning rate for field weights + # FIX: ONCE LEARNING IS FULLY IMPLEMENTED FOR FIELD WEIGHTS, VALIDATE THAT: + # KEY A USES COMPOSITION DEFAULT LEARNING RATE OF .5 + # KEY B USES INDIVIDUALLY ASSIGNED LEARNING RATE OF .01 + assert em.learn_field_weights == [True, False, .01, False, False] + assert em.projections['WEIGHT to WEIGHTED MATCH for KEY A'].learnable + assert em.projections['WEIGHT to WEIGHTED MATCH for KEY B'].learnable + assert not em.projections['WEIGHT to WEIGHTED MATCH for KEY VALUE'].learnable - with pytest.warns(UserWarning, match=warning_msg): - EMComposition(softmax_choice=softmax_choice, enable_learning=True) + # Validate _field_index_map + assert em._field_index_map[[k for k in em._field_index_map.keys() + if ('MappingProjection from KEY A [QUERY][OutputPort-0] to STORE[InputPort-0]') + in k.name][0]]==0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY A [QUERY]' in k.name][0]]==0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY A [MATCH to KEYS]' in k.name][0]]==0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY A [WEIGHTED MATCH]' in k.name][0]]==0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY A [RETRIEVED]' in k.name][0]]==0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'MEMORY FOR KEY A [RETRIEVE KEY]' + in k.name][0]]==0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'VALUE A [VALUE]' in k.name][0]] == 1 + assert em._field_index_map[[k for k in em._field_index_map.keys() if + ('VALUE A [VALUE][OutputPort-0] to STORE[InputPort-1]') in k.name][0]] == 1 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'VALUE A [RETRIEVED]' in k.name][0]] == 1 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MEMORY FOR VALUE A' in k.name][0]] == 1 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY B [QUERY]' in k.name][0]] == 2 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if ('KEY B [QUERY][OutputPort-0] to STORE[InputPort-2]') in k.name][0]] == 2 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY B [RETRIEVED]' in k.name][0]] == 2 + assert (em._field_index_map[[k for k in em._field_index_map.keys() + if 'MEMORY FOR KEY B [RETRIEVE KEY]' in k.name][0]] == 2) + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY VALUE [QUERY]' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'KEY VALUE [QUERY][OutputPort-0] to STORE[InputPort-3]' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY VALUE [RETRIEVED]' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MEMORY FOR KEY VALUE [RETRIEVE KEY]' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'VALUE LEARN [VALUE]' in k.name][0]] == 4 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'VALUE LEARN [VALUE][OutputPort-0] to STORE[InputPort-4]' in k.name][0]] == 4 + assert (em._field_index_map[[k for k in em._field_index_map.keys() + if 'VALUE LEARN [RETRIEVED]' in k.name][0]] == 4) + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'VALUE LEARN [VALUE]' in k.name][0]] == 4 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MEMORY FOR VALUE LEARN [RETRIEVE VALUE]' in k.name][0]] == 4 + assert (em._field_index_map[[k for k in em._field_index_map.keys() + if 'MEMORY for KEY A [KEY]' in k.name][0]] == 0) + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MATCH to WEIGHTED MATCH for KEY A' in k.name][0]] == 0 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'WEIGHTED MATCH for KEY A to COMBINE MATCHES' in k.name][0]] == 0 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY B [MATCH to KEYS]' in k.name][0]] == 2 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MEMORY for KEY B [KEY]' in k.name][0]] == 2 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MATCH to WEIGHTED MATCH for KEY B' in k.name][0]] == 2 + assert (em._field_index_map[[k for k in em._field_index_map.keys() + if 'KEY B [WEIGHTED MATCH]' in k.name][0]] == 2) + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'WEIGHTED MATCH for KEY B to COMBINE MATCHES' in k.name][0]] == 2 + assert (em._field_index_map[[k for k in em._field_index_map.keys() + if 'KEY VALUE [MATCH to KEYS]' in k.name][0]] == 3) + assert em._field_index_map[[k for k in em._field_index_map.keys() if + 'MEMORY for KEY VALUE [KEY]' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'MATCH to WEIGHTED MATCH for KEY VALUE' in k.name][0]] == 3 + assert (em._field_index_map[[k for k in em._field_index_map.keys() + if 'KEY VALUE [WEIGHTED MATCH]' in k.name][0]] == 3) + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'WEIGHTED MATCH for KEY VALUE to COMBINE MATCHES' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY B [WEIGHT]' in k.name][0]] == 2 + assert em._field_index_map[[k for k in em._field_index_map.keys() if 'KEY VALUE [WEIGHT]' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'WEIGHT to WEIGHTED MATCH for KEY VALUE' in k.name][0]] == 3 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'WEIGHT to WEIGHTED MATCH for KEY A' in k.name][0]] == 0 + assert em._field_index_map[[k for k in em._field_index_map.keys() + if 'WEIGHT to WEIGHTED MATCH for KEY B' in k.name][0]] == 2 + + def test_field_weights_all_None_and_or_0(self): + with pytest.raises(EMCompositionError) as error_text: + EMComposition(memory_template=(3,1), memory_capacity=1, field_weights=[None, None, None]) + assert error_text.value.error_value == (f"The entries in 'field_weights' arg for EM_Composition can't all " + f"be 'None' since that will preclude the construction of any keys.") + + with pytest.warns(UserWarning) as warning: + EMComposition(memory_template=(3,1), memory_capacity=1, field_weights=[0, None, 0]) + warning_msg = (f"All of the entries in the 'field_weights' arg for EM_Composition are either None or set to 0; " + f"this will result in no retrievals unless/until the 0(s) is(are) changed to a positive value.") + assert warning_msg in str(warning[0].message) @pytest.mark.pytorch @@ -265,21 +458,21 @@ class TestExecution: # ---------------------------------------------------------------------------------- ------------------------ (0, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[[1, 2, 3]]], [[1., 2., 3.16585899], - [4., 6.16540637]]), + [[1,2,10],[4,10]]], None, 3, 0, [1,None], None, None, 100, 0, [[[1, 2, 3]]], [[1., 2., 3.16585899], + [4., 6.16540637]]), (1, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[1, 2, 3], - [4, 6]], [[1., 2., 3.16585899], - [4., 6.16540637]]), + [[1,2,10],[4,10]]], None, 3, 0, [1,None], None, None, 100, 0, [[1, 2, 3], + [4, 6]], [[1., 2., 3.16585899], + [4., 6.16540637]]), (2, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], None, 3, 0, [1,0], None, None, 100, 0, [[1, 2, 3], - [4, 8]], [[1., 2., 3.16585899], - [4., 6.16540637]]), + [[1,2,10],[4,10]]], None, 3, 0, [1,None], None, None, 100, 0, [[1, 2, 3], + [4, 8]], [[1., 2., 3.16585899], + [4., 6.16540637]]), (3, [[[1,2,3],[4,6]], [[1,2,5],[4,8]], - [[1,2,10],[4,10]]], (0,.01), 4, 0, [1,0], None, None, 100, 0, [[1, 2, 3], + [[1,2,10],[4,10]]], (0,.01), 4, 0, [1,None], None, None, 100, 0, [[1, 2, 3], [4, 8]], [[0.99998628, 1.99997247, 3.1658154 ], @@ -352,11 +545,11 @@ class TestExecution: 6.38682264]]), (12, [[[1],[2],[3]], # Scalar keys - exact match (this tests use of L0 for retreieval in MEMORY matrix) - [[10],[0],[100]]], (0,.01), 3, 0, [1,1,0], None, None, pnl.ARG_MAX, 1, [[10],[0],[100]], + [[10],[0],[100]]], (0,.01), 3, 0, [1,1,None], None, None, pnl.ARG_MAX, 1, [[10],[0],[100]], [[10],[0],[100]]), (13, [[[1],[2],[3]], # Scalar keys - close match (this tests use of L0 for retreieval in MEMORY matrix - [[10],[0],[100]]], (0,.01), 3, 0, [1,1,0], None, None, pnl.ARG_MAX, 1, [[2],[3],[4]], [[1],[2],[3]]), + [[10],[0],[100]]], (0,.01), 3, 0, [1,1,None], None, None, pnl.ARG_MAX, 1, [[2],[3],[4]], [[1],[2],[3]]), ] args_names = "test_num, memory_template, memory_fill, memory_capacity, memory_decay_rate, field_weights, " \ @@ -364,13 +557,13 @@ class TestExecution: @pytest.mark.parametrize(args_names, test_execution_data, ids=[x[0] for x in test_execution_data]) - @pytest.mark.parametrize('enable_learning', [False, True], ids=['no_learning','learning']) + @pytest.mark.parametrize('learn_field_weights', [False, True], ids=['no_learning','learning']) @pytest.mark.composition @pytest.mark.parametrize('exec_mode', [pnl.ExecutionMode.Python, pnl.ExecutionMode.PyTorch], ids=['Python','PyTorch']) def test_simple_execution_without_learning(self, exec_mode, - enable_learning, + learn_field_weights, test_num, memory_template, memory_capacity, @@ -388,12 +581,12 @@ def test_simple_execution_without_learning(self, # # pytest.skip('Execution of EMComposition not yet supported for LLVM Mode.') # Restrict testing of learning configurations (which are much larger) to select tests - if enable_learning and test_num not in {10}: + if learn_field_weights and test_num not in {10}: pytest.skip('Limit tests of learning to subset of parametrizations (for efficiency)') params = {'memory_template': memory_template, 'memory_capacity': memory_capacity, - 'enable_learning': enable_learning, + 'learn_field_weights': learn_field_weights, } # Add explicit argument specifications only for args that are not None # (to avoid forcing to None in constructor) @@ -406,7 +599,7 @@ def test_simple_execution_without_learning(self, if concatenate_queries is not None: params.update({'concatenate_queries': concatenate_queries}) # FIX: DELETE THE FOLLOWING ONCE CONCATENATION IS IMPLEMENTED FOR LEARNING - params.update({'enable_learning': False}) + params.update({'learn_field_weights': False}) if normalize_memories is not None: params.update({'normalize_memories': normalize_memories}) if softmax_gain is not None: @@ -470,58 +663,129 @@ def test_simple_execution_without_learning(self, memory_fill = memory_fill or 0 assert all(elem == memory_fill for elem in em.memory[-1]) - @pytest.mark.parametrize('data', - (([[[5], [0], [10]], # 1d template + @pytest.mark.parametrize('test_field_weights_0_vs_None_data', + (([[[5], [0], [10]], # 1d memory template [[0], [5], [10]], [[0.1], [0.1], [10]], [[0.1], [0.1], [10]]], - [[5], [5], [10]], # 1d query - pnl.L0 # 1d retrieval operation + [[5], [5], [10]], # 1d query + pnl.L0), # 1d retrieval operation + ([[[5,0], [0,5], [10,10]], # 2d memory template + [[0,5], [5,0], [10,10]], + [[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]], + [[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]]], + [[5,0], [5,0], [10,10]], # 2d query + pnl.DOT_PRODUCT), # 2d retrieval operation ), - ([[[5,0], [0,5], [10]], # 2d template - [[0,5], [5,0], [10]], - [[0.1, 0.1], [0.1, 0.1], [0.1]], - [[0.1, 0.1], [0.1, 0.1], [0.1]]], - [[5,0], [5,0], [10]], # 2d query - pnl.DOT_PRODUCT)), # 2d retrieval operation ids=['1d', '2d']) + @pytest.mark.parametrize('field_weights', [[.75, .25, 0], [.75, .25, None]], ids=['0','None']) + @pytest.mark.parametrize('softmax_choice', [pnl.MAX_VAL, pnl.ARG_MAX], ids=['MAX_VAL','ARG_MAX']) + @pytest.mark.parametrize('exec_mode', [pnl.ExecutionMode.Python, + pnl.ExecutionMode.PyTorch, + # pnl.ExecutionMode.LLVM + ], + ids=['Python', + 'PyTorch', + # 'LLVM' + ]) @pytest.mark.composition - @pytest.mark.parametrize('exec_mode', [pnl.ExecutionMode.Python, pnl.ExecutionMode.PyTorch]) - def test_em_field_weights_assignment(self, exec_mode, data): - EM_assign_template = data[0] - em = pnl.EMComposition(memory_template=EM_assign_template, + def test_assign_field_weights_and_0_vs_None(self, + field_weights, + softmax_choice, + test_field_weights_0_vs_None_data, + exec_mode): + memory_template = test_field_weights_0_vs_None_data[0] + query = test_field_weights_0_vs_None_data[1] + operation = test_field_weights_0_vs_None_data[2] + + em = pnl.EMComposition(memory_template=memory_template, memory_capacity=4, memory_decay_rate= 0, - memory_fill=0.001, - enable_learning = False, - softmax_choice=pnl.ARG_MAX, - field_weights=(.75,.25,0), + learn_field_weights = False, + softmax_choice=softmax_choice, + field_weights=field_weights, field_names=['A','B','C']) - # Confirm initial weight assginments (that favor A) + # Confirm initial weight assignments (that favor A) assert em.nodes['A [WEIGHT]'].input_port.defaults.variable == [.75] assert em.nodes['B [WEIGHT]'].input_port.defaults.variable == [.25] + if field_weights[2] == 0: + assert 'C [QUERY]' in em.nodes.names + assert len(em.field_weight_nodes) == 3 + assert em.nodes['C [WEIGHT]'].input_port.defaults.variable == [0] + elif field_weights[2] is None: + assert 'C [VALUE]' in em.nodes.names + assert len(em.field_weight_nodes) == 2 + assert 'C [WEIGHT]' not in em.nodes.names + # Confirm use of L0 for retrieval since keys for A and B are scalars - assert em.projections['MEMORY for A [KEY]'].function.operation == data[2] - assert em.projections['MEMORY for B [KEY]'].function.operation == data[2] - # Change fields weights to favor B - em.field_weights = [0,1,0] - # Ensure weights got changed - assert em.nodes['A [WEIGHT]'].input_port.defaults.variable == [0] - assert em.nodes['B [WEIGHT]'].input_port.defaults.variable == [1] - # Note: The input matches both fields A and B; - test_input = {em.nodes['A [QUERY]']: [data[1][0]], - em.nodes['B [QUERY]']: [data[1][1]], - em.nodes['C [VALUE]']: [data[1][2]]} - result = em.run(test_input, execution_mode=exec_mode) - # If the weights change DIDN'T get used, it should favor field A and return [5,0,10] as the best match - # If weights change DID get used, it should favor field B and return [0,5,10] as the best match - for i,j in zip(result, data[0][1]): - assert (i == j).all() - # Change weights back and confirm that it now favors A - em.field_weights = [1,0,0] + assert em.projections['MEMORY for A [KEY]'].function.operation == operation + assert em.projections['MEMORY for B [KEY]'].function.operation == operation + if field_weights[2] == 0: + assert em.projections['MEMORY for C [KEY]'].function.operation == operation + + A = em.nodes['A [QUERY]'] + B = em.nodes['B [QUERY]'] + C = em.nodes['C [QUERY]' if field_weights[2] == 0 else 'C [VALUE]'] + + # Note: The input matches both fields A and B + test_input = {A: [query[0]], + B: [query[1]], + C: [query[2]]} result = em.run(test_input, execution_mode=exec_mode) - for i,j in zip(result, data[0][0]): - assert (i == j).all() + # Note: field_weights favors A + if softmax_choice == pnl.MAX_VAL: + if operation == pnl.L0: + expected = [[1.70381182], [0.], [3.40762364]] + else: + expected = [[1.56081243, 0.0], [0.0, 1.56081243], [3.12162487, 3.12162487]] + else: + expected = memory_template[0] + np.testing.assert_allclose(result, expected) + + # Change fields weights to favor C + if field_weights[2] is None: + with pytest.raises(EMCompositionError) as error_text: + em.field_weights = np.array([0,0,1]) + assert error_text.value.error_value == (f"Field 'C' of 'EM_Composition' was originally assigned " + f"as a value node (i.e., with a field_weight = None); " + f"this cannot be changed after construction. If you want to " + f"change it to a key field, you must re-construct the " + f"EMComposition using a scalar for its field in the " + f"`field_weights` arg (including 0.") + else: + em.field_weights = np.array([0,0,1]) + # Ensure weights got changed + assert em.nodes['A [WEIGHT]'].input_port.defaults.variable == [0] + assert em.nodes['B [WEIGHT]'].input_port.defaults.variable == [0] + assert em.nodes['C [WEIGHT]'].input_port.defaults.variable == [1] + # Note: The input matches both fields A and B; + test_input = {em.nodes['A [QUERY]']: [query[0]], + em.nodes['B [QUERY]']: [query[1]], + em.nodes['C [QUERY]']: [query[2]]} + result = em.run(test_input, execution_mode=exec_mode) + # If the weights change DIDN'T get used, it should favor field A and return [5,0,10] as the best match + # If weights change DID get used, it should favor field B and return [0,5,10] as the best match + if softmax_choice == pnl.MAX_VAL: + if operation == pnl.L0: + expected = [[2.525], [2.525], [10]] + else: + expected = [[2.525, 1.275], [2.525, 1.275], [7.525, 7.525]] + else: + expected = memory_template[0] + np.testing.assert_allclose(result, expected) + + # Change weights back and confirm that it now favors A + em.field_weights = [0,1,0] + result = em.run(test_input, execution_mode=exec_mode) + if softmax_choice == pnl.MAX_VAL: + if operation == pnl.L0: + expected = [[3.33333333], [5], [10]] + else: + expected = [[3.33333333, 1.66666667], [5, 0], [10, 10]] + else: + expected = memory_template[1] + np.testing.assert_allclose(result, expected) + @pytest.mark.composition @pytest.mark.parametrize('exec_mode', [pnl.ExecutionMode.Python, pnl.ExecutionMode.PyTorch]) @@ -540,7 +804,9 @@ def test_multiple_trials_concatenation_and_storage_node(self, exec_mode, concate softmax_gain=100, memory_fill=(0,.001), concatenate_queries=concatenate, - enable_learning=learning, + # learn_field_weights=learning, + learn_field_weights=False, + enable_learning=True, use_storage_node=use_storage_node) inputs = [[[[1,2,3]],[[4,5,6]],[[10,20,30]],[[40,50,60]],[[100,200,300]],[[400,500,600]]], @@ -574,3 +840,221 @@ def test_multiple_trials_concatenation_and_storage_node(self, exec_mode, concate [[2.5, 3.125, 3.75 ], [2.5625, 3.1875, 3.8125]]] em.learn(inputs=inputs, execution_mode=exec_mode) np.testing.assert_equal(em.memory, expected_memory) + + @pytest.mark.composition + def test_backpropagation_of_error_in_learning(self): + """This test is based on the EGO CSW Model""" + + import torch + torch.manual_seed(0) + state_input_layer = pnl.ProcessingMechanism(name='STATE', input_shapes=11) + previous_state_layer = pnl.ProcessingMechanism(name='PREVIOUS STATE', input_shapes=11) + context_layer = pnl.TransferMechanism(name='CONTEXT', + input_shapes=11, + function=pnl.Tanh, + integrator_mode=True, + integration_rate=.69) + em = EMComposition(name='EM', + memory_template=[[0] * 11, [0] * 11, [0] * 11], # context + memory_fill=(0,.0001), + memory_capacity=50, + memory_decay_rate=0, + softmax_gain=10, + softmax_threshold=.001, + fields = {'STATE': {pnl.FIELD_WEIGHT: None, + pnl.LEARN_FIELD_WEIGHT: False, + pnl.TARGET_FIELD: True}, + 'PREVIOUS_STATE': {pnl.FIELD_WEIGHT:.5, + pnl.LEARN_FIELD_WEIGHT: False, + pnl.TARGET_FIELD: False}, + 'CONTEXT': {pnl.FIELD_WEIGHT:.5, + pnl.LEARN_FIELD_WEIGHT: False, + pnl.TARGET_FIELD: False}}, + normalize_field_weights=True, + normalize_memories=False, + concatenate_queries=False, + enable_learning=True, + learning_rate=.5, + device=pnl.CPU + ) + prediction_layer = pnl.ProcessingMechanism(name='PREDICTION', input_shapes=11) + + QUERY = ' [QUERY]' + VALUE = ' [VALUE]' + RETRIEVED = ' [RETRIEVED]' + + # Pathways + state_to_previous_state_pathway = [state_input_layer, + pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX, + learnable=False), + previous_state_layer] + state_to_context_pathway = [state_input_layer, + pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX, + learnable=False), + context_layer] + state_to_em_pathway = [state_input_layer, + pnl.MappingProjection(sender=state_input_layer, + receiver=em.nodes['STATE' + VALUE], + matrix=pnl.IDENTITY_MATRIX, + learnable=False), + em] + previous_state_to_em_pathway = [previous_state_layer, + pnl.MappingProjection(sender=previous_state_layer, + receiver=em.nodes['PREVIOUS_STATE' + QUERY], + matrix=pnl.IDENTITY_MATRIX, + learnable=False), + em] + context_learning_pathway = [context_layer, + pnl.MappingProjection(sender=context_layer, + matrix=pnl.IDENTITY_MATRIX, + receiver=em.nodes['CONTEXT' + QUERY], + learnable=True), + em, + pnl.MappingProjection(sender=em.nodes['STATE' + RETRIEVED], + receiver=prediction_layer, + matrix=pnl.IDENTITY_MATRIX, + learnable=False), + prediction_layer] + + # Composition + EGO = pnl.AutodiffComposition([state_to_previous_state_pathway, + state_to_context_pathway, + state_to_em_pathway, + previous_state_to_em_pathway, + context_learning_pathway], + learning_rate=.5, + loss_spec=pnl.Loss.BINARY_CROSS_ENTROPY, + device=pnl.CPU) + + learning_components = EGO.infer_backpropagation_learning_pathways(pnl.ExecutionMode.PyTorch) + assert len(learning_components) == 1 + assert learning_components[0].name == 'TARGET for PREDICTION' + EGO.add_projection(pnl.MappingProjection(sender=state_input_layer, + receiver=learning_components[0], + learnable=False)) + + EGO.scheduler.add_condition(em, pnl.BeforeNodes(previous_state_layer, context_layer)) + + INPUTS = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]] + + result = EGO.learn(inputs={'STATE':INPUTS}, learning_rate=.5, execution_mode=pnl.ExecutionMode.PyTorch) + expected = [[ 0.00000000e+00, 1.35476414e-03, 1.13669378e-03, 2.20434260e-03, 6.61008388e-04, 9.88672202e-01, + 6.52088276e-04, 1.74149507e-03, 1.09769133e-03, 2.47971436e-03, 0.00000000e+00], + [ 0.00000000e+00, -6.75284069e-02, -1.28930436e-03, -2.10726610e-01, -1.41050716e-03, -5.92286989e-01, + -2.75196416e-03, -2.21010605e-03, -7.14369243e-03, -2.05167374e-02, 0.00000000e+00], + [ 0.00000000e+00, 1.18578255e-03, 1.29393181e-03, 1.35476414e-03, 1.13669378e-03, 2.20434260e-03, + 6.61008388e-04, 9.88672202e-01, 6.52088276e-04, 2.83918640e-03, 0.00000000e+00]] + np.testing.assert_allclose(result, expected) + + # Plot (for during debugging): + # + # TARGETS = [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + # [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]] + # + # fig, axes = plt.subplots(3, 1, figsize=(5, 12)) + # axes[0].imshow(EGO.projections[7].parameters.matrix.get(EGO.name), interpolation=None) + # axes[1].plot((1 - np.abs(EGO.results[1:50,2]-TARGETS[:49])).sum(-1)) + # axes[1].set_xlabel('Stimuli') + # axes[1].set_ylabel('loss_spec') + # axes[2].plot( (EGO.results[1:50,2]*TARGETS[:49]).sum(-1) ) + # axes[2].set_xlabel('Stimuli') + # axes[2].set_ylabel('Correct Logit') + # plt.suptitle(f"Blocked Training") + # plt.show()