-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor/emcomposition_field_handling (#3122)
* • emcomposition.py _parse_fields: clean up assignment of self.num_fields • test_emcomposition.py add test_assign_field_weights_and_0_vs_None() add test_field_weights_all_None_and_or_0 • emcomposition.py - revamp docstring to document new mods - add fields arg to specify field_naes, field_weights, learn_field_weights - implement fields arg to specify field_names, field_weights, learn_field_weights in dict format - implement support for field-specific learn_field_weight specifications - _identify_target_nodes(): refactor to use target_fields instead of learn_field_weights - add target_fields to fields specification dict - add dict spec for entries in fields arg - start adding field_idx to all components - add self._field_index_map • pytorchEMcompositionwrapper.py - store_memory(): use self._field_index_map to assign memories to fields • test_emcomposition.py - test_backpropagation_of_error_in_learning(): use EGO model to test for error backpropagation through EMCompoistion - test_field_args_and_map_assignments(): flesh out _field_index_map validation * • emcomposition.py - update docstring figs - add purge_by_field_weights Parameter * • autodiffcomposition.py - infer_backpropagation_learning_pathways(): add NodeRole.BIAS to pathways consructed for learning
- Loading branch information
Showing
28 changed files
with
1,930 additions
and
1,038 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
Scripts/Models (Under Development)/EGO/Using EMComposition/Coffee Shop World/Environment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import dataset | ||
from random import randint | ||
|
||
def one_hot_encode(labels, num_classes): | ||
""" | ||
One hot encode labels and convert to tensor. | ||
""" | ||
return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32) | ||
|
||
class DeterministicCSWDataset(dataset.Dataset): | ||
def __init__(self, n_samples_per_context, contexts_to_load) -> None: | ||
super().__init__() | ||
raw_xs = np.array([ | ||
[[9,1,3,5,7],[9,2,4,6,8]], | ||
[[10,1,4,5,8],[10,2,3,6,7]] | ||
]) | ||
|
||
item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True) | ||
task_names = [0,1] # Flexible so these can be renamed later | ||
task_indices = [task_names.index(name) for name in contexts_to_load] | ||
|
||
context_indices = np.repeat(np.array(task_indices),n_samples_per_context) | ||
self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11) | ||
|
||
self.xs = self.xs.reshape((-1,11)) | ||
self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0) | ||
context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context]) | ||
self.contexts = one_hot_encode(context_indices, len(task_names)) | ||
|
||
# Remove the last transition since there's no next state available | ||
self.xs = self.xs[:-1] | ||
self.ys = self.ys[:-1] | ||
self.contexts = self.contexts[:-1] | ||
|
||
def __len__(self): | ||
return len(self.xs) | ||
|
||
def __getitem__(self, idx): | ||
return self.xs[idx], self.contexts[idx], self.ys[idx] | ||
|
||
def generate_dataset(condition='Blocked'): | ||
# Generate the dataset for either the blocked or interleaved condition | ||
if condition=='Blocked': | ||
contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)] | ||
n_samples_per_context = [40,40,40,40] + [1]*40 | ||
elif condition == 'Interleaved': | ||
contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)] | ||
n_samples_per_context = [1]*160 + [1]*40 | ||
else: | ||
raise ValueError(f'Unknown dataset condition: {condition}') | ||
|
||
return DeterministicCSWDataset(n_samples_per_context, contexts_to_load) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.