This repo contains the implementation of Key Value Memory Networks for Directly Reading Documents in Tensorflow. The model is tested on bAbI.
-
There is a must-read tutorial on Memory Networks for NLP from Jason Weston @ ICML 2016.
-
[Video] [Slides] Sumit Chopra, from Facebook AI, gave a lecture about Reasoning, Attention and Memory at Deep Learning Summer School 2016.
git clone https://github.com/siyuanzhao/key-value-memory-networks.git
mkdir ./key-value-memory-networks/logs
mkdir ./key-value-memory-networks/data/
cd ./key-value-memory-networks/data
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
tar xzvf ./tasks_1-20_v1-2.tar.gz
cd ../
python single.py
Running a single bAbI task
# Train the model on a single task <task_id>
python single.py --task_id <task_id>
There are serval flags within single.py. Below is an example of training the model on task 20 with specific learning rate, feature_size and epochs.
python single.py --task_id 20 --learning_rate 0.005 --feature_size 40 --epochs 200
Check all avaiable flags with the following command.
python single.py -h
Running a joint model on all bAbI tasks
python joint.py
There are also serval flags within joint.py. Below is an example of training the joint model with specific learning rate, feature_size and epochs.
python joint.py --learning_rate 0.005 --feature_size 40 --epochs 200
Check all avaiable flags with the following command.
python joint.py -h
The model is jointly trained on 20 tasks (1k training examples / weakly supervised) with following hyperparameters.
- BATCH_SIZE=50
- EMBEDDING_SIZE=40
- EPOCHS=200
- EPSILON=0.1
- FEATURE_SIZE=50
- HOPS=3
- L2_LAMBDA=0.1
- LEARNING_RATE=0.001
- MAX_GRAD_NORM=20.0
- MEMORY_SIZE=50
- READER=bow
python joint.py
Task | Testing Accuracy | Training Accuracy | Validation Accuracy |
---|---|---|---|
1 | 1.00 | 1.00 | 1.00 |
2 | 0.80 | 0.87 | 0.85 |
3 | 0.66 | 0.77 | 0.69 |
4 | 0.73 | 0.79 | 0.74 |
5 | 0.84 | 0.91 | 0.80 |
6 | 0.98 | 0.99 | 0.98 |
7 | 0.83 | 0.85 | 0.80 |
8 | 0.89 | 0.92 | 0.86 |
9 | 0.98 | 0.99 | 0.96 |
10 | 0.85 | 0.96 | 0.89 |
11 | 0.97 | 0.98 | 0.99 |
12 | 0.99 | 0.99 | 1.00 |
13 | 0.99 | 0.99 | 1.00 |
14 | 0.80 | 0.90 | 0.84 |
15 | 0.56 | 0.57 | 0.45 |
16 | 0.46 | 0.48 | 0.37 |
17 | 0.57 | 0.72 | 0.70 |
18 | 0.93 | 0.95 | 0.92 |
19 | 0.10 | 0.11 | 0.06 |
20 | 0.98 | 0.99 | 0.99 |
- results on 10k training examples are here
- tensorflow 1.10
- scikit-learn 0.19
- six 1.10.0