Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…Link into patch/autodiff_pnl_showgraph
  • Loading branch information
jdcpni committed Nov 22, 2024
2 parents 55cef19 + ad1e743 commit d3221ca
Show file tree
Hide file tree
Showing 40 changed files with 2,699 additions and 1,319 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ def calc_prob(em_preds, test_ys):

# Names:
name = "EGO Model CSW",
em_name = "EM",
state_input_layer_name = "STATE",
previous_state_layer_name = "PREVIOUS STATE",
context_layer_name = 'CONTEXT',
em_name = "EM",
prediction_layer_name = "PREDICTION",

# Structural
state_d = 11, # length of state vector
previous_state_d = 11, # length of state vector
context_d = 11, # length of context vector
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
# memory_init = (0,.0001), # Initialize memory with random values in interval
memory_init = None, # Initialize with zeros
concatenate_queries = False,
# concatenate_queries = True,
memory_init = (0,.0001), # Initialize memory with random values in interval
# memory_init = None, # Initialize with zeros
# concatenate_queries = False,
concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
Expand All @@ -63,20 +63,23 @@ def calc_prob(em_preds, test_ys):

# Processing
integration_rate = .69, # rate at which state is integrated into new context
# state_weight = 1, # weight of the state used during memory retrieval
# state_weight =normalize_field_weightsnormalize_field_weights 1, # weight of the state used during memory retrieval
# context_weight = 1, # weight of the context used during memory retrieval
state_weight = .5, # weight of the state used during memory retrieval
previous_state_weight = .5, # weight of the state used during memory retrieval
context_weight = .5, # weight of the context used during memory retrieval
state_weight = None, # weight of the state used during memory retrieval
# normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
normalize_memories = False, # whether to normalize the memory during memory retrieval
# normalize_memories = True, # whether to normalize the memory during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_threshold = None, # threshold used to mask out small values in softmax
softmax_threshold = .001, # threshold used to mask out small values in softmax
enable_learning=[False, False, True], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
learn_field_weights = False,
# target_fields=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
enable_learning = True,
loss_spec = Loss.BINARY_CROSS_ENTROPY,
# loss_spec = Loss.MSE,
learning_rate = .5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
'CONTEXT',
'PREVIOUS STATE'],
start=0)
state_retrieval_weight = 0
state_retrieval_weight = None
RANDOM_WEIGHTS_INITIALIZATION=RandomMatrix(center=0.0, range=0.1) # Matrix spec used to initialize all Projections

if is_numeric_scalar(model_params['softmax_temperature']): # translate to gain of softmax retrieval function
Expand All @@ -194,7 +194,7 @@ def construct_model(model_name:str=model_params['name'],
state_size:int=model_params['state_d'],

# Previous state
previous_state_input_name:str=model_params['previous_state_layer_name'],
previous_state_name:str=model_params['previous_state_layer_name'],

# Context representation (learned):
context_name:str=model_params['context_layer_name'],
Expand All @@ -205,12 +205,15 @@ def construct_model(model_name:str=model_params['name'],
em_name:str=model_params['em_name'],
retrieval_softmax_gain=retrieval_softmax_gain,
retrieval_softmax_threshold=model_params['softmax_threshold'],
state_retrieval_weight:Union[float,int]=state_retrieval_weight,
previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'],
# state_retrieval_weight:Union[float,int]=state_retrieval_weight,
# previous_state_retrieval_weight:Union[float,int]=model_params['state_weight'],
state_retrieval_weight:Union[float,int]=model_params['state_weight'],
previous_state_retrieval_weight:Union[float,int]=model_params['previous_state_weight'],
context_retrieval_weight:Union[float,int]=model_params['context_weight'],
normalize_field_weights = model_params['normalize_field_weights'],
normalize_memories = model_params['normalize_memories'],
concatenate_queries = model_params['concatenate_queries'],
learn_field_weights = model_params['learn_field_weights'],
enable_learning = model_params['enable_learning'],
memory_capacity = memory_capacity,
memory_init=model_params['memory_init'],

Expand All @@ -219,7 +222,7 @@ def construct_model(model_name:str=model_params['name'],

# Learning
loss_spec=model_params['loss_spec'],
enable_learning=model_params['enable_learning'],
# target_fields=model_params['target_fields'],
learning_rate = model_params['learning_rate'],
device=model_params['device']

Expand All @@ -233,14 +236,16 @@ def construct_model(model_name:str=model_params['name'],
# ----------------------------------------------------------------------------------------------------------------

state_input_layer = ProcessingMechanism(name=state_input_name, input_shapes=state_size)
previous_state_layer = ProcessingMechanism(name=previous_state_input_name, input_shapes=state_size)
previous_state_layer = ProcessingMechanism(name=previous_state_name, input_shapes=state_size)
# context_layer = ProcessingMechanism(name=context_name, input_shapes=context_size)
context_layer = TransferMechanism(name=context_name,
input_shapes=context_size,
function=Tanh,
integrator_mode=True,
integration_rate=integration_rate)



em = EMComposition(name=em_name,
memory_template=[[0] * state_size, # state
[0] * state_size, # previous state
Expand All @@ -251,6 +256,15 @@ def construct_model(model_name:str=model_params['name'],
memory_decay_rate=0,
softmax_gain=retrieval_softmax_gain,
softmax_threshold=retrieval_softmax_threshold,
fields = {state_input_name: {FIELD_WEIGHT: state_retrieval_weight,
LEARN_FIELD_WEIGHT: False,
TARGET_FIELD: True},
previous_state_name: {FIELD_WEIGHT:previous_state_retrieval_weight,
LEARN_FIELD_WEIGHT: False,
TARGET_FIELD: False},
context_name: {FIELD_WEIGHT:context_retrieval_weight,
LEARN_FIELD_WEIGHT: False,
TARGET_FIELD: False}},
# Input Nodes:
# field_names=[state_input_name,
# previous_state_input_name,
Expand All @@ -260,19 +274,20 @@ def construct_model(model_name:str=model_params['name'],
# previous_state_retrieval_weight,
# context_retrieval_weight
# ),
field_names=[previous_state_input_name,
context_name,
state_input_name,
],
field_weights=(previous_state_retrieval_weight,
context_retrieval_weight,
state_retrieval_weight,
),
# field_names=[previous_state_input_name,
# context_name,
# state_input_name,
# ],
# field_weights=(previous_state_retrieval_weight,
# context_retrieval_weight,
# state_retrieval_weight,
# ),
normalize_field_weights=normalize_field_weights,
normalize_memories=normalize_memories,
concatenate_queries=concatenate_queries,
learn_field_weights=learn_field_weights,
learning_rate=learning_rate,
enable_learning=enable_learning,
learning_rate=learning_rate,
# target_fields=target_fields,
device=device
)

Expand Down Expand Up @@ -312,7 +327,7 @@ def construct_model(model_name:str=model_params['name'],
em]
previous_state_to_em_pathway = [previous_state_layer,
MappingProjection(sender=previous_state_layer,
receiver=em.nodes[previous_state_input_name+QUERY],
receiver=em.nodes[previous_state_name+QUERY],
matrix=IDENTITY_MATRIX,
learnable=False),
em]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import torch
from torch.utils.data import dataset
from random import randint

def one_hot_encode(labels, num_classes):
"""
One hot encode labels and convert to tensor.
"""
return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32)

class DeterministicCSWDataset(dataset.Dataset):
def __init__(self, n_samples_per_context, contexts_to_load) -> None:
super().__init__()
raw_xs = np.array([
[[9,1,3,5,7],[9,2,4,6,8]],
[[10,1,4,5,8],[10,2,3,6,7]]
])

item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True)
task_names = [0,1] # Flexible so these can be renamed later
task_indices = [task_names.index(name) for name in contexts_to_load]

context_indices = np.repeat(np.array(task_indices),n_samples_per_context)
self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11)

self.xs = self.xs.reshape((-1,11))
self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0)
context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context])
self.contexts = one_hot_encode(context_indices, len(task_names))

# Remove the last transition since there's no next state available
self.xs = self.xs[:-1]
self.ys = self.ys[:-1]
self.contexts = self.contexts[:-1]

def __len__(self):
return len(self.xs)

def __getitem__(self, idx):
return self.xs[idx], self.contexts[idx], self.ys[idx]

def generate_dataset(condition='Blocked'):
# Generate the dataset for either the blocked or interleaved condition
if condition=='Blocked':
contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)]
n_samples_per_context = [40,40,40,40] + [1]*40
elif condition == 'Interleaved':
contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)]
n_samples_per_context = [1]*160 + [1]*40
else:
raise ValueError(f'Unknown dataset condition: {condition}')

return DeterministicCSWDataset(n_samples_per_context, contexts_to_load)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# # 'show_pytorch': True, # show pytorch graph of model
# 'show_learning': True,
# # 'show_nested_args': {'show_learning': pnl.ALL},
# 'show_projections_not_in_composition': True,
# # 'show_projections_not_in_composition': True,
# # 'show_nested': {'show_node_structure': True},
# # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning
# # 'show_node_structure': True # show detailed view of node structures and projections
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from psyneulink.core.llvm import ExecutionMode
from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL



model_params = dict(

# Names:
name = "EGO Model CSW",
em_name = "EM",
state_input_layer_name = "STATE",
previous_state_layer_name = "PREVIOUS STATE",
context_layer_name = 'CONTEXT',
em_name = "EM",
prediction_layer_name = "PREDICTION",

# Structural
Expand All @@ -20,7 +22,6 @@
# memory_init = None, # Initialize with zeros
concatenate_queries = False,
# concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
curriculum_type = 'Blocked',
Expand All @@ -33,18 +34,19 @@
context_weight = 1, # weight of the context used during memory retrieval
# normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
normalize_memories = False, # whether to normalize the memory during memory retrieval
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
# softmax_threshold = None, # threshold used to mask out small values in softmax
softmax_threshold = .001, # threshold used to mask out small values in softmax
enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
# enable_learning=[True, True, True]
# enable_learning=True,
# enable_learning=False,
learn_field_weights = True,
# learn_field_weights = False,
# target_fields=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
# target_fields=[True, True, True]
# target_fields=True,
# target_fields=False,
enable_learning = True,
# enable_learning = False,
loss_spec = Loss.BINARY_CROSS_ENTROPY,
# loss_spec = Loss.CROSS_ENTROPY,
# loss_spec = Loss.MSE,
Expand All @@ -53,8 +55,8 @@
synch_weights = RUN,
synch_values = RUN,
synch_results = RUN,
execution_mode = ExecutionMode.Python,
# execution_mode = ExecutionMode.PyTorch,
# execution_mode = ExecutionMode.Python,
execution_mode = ExecutionMode.PyTorch,
device = CPU,
# device = MPS,
)
Expand Down
Loading

0 comments on commit d3221ca

Please sign in to comment.