Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patchscopes code #45

Merged
merged 13 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions patchscopes/code/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## 🩺 Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models


### Overview
We propose a framework that decodes specific information from a representation within an LLM by “patching” it into the inference pass on a different prompt that has been designed to encourage the extraction of that information. A "Patchscope" is a configuration of our framework that can be viewed as an inspection tool geared towards a particular objective.

For example, this figure shows a simple Patchscope for decoding what is encoded in the representation of "CEO" in the source prompt (left). We patch a target prompt (right) comprised of few-shot demonstrations of token repetitions, which encourages decoding the token identity given a hidden representation.

[**[Paper]**](https://arxiv.org/abs/2401.06102) [**[Project Website]**](https://pair-code.github.io/interpretability/patchscopes/)

<p align="left"><img width="60%" src="images/patchscopes.png" /></p>

### 💾 Download textual data
The script is provided [**here**](download_the_pile_text_data.py). Use the following command to run it:
```python
python3 download_the_pile_text_data.py
```

### 🦙 For using Vicuna-13B
Run the following command for using the Vicuna 13b model (see also details [here](https://huggingface.co/CarperAI/stable-vicuna-13b-delta)):
```python
python3 apply_delta.py --base meta-llama/Llama-2-13b-hf --target ./stable-vicuna-13b --delta CarperAI/stable-vicuna-13b-delta
```

### 🧪 Experiments

#### (1) Next Token Prediction
The main code used appears [here](next_token_prediction.ipynb).
#### (2) Attribute Extraction
For this experiment, you should download the `preprocessed_data` directory.
The main code used appears [here](attribute_extraction.ipynb).
#### (3) Entity Processing
The main code used appears [here](entity_processing.ipynb). The dataset is available for downloading [here](https://github.com/AlexTMallen/adaptive-retrieval/blob/main/data/popQA.tsv).
#### (4) Cross-model Patching
The main code used appears [here](patch_cross_model.ipynb).
#### (5) Self-Correction in Multi-Hop Reasoning
For this experiment, you should download the `preprocessed_data` directory.
The main code used appears [here](multihop-CoT.ipynb). The code provided supports the Vicuna-13B model.

### 📙 BibTeX
```bibtex
@misc{ghandeharioun2024patchscopes,
title={Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models},
author={Ghandeharioun, Asma and Caciularu, Avi and Pearce, Adam and Dixon, Lucas and Geva, Mor},
year={2024},
eprint={2401.06102},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
51 changes: 51 additions & 0 deletions patchscopes/code/apply_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Usage:
python3 apply_delta.py --base /path/to/model_weights/llama-13b --target stable-vicuna-13b --delta pvduy/stable-vicuna-13b-delta

The code was adopted from https://github.com/GanjinZero/RRHF/blob/main/apply_delta.py
"""
import argparse

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM


def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)

print("Loading delta")
delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)

DEFAULT_PAD_TOKEN = "[PAD]"
base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))

base.resize_token_embeddings(len(base_tokenizer))
input_embeddings = base.get_input_embeddings().weight.data
output_embeddings = base.get_output_embeddings().weight.data
input_embeddings[-num_new_tokens:] = 0
output_embeddings[-num_new_tokens:] = 0

print("Applying delta")
for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
assert name in delta.state_dict()
param.data += delta.state_dict()[name]

print("Saving target model")
base.save_pretrained(target_model_path)
delta_tokenizer.save_pretrained(target_model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)

args = parser.parse_args()

apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
Loading
Loading