A neuro-symbolic language model, based on a base neural-LM and an automaton that retrieves examples from the training data. This is an official implementation of the model described in:
Uri Alon, Frank F. Xu, Junxian He, Sudipta Sengupta, Dan Roth, and Graham Neubig,
"Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval"
July 2022 - A Hugging Face 🤗 transformers
implementation of RetoMaton and kNN-LM is available at https://github.com/neulab/knn-transformers
June 2022 - Overview tweet!
May 2022 - The paper was accepted to ICML'2022! See you in Baltimore in July 2022 [Poster here]
April 2022 - a talk video is available here: [1-hour video] [5-min video]
Please let us know if anything here is not working as expected, and feel free to create new issues with any questions.
WikiText-103:
Law-MT, with a based LM that was trained on WMT News Crawl:
Law-MT, with a base LM that was fine-tuned on Law-MT:
Method | ppl | ppl, saving 50% of the searches |
---|---|---|
Fine-tuned LM | 8.61 | 8.61 |
kNN-LM | 7.93 | 8.25 |
AdaptRet baseline (He et al., 2021) | 7.81 | 7.91 |
RetoMaton (this work) | 7.10 | 7.15 |
- Overview
- Results
- Requirements
- Quickstart
- Step 1: Preparing the data
- Step 2: Downloading the base Language Model
- Step 3: Evaluating the base Language Model
- Step 4: Saving the keys and values for the datastore
- Step 5: Building the FAISS index
- Step 6: Evaluating RetoMaton without clustering
- Step 7: Adding clustering
- Step 8: Evaluating the Fine-tuned Model
- Lambda values
- All files
- Differences from the kNN-LM implementation
- Citation
This repository is a fork of the kNN-LM and based on the fairseq framework.
- This project is based on python3 and PyTorch 1.9.0. To check PyTorch version:
python3 -c 'import torch; print(torch.__version__)
- The project also depends on the
faiss
library. We recommend using the GPU version offaiss
:
pip install faiss-gpu
The CPU version can be installed using pip install faiss
.
On a Macbook, use the Anaconda installation instead:
conda install -c conda-forge pytorch faiss-cpu
- Finally, from this project's directory, run:
pip install --editable .
Experiments for this paper were conducted on a machine that contains 16GB of RAM, and a single NVIDIA RTX 3090 GPU.
Saving the Wikitext-103 datastore requires 200GB of disk space (in fp16, which does not degrade the performance compared to fp32).
git clone https://github.com/neulab/retomaton
cd retomaton
mkdir -p checkpoints/wt103
mkdir -p checkpoints/law
You can either download our preprocessed Wikitext-103 and Law-MT datasets, or preprocess them yourself.
wget https://retomaton.s3.us-east-2.amazonaws.com/wt103/wiki103_preprocessed.tar.gz
tar -xzvf wiki103_preprocessed.tar.gz
wget https://retomaton.s3.us-east-2.amazonaws.com/law/law_preprocessed.tar.gz
tar -xzvf law_preprocessed.tar.gz
We include Fairseq's instructions on how to prepare the data here.
cd examples/language_model/
bash prepare-wikitext-103.sh
cd ../..
TEXT=examples/language_model/wikitext-103
python preprocess.py \
--only-source \
--trainpref $TEXT/wiki.train.tokens \
--validpref $TEXT/wiki.valid.tokens \
--testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103 \
--workers 20
The data is originally from: https://github.com/roeeaharoni/unsupervised-domain-clusters.
We used the law/
subdirectory, and only the English "source" files.
Then, we re-tokenized the dataset using the model's BPE tokenizer.
The tokenized dataset can be downloaded from:
mkdir -p datasets/law
wget -P datasets/law/ https://retomaton.s3.us-east-2.amazonaws.com/law/law_tokenized.tar.gz
cd datasets/law
tar -xzvf law_tokenized.tar.gz
cd ../../
and then preprocessing it can be performed using:
TEXT=datasets/law
python preprocess.py \
--only-source \
--trainpref $TEXT/train.tokenized \
--validpref $TEXT/dev.en.tokenized \
--testpref $TEXT/test.en.tokenized \
--destdir data-bin/law \
--workers 20
The models that we used can be downloaded from the following sources: For Wikitext-103:
wget -P checkpoints/wt103/ https://nlp.stanford.edu/projects/knnlm/wt103_checkpoint_best.pt
For Law-MT:
wget -P checkpoints/law/ https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz
cd checkpoints/law
tar -xzvf wmt19.en.tar.gz
cd ..
We also include Fairseq's instructions on how to train the language model here:
python train.py --task language_modeling \
data-bin/wikitext-103 \
--save-dir checkpoints/ \
--arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
This model was trained on 8 gpus.
To evaluate the base model on the validation set (without any retrieval):
For Wikitext-103:
python eval_lm.py data-bin/wikitext-103 \
--path checkpoints/wt103/wt103_checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024 --batch-size 2 \
--gen-subset valid
For Law-MT:
python eval_lm.py data-bin/law \
--sample-break-mode eos \
--path checkpoints/law/wmt19.en/model.pt \
--max-tokens 2048 --context-window 0 --batch-size 2 \
--gen-subset valid --remove-bpe
Notice that the main difference between the datasets is that in Law-MT we use the flags --remove-bpe
and --sample-break-mode eos
, and also the --max-tokens
and --context-window
values are different.
The next step is to run model evaluation over the entire training set, and save keys and values.
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/dstore16_vals.npy
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/dstore16_keys.npy
Note: The keys of Wikitext-103 take 200GB of disk space
wget -P checkpoints/law/ https://retomaton.s3.us-east-2.amazonaws.com/law/dstore16_vals.npy
wget -P checkpoints/law/ https://retomaton.s3.us-east-2.amazonaws.com/law/dstore16_keys.npy
python eval_lm.py data-bin/wikitext-103 \
--path checkpoints/wt103/wt103_checkpoint_best.pt \
--sample-break-mode none --max-tokens 3072 \
--softmax-batch 1024 --batch-size 2 --gen-subset train \
--context-window 1536 --tokens-per-sample 1536 \
--dstore-mmap checkpoints/wt103/dstore16 --knn-keytype 'last_ffn_input' \
--dstore-size 103225485 --model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--save-knnlm-dstore --dstore-fp16
The total number of tokens in the Wikitext-103 training set is 103227021
. The dstore size 103225485
is 1536
tokens less than the total due to the context-window. We want each key to be constructed using a minimum amount of prior context.
Following the instructions and using the code of https://github.com/jxhe/efficient-knnlm, we created the datastore using their code:
cd ../efficient-knnlm
python eval_lm.py ../retomaton/data-bin/law \
--path ../retomaton/checkpoints/law/wmt19.en/model.pt \
--sample-break-mode eos --max-tokens 2048 \
--softmax-batch 1024 --batch-size 2 --gen-subset train \
--context-window 0 --tokens-per-sample 512 \
--dstore-mmap ../retomaton/checkpoints/law/dstore16 --knn-keytype 'last_ffn_input' \
--dstore-size 19068709 \
--log-interval 100 \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--dstore-fp16 \
--save-knnlm-dstore
The FAISS index requires a training stage where it learns an index for accessing the keys quickly. Once this is completed, the keys must all be added to the index. The speed of adding keys to the index depends on the hardware, particularly the amount of RAM available.
To download our index:
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/knn16.index
wget -P checkpoints/law/ https://retomaton.s3.us-east-2.amazonaws.com/law/knn.19048862.index
For Wikitext-103:
DSTORE=checkpoints/wt103/dstore16
DSTORE_SIZE=103225485
INDEX=checkpoints/wt103/knn16.index
For Law-MT:
DSTORE=checkpoints/law/dstore16
DSTORE_SIZE=19068709
INDEX=checkpoints/law/knn16.index
and then for both datasets:
python build_dstore.py \
--dstore_mmap ${DSTORE} \
--dstore_size ${DSTORE_SIZE} \
--faiss_index ${INDEX} \
--num_keys_to_add_at_a_time 500000 \
--starting_point 0
To evaluate the model on the validation set:
DSTORE=checkpoints/wt103/dstore16
DSTORE_SIZE=103225485
INDEX=checkpoints/wt103/knn16.index
MODEL=checkpoints/wt103/wt103_checkpoint_best.pt
python eval_lm.py data-bin/wikitext-103 \
--path ${MODEL} \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024000 --batch-size 2 \
--gen-subset valid --dstore-filename ${DSTORE} \
--indexfile ${INDEX} \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size ${DSTORE_SIZE} --knn-keytype last_ffn_input \
--probe 32 --knnlm --dstore-fp16 \
--knn-sim-func do_not_recomp_l2 --no-load-keys --move-dstore-to-mem \
--knnlm-gpu --min-knns 1 --max-knns 1024
To encourage the model to perform a full kNN search more frequently and thus increase accuracy and reduce perplexity, use a larger value of --min-knns
such as 100
. Using --min-knns 9999999
makes the model perform kNN search at every step (FoSS = 0
in Figure 3 of the paper), and achieves the best results at the cost of slower speed.
To run the baseline kNN-LM, add the flag --no-pointer
.
DSTORE=checkpoints/law/dstore16
DSTORE_SIZE=19068709
INDEX=checkpoints/law/knn16.index
MODEL=checkpoints/law/wmt19.en/model.pt
python eval_lm.py data-bin/law \
--path ${MODEL} \
--sample-break-mode eos --max-tokens 2048 \
--context-window 0 --softmax-batch 1024000 --batch-size 2 \
--gen-subset valid --dstore-filename ${DSTORE} \
--indexfile ${INDEX} \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.9 --dstore-size ${DSTORE_SIZE} --knn-keytype last_ffn_input \
--probe 32 --knnlm --dstore-fp16 \
--knn-sim-func do_not_recomp_l2 --no-load-keys --move-dstore-to-mem \
--remove-bpe \
--knnlm-gpu --min-knns 1 --max-knns 1024
Notice that the difference betweens between the datasets is that in Law-MT we use the flags --remove-bpe
and --sample-break-mode eos
, and also the --max-tokens
and --context-window
values are different.
Further, as found by He et al., 2021, the interpolation coefficient should be set to --lmbda 0.9
, to give more weight to the datastore than the base LM.
For the Greedy Merge clustering algorithm. See the code of He et al. (2021). Greedy Merge is much faster and requires much fewer memory than k-means, but results in slightly higher perplexity:
See also Figures 8 and 9 in Appendix D in the paper.
Note that only one of the following files is needed. For the main experiments in the paper, we used:
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/clusters_s40000000_k1000000_members.pkl
but additional clusterings are available as well:
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/clusters_s20000000_k500000_members.pkl
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/dstore_merge15_members_sp.pkl
wget -P checkpoints/wt103/ https://retomaton.s3.us-east-2.amazonaws.com/wt103/dstore_merge29_members.pkl
Note that only one of the following files is needed. For the main experiments in the paper, we used:
wget -P checkpoints/law/ https://retomaton.s3.us-east-2.amazonaws.com/law/law_clusters_s40000000_k200000_members.pkl
but additional clustering is available as well:
wget -P checkpoints/law/ https://retomaton.s3.us-east-2.amazonaws.com/law/law_clusters_s40000000_k400000_members.pkl
Basically identical to Step 6: Evaluating RetoMaton without clustering, except that we add the flag --members <filename>_members.pkl
,
DSTORE=checkpoints/wt103/dstore16
DSTORE_SIZE=103225485
INDEX=checkpoints/wt103/knn16.index
MODEL=checkpoints/wt103/wt103_checkpoint_best.pt
MEMBERS=checkpoints/wt103/clusters_s40000000_k1000000_members.pkl
python eval_lm.py data-bin/wikitext-103 \
--path ${MODEL} \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024000 --batch-size 2 \
--gen-subset valid --dstore-filename ${DSTORE} \
--indexfile ${INDEX} \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size ${DSTORE_SIZE} --knn-keytype last_ffn_input \
--probe 32 --knnlm --dstore-fp16 \
--knn-sim-func do_not_recomp_l2 --no-load-keys --move-dstore-to-mem \
--knnlm-gpu --min-knns 1 --max-knns 1024 \
--members ${MEMBERS}
DSTORE=checkpoints/law/dstore16
DSTORE_SIZE=19068709
INDEX=checkpoints/law/knn16.index
MODEL=checkpoints/law/wmt19.en/model.pt
MEMBERS=checkpoints/law/law_clusters_s40000000_k200000_members.pkl
python eval_lm.py data-bin/law \
--path ${MODEL} \
--sample-break-mode eos --max-tokens 2048 \
--context-window 0 --softmax-batch 1024000 --batch-size 2 \
--gen-subset valid --dstore-filename ${DSTORE} \
--indexfile ${INDEX} \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.9 --dstore-size ${DSTORE_SIZE} --knn-keytype last_ffn_input \
--probe 32 --knnlm --dstore-fp16 \
--knn-sim-func do_not_recomp_l2 --no-load-keys --move-dstore-to-mem \
--remove-bpe \
--knnlm-gpu --min-knns 1 --max-knns 1024\
--members ${MEMBERS}
for Wikitext-103:
DSTORE=checkpoints/wt103/dstore16
DSTORE_SIZE=103225485
NUM_CLUSTERS=1000000
SAMPLE=40000000
DIM=1024
SAVE=kmeans_wt103
For Law-MT:
DSTORE=checkpoints/law/dstore16
DSTORE_SIZE=19068709
NUM_CLUSTERS=200000
SAMPLE=40000000
DIM=1536
SAVE=kmeans_law
And then for both datasets:
python kmeans.py --dstore ${DSTORE} --dstore-size ${DSTORE_SIZE} --num-clusters ${NUM_CLUSTERS} --sample ${SAMPLE} --dim ${DIM} --save ${}
The model that was fine-tuned on Law-MT, along with its corresponding datastore, FAISS index and clustering can be downloaded from:
mkdir checkpoints/law-finetuned/
wget -P checkpoints/law-finetuned/ https://retomaton.s3.us-east-2.amazonaws.com/law/finetuned.pt
wget -P checkpoints/law-finetuned/ https://retomaton.s3.us-east-2.amazonaws.com/law/dstore16_finetuned_size19068709_embed1536_fp16_vals.npy
wget -P checkpoints/law-finetuned/ https://retomaton.s3.us-east-2.amazonaws.com/law/dstore16_finetuned_size19068709_embed1536_fp16_keys.npy
wget -P checkpoints/law-finetuned/ https://retomaton.s3.us-east-2.amazonaws.com/law/knn_finetuned.index
wget -P checkpoints/law-finetuned/ https://retomaton.s3.us-east-2.amazonaws.com/law/law_finetuned_clusters_s20000000_k200000_members.pkl
Finally, evaluate using the fine-tuned checkpoint, datastore, and index.
It is important to also set --lmbda 0.25
when using the fine-tuned model: since the model is fine-tuned, we can rely on it more than before. See a clarification at #lambda-values
Best results with the fine-tuned model are achieved without clustering (that is, every datastore entry is a singleton cluster).
Then, the same steps as before should be run on the Law-MT datasets, except that:
finetuned.pt
should be used as the${MODEL}
dstore16_finetuned_size19068709_embed1536_fp16
should be used as the${DSTORE}
knn_finetuned.index
should be used as the${INDEX}
law_finetuned_clusters_s20000000_k200000_members.pkl
shoould be used as${MEMBERS}
That is:
DSTORE=checkpoints/law-finetuned/dstore16_finetuned_size19068709_embed1536_fp16
DSTORE_SIZE=19068709
INDEX=checkpoints/law-finetuned/knn_finetuned.index
MODEL=checkpoints/law-finetuned/finetuned.pt
MEMBERS=checkpoints/law-finetuned/law_finetuned_clusters_s20000000_k200000_members.pkl
python eval_lm.py data-bin/law \
--path ${MODEL} \
--sample-break-mode eos --max-tokens 2048 \
--context-window 0 --softmax-batch 1024000 --batch-size 2 \
--gen-subset valid --dstore-filename ${DSTORE} \
--indexfile ${INDEX} \
--model-overrides "{'knn_keytype': 'last_ffn_input'}" \
--k 1024 --lmbda 0.25 --dstore-size ${DSTORE_SIZE} --knn-keytype last_ffn_input \
--probe 32 --knnlm --dstore-fp16 \
--knn-sim-func do_not_recomp_l2 --no-load-keys --move-dstore-to-mem \
--remove-bpe \
--knnlm-gpu --min-knns 1 --max-knns 1024
In all configurations, the interpolation factor lmbda
is set to 0.25
, except when the base LM is checkpoints/law/wmt19.en/model.pt
and the model is evaluated on Law-MT, since this scenario tests domain adaptation, and thus lmbda
should be set to 0.9
:
wt103_checkpoint_best.pt |
wmt19.en/model.pt |
finetuned.pt |
|
---|---|---|---|
Wikitext-103 | 0.25 | - | - |
Law-MT | - | 0.9 | 0.25 |
Checkpoints and datasets can be downloaded from here: https://zenodo.org/record/6525426
And also from the AWS S3 bucket
Here we point to the code that differs our work from kNN-LM.
- The main changes are in this commit. The pointers for the next timestep are initially the current k-nearest neighbors + 1. Then we extend each pointer to consider all entries in its cluster. This is the function that maps each pointer to its cluster, removes duplicate clusters, and then finds the members of each cluster. We find the log probabilities as suggested by the new pointers, and finally take to the next timestep - only the pointers that are consistent with the token that the model eventually predicted.
- In this commit we utilize the given pointers, or perform kNN search and combine the results with the existing pointers.
- When using the
--knnlm-gpu
flag, we use a GPU index to search for nearest neighbors, and its copy CPU index to reconstruct vectors given their ID. Unfortunately, currently reconstructing vectors infaiss
is not implemented for GPU indexes (see also this issue). - Reconstructing a batch of vectors from the index is unfortunately not implemented in
faiss
(see this issue), and thus the fastest way that we found to do that is usingnp.vectorize
, and reconstructing many single vectors in parallel: fairseq/knnlm.py#L92-L94. - Performing k-means clustering on millions of vectors can be performed in many ways, but specifically we utilize the
faiss
library to do it using the script kmeans.py.
- The original kNN-LM repository uses
faiss
CPU to perform retrieval. However, we added the flag--knnlm-gpu
that allows performing retrieval much faster on the GPU. - After each retrieval, the original kNN-LM repository loads the found keys and re-computes the distance from the query to each nearest neighbor. This is much more time consuming, unless loading all the keys (200GB) into memory.
We thus use the flags
--knn-sim-func do_not_recomp_l2 --no-load-keys --move-dstore-to-mem
. - When using
faiss-gpu
, it is useful toimport faiss.contrib.torch_utils
. This allows performing the kNN search usingtorch
tensors (rather than onlynumpy
arrays). Additionally, sometimes thisimport
statement prevents searching bugs infaiss
(see this issue).
Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval
@inproceedings{alon2022neuro,
title={Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval},
author={Alon, Uri and Xu, Frank and He, Junxian and Sengupta, Sudipta and Roth, Dan and Neubig, Graham},
booktitle={International Conference on Machine Learning},
pages={468--485},
year={2022},
organization={PMLR}
}