Skip to content

Commit

Permalink
refactor/emcomposition_field_handling (#3122)
Browse files Browse the repository at this point in the history
* • emcomposition.py
  _parse_fields:  clean up assignment of self.num_fields

• test_emcomposition.py
  add test_assign_field_weights_and_0_vs_None()
  add test_field_weights_all_None_and_or_0

• emcomposition.py
  - revamp docstring to document new mods
  - add fields arg to specify field_naes, field_weights, learn_field_weights
  - implement fields arg to specify field_names, field_weights, learn_field_weights in dict format
  - implement support for field-specific learn_field_weight specifications
  - _identify_target_nodes():
    refactor to use target_fields instead of learn_field_weights
  - add target_fields to fields specification dict
  - add dict spec for entries in fields arg
  - start adding field_idx to all components
  - add self._field_index_map

• pytorchEMcompositionwrapper.py
  - store_memory(): use self._field_index_map to assign memories to fields

• test_emcomposition.py
  - test_backpropagation_of_error_in_learning(): use EGO model to test for error backpropagation through EMCompoistion
  - test_field_args_and_map_assignments(): flesh out _field_index_map validation

* • emcomposition.py
  - update docstring figs
  - add purge_by_field_weights Parameter

* • autodiffcomposition.py
  - infer_backpropagation_learning_pathways(): add NodeRole.BIAS to pathways consructed for learning
  • Loading branch information
jdcpni authored Nov 22, 2024
1 parent c0f73e2 commit ad1e743
Show file tree
Hide file tree
Showing 28 changed files with 1,930 additions and 1,038 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ def calc_prob(em_preds, test_ys):

# Names:
name = "EGO Model CSW",
em_name = "EM",
state_input_layer_name = "STATE",
previous_state_layer_name = "PREVIOUS STATE",
context_layer_name = 'CONTEXT',
em_name = "EM",
prediction_layer_name = "PREDICTION",

# Structural
state_d = 11, # length of state vector
previous_state_d = 11, # length of state vector
context_d = 11, # length of context vector
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
# memory_init = (0,.0001), # Initialize memory with random values in interval
memory_init = None, # Initialize with zeros
concatenate_queries = False,
# concatenate_queries = True,
memory_init = (0,.0001), # Initialize memory with random values in interval
# memory_init = None, # Initialize with zeros
# concatenate_queries = False,
concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
Expand All @@ -63,20 +63,23 @@ def calc_prob(em_preds, test_ys):

# Processing
integration_rate = .69, # rate at which state is integrated into new context
# state_weight = 1, # weight of the state used during memory retrieval
# state_weight =normalize_field_weightsnormalize_field_weights 1, # weight of the state used during memory retrieval
# context_weight = 1, # weight of the context used during memory retrieval
state_weight = .5, # weight of the state used during memory retrieval
previous_state_weight = .5, # weight of the state used during memory retrieval
context_weight = .5, # weight of the context used during memory retrieval
state_weight = None, # weight of the state used during memory retrieval
# normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
normalize_memories = False, # whether to normalize the memory during memory retrieval
# normalize_memories = True, # whether to normalize the memory during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_threshold = None, # threshold used to mask out small values in softmax
softmax_threshold = .001, # threshold used to mask out small values in softmax
enable_learning=[False, False, True], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
learn_field_weights = False,
# target_fields=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
enable_learning = True,
loss_spec = Loss.BINARY_CROSS_ENTROPY,
# loss_spec = Loss.MSE,
learning_rate = .5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
'CONTEXT',
'PREVIOUS STATE'],
start=0)
state_retrieval_weight = 0
state_retrieval_weight = None
RANDOM_WEIGHTS_INITIALIZATION=RandomMatrix(center=0.0, range=0.1) # Matrix spec used to initialize all Projections

if is_numeric_scalar(model_params['softmax_temperature']): # translate to gain of softmax retrieval function
Expand All @@ -194,7 +194,7 @@ def construct_model(model_name:str=model_params['name'],
state_size:int=model_params['state_d'],

# Previous state
previous_state_input_name:str=model_params['previous_state_layer_name'],
previous_state_name:str=model_params['previous_state_layer_name'],

# Context representation (learned):
context_name:str=model_params['context_layer_name'],
Expand All @@ -205,12 +205,15 @@ def construct_model(model_name:str=model_params['name'],
em_name:str=model_params['em_name'],
retrieval_softmax_gain=retrieval_softmax_gain,
retrieval_softmax_threshold=model_params['softmax_threshold'],
state_retrieval_weight:Union[float,int]=state_retrieval_weight,
previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'],
# state_retrieval_weight:Union[float,int]=state_retrieval_weight,
# previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'],
state_retrieval_weight:Union[float,int]=model_params['state_weight'],
previous_state_retrieval_weight:Union[float,int]=model_params['previous_state_weight'],
context_retrieval_weight:Union[float,int]=model_params['context_weight'],
normalize_field_weights = model_params['normalize_field_weights'],
normalize_memories = model_params['normalize_memories'],
concatenate_queries = model_params['concatenate_queries'],
learn_field_weights = model_params['learn_field_weights'],
enable_learning = model_params['enable_learning'],
memory_capacity = memory_capacity,
memory_init=model_params['memory_init'],

Expand All @@ -219,7 +222,7 @@ def construct_model(model_name:str=model_params['name'],

# Learning
loss_spec=model_params['loss_spec'],
enable_learning=model_params['enable_learning'],
# target_fields=model_params['target_fields'],
learning_rate = model_params['learning_rate'],
device=model_params['device']

Expand All @@ -233,14 +236,16 @@ def construct_model(model_name:str=model_params['name'],
# ----------------------------------------------------------------------------------------------------------------

state_input_layer = ProcessingMechanism(name=state_input_name, input_shapes=state_size)
previous_state_layer = ProcessingMechanism(name=previous_state_input_name, input_shapes=state_size)
previous_state_layer = ProcessingMechanism(name=previous_state_name, input_shapes=state_size)
# context_layer = ProcessingMechanism(name=context_name, input_shapes=context_size)
context_layer = TransferMechanism(name=context_name,
input_shapes=context_size,
function=Tanh,
integrator_mode=True,
integration_rate=integration_rate)



em = EMComposition(name=em_name,
memory_template=[[0] * state_size, # state
[0] * state_size, # previous state
Expand All @@ -250,6 +255,15 @@ def construct_model(model_name:str=model_params['name'],
memory_decay_rate=0,
softmax_gain=retrieval_softmax_gain,
softmax_threshold=retrieval_softmax_threshold,
fields = {state_input_name: {FIELD_WEIGHT: state_retrieval_weight,
LEARN_FIELD_WEIGHT: False,
TARGET_FIELD: True},
previous_state_name: {FIELD_WEIGHT:previous_state_retrieval_weight,
LEARN_FIELD_WEIGHT: False,
TARGET_FIELD: False},
context_name: {FIELD_WEIGHT:context_retrieval_weight,
LEARN_FIELD_WEIGHT: False,
TARGET_FIELD: False}},
# Input Nodes:
# field_names=[state_input_name,
# previous_state_input_name,
Expand All @@ -259,19 +273,20 @@ def construct_model(model_name:str=model_params['name'],
# previous_state_retrieval_weight,
# context_retrieval_weight
# ),
field_names=[previous_state_input_name,
context_name,
state_input_name,
],
field_weights=(previous_state_retrieval_weight,
context_retrieval_weight,
state_retrieval_weight,
),
# field_names=[previous_state_input_name,
# context_name,
# state_input_name,
# ],
# field_weights=(previous_state_retrieval_weight,
# context_retrieval_weight,
# state_retrieval_weight,
# ),
normalize_field_weights=normalize_field_weights,
normalize_memories=normalize_memories,
concatenate_queries=concatenate_queries,
learn_field_weights=learn_field_weights,
learning_rate=learning_rate,
enable_learning=enable_learning,
learning_rate=learning_rate,
# target_fields=target_fields,
device=device
)

Expand Down Expand Up @@ -311,7 +326,7 @@ def construct_model(model_name:str=model_params['name'],
em]
previous_state_to_em_pathway = [previous_state_layer,
MappingProjection(sender=previous_state_layer,
receiver=em.nodes[previous_state_input_name+QUERY],
receiver=em.nodes[previous_state_name+QUERY],
matrix=IDENTITY_MATRIX,
learnable=False),
em]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import torch
from torch.utils.data import dataset
from random import randint

def one_hot_encode(labels, num_classes):
"""
One hot encode labels and convert to tensor.
"""
return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32)

class DeterministicCSWDataset(dataset.Dataset):
def __init__(self, n_samples_per_context, contexts_to_load) -> None:
super().__init__()
raw_xs = np.array([
[[9,1,3,5,7],[9,2,4,6,8]],
[[10,1,4,5,8],[10,2,3,6,7]]
])

item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True)
task_names = [0,1] # Flexible so these can be renamed later
task_indices = [task_names.index(name) for name in contexts_to_load]

context_indices = np.repeat(np.array(task_indices),n_samples_per_context)
self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11)

self.xs = self.xs.reshape((-1,11))
self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0)
context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context])
self.contexts = one_hot_encode(context_indices, len(task_names))

# Remove the last transition since there's no next state available
self.xs = self.xs[:-1]
self.ys = self.ys[:-1]
self.contexts = self.contexts[:-1]

def __len__(self):
return len(self.xs)

def __getitem__(self, idx):
return self.xs[idx], self.contexts[idx], self.ys[idx]

def generate_dataset(condition='Blocked'):
# Generate the dataset for either the blocked or interleaved condition
if condition=='Blocked':
contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)]
n_samples_per_context = [40,40,40,40] + [1]*40
elif condition == 'Interleaved':
contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)]
n_samples_per_context = [1]*160 + [1]*40
else:
raise ValueError(f'Unknown dataset condition: {condition}')

return DeterministicCSWDataset(n_samples_per_context, contexts_to_load)
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

# Settings for running script:

MODEL_PARAMS = 'TestParams'
# MODEL_PARAMS = 'DeclanParams'
# MODEL_PARAMS = 'TestParams'
MODEL_PARAMS = 'DeclanParams'

CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script
DISPLAY_MODEL = ( # Only one of the following can be uncommented:
Expand All @@ -13,7 +13,7 @@
# # 'show_pytorch': True, # show pytorch graph of model
# 'show_learning': True,
# # 'show_nested_args': {'show_learning': pnl.ALL},
# 'show_projections_not_in_composition': True,
# # 'show_projections_not_in_composition': True,
# # 'show_nested': {'show_node_structure': True},
# # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning
# # 'show_node_structure': True # show detailed view of node structures and projections
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from psyneulink.core.llvm import ExecutionMode
from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL



model_params = dict(

# Names:
name = "EGO Model CSW",
em_name = "EM",
state_input_layer_name = "STATE",
previous_state_layer_name = "PREVIOUS STATE",
context_layer_name = 'CONTEXT',
em_name = "EM",
prediction_layer_name = "PREDICTION",

# Structural
Expand All @@ -20,7 +22,6 @@
# memory_init = None, # Initialize with zeros
concatenate_queries = False,
# concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
curriculum_type = 'Blocked',
Expand All @@ -33,18 +34,19 @@
context_weight = 1, # weight of the context used during memory retrieval
# normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
normalize_memories = False, # whether to normalize the memory during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_threshold = None, # threshold used to mask out small values in softmax
softmax_threshold = .001, # threshold used to mask out small values in softmax
enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
# enable_learning=[True, True, True]
# enable_learning=True,
# enable_learning=False,
learn_field_weights = True,
# learn_field_weights = False,
# target_fields=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
# target_fields=[True, True, True]
# target_fields=True,
# target_fields=False,
enable_learning = True,
# enable_learning = False,
loss_spec = Loss.BINARY_CROSS_ENTROPY,
# loss_spec = Loss.CROSS_ENTROPY,
# loss_spec = Loss.MSE,
Expand All @@ -53,8 +55,8 @@
synch_weights = RUN,
synch_values = RUN,
synch_results = RUN,
execution_mode = ExecutionMode.Python,
# execution_mode = ExecutionMode.PyTorch,
# execution_mode = ExecutionMode.Python,
execution_mode = ExecutionMode.PyTorch,
device = CPU,
# device = MPS,
)
Expand Down
Loading

0 comments on commit ad1e743

Please sign in to comment.