-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from PAIR-code/patchscopes-code
Patchscopes code
- Loading branch information
Showing
36 changed files
with
130,244 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
## 🩺 Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models | ||
|
||
|
||
### Overview | ||
We propose a framework that decodes specific information from a representation within an LLM by “patching” it into the inference pass on a different prompt that has been designed to encourage the extraction of that information. A "Patchscope" is a configuration of our framework that can be viewed as an inspection tool geared towards a particular objective. | ||
|
||
For example, this figure shows a simple Patchscope for decoding what is encoded in the representation of "CEO" in the source prompt (left). We patch a target prompt (right) comprised of few-shot demonstrations of token repetitions, which encourages decoding the token identity given a hidden representation. | ||
|
||
[**[Paper]**](https://arxiv.org/abs/2401.06102) [**[Project Website]**](https://pair-code.github.io/interpretability/patchscopes/) | ||
|
||
<p align="left"><img width="60%" src="images/patchscopes.png" /></p> | ||
|
||
### 💾 Download textual data | ||
The script is provided [**here**](download_the_pile_text_data.py). Use the following command to run it: | ||
```python | ||
python3 download_the_pile_text_data.py | ||
``` | ||
|
||
### 🦙 For using Vicuna-13B | ||
Run the following command for using the Vicuna 13b model (see also details [here](https://huggingface.co/CarperAI/stable-vicuna-13b-delta)): | ||
```python | ||
python3 apply_delta.py --base meta-llama/Llama-2-13b-hf --target ./stable-vicuna-13b --delta CarperAI/stable-vicuna-13b-delta | ||
``` | ||
|
||
### 🧪 Experiments | ||
|
||
#### (1) Next Token Prediction | ||
The main code used appears [here](next_token_prediction.ipynb). | ||
#### (2) Attribute Extraction | ||
For this experiment, you should download the `preprocessed_data` directory. | ||
The main code used appears [here](attribute_extraction.ipynb). | ||
#### (3) Entity Processing | ||
The main code used appears [here](entity_processing.ipynb). The dataset is available for downloading [here](https://github.com/AlexTMallen/adaptive-retrieval/blob/main/data/popQA.tsv). | ||
#### (4) Cross-model Patching | ||
The main code used appears [here](patch_cross_model.ipynb). | ||
#### (5) Self-Correction in Multi-Hop Reasoning | ||
For this experiment, you should download the `preprocessed_data` directory. | ||
The main code used appears [here](multihop-CoT.ipynb). The code provided supports the Vicuna-13B model. | ||
|
||
### 📙 BibTeX | ||
```bibtex | ||
@misc{ghandeharioun2024patchscopes, | ||
title={Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models}, | ||
author={Ghandeharioun, Asma and Caciularu, Avi and Pearce, Adam and Dixon, Lucas and Geva, Mor}, | ||
year={2024}, | ||
eprint={2401.06102}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Usage: | ||
python3 apply_delta.py --base /path/to/model_weights/llama-13b --target stable-vicuna-13b --delta pvduy/stable-vicuna-13b-delta | ||
The code was adopted from https://github.com/GanjinZero/RRHF/blob/main/apply_delta.py | ||
""" | ||
import argparse | ||
|
||
import torch | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
|
||
def apply_delta(base_model_path, target_model_path, delta_path): | ||
print("Loading base model") | ||
base = AutoModelForCausalLM.from_pretrained( | ||
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
|
||
print("Loading delta") | ||
delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) | ||
|
||
DEFAULT_PAD_TOKEN = "[PAD]" | ||
base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) | ||
num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN)) | ||
|
||
base.resize_token_embeddings(len(base_tokenizer)) | ||
input_embeddings = base.get_input_embeddings().weight.data | ||
output_embeddings = base.get_output_embeddings().weight.data | ||
input_embeddings[-num_new_tokens:] = 0 | ||
output_embeddings[-num_new_tokens:] = 0 | ||
|
||
print("Applying delta") | ||
for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): | ||
assert name in delta.state_dict() | ||
param.data += delta.state_dict()[name] | ||
|
||
print("Saving target model") | ||
base.save_pretrained(target_model_path) | ||
delta_tokenizer.save_pretrained(target_model_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--base-model-path", type=str, required=True) | ||
parser.add_argument("--target-model-path", type=str, required=True) | ||
parser.add_argument("--delta-path", type=str, required=True) | ||
|
||
args = parser.parse_args() | ||
|
||
apply_delta(args.base_model_path, args.target_model_path, args.delta_path) |
Oops, something went wrong.