Skip to content

Commit 95c6966

Browse files
committed
First commit.
1 parent 2807a2c commit 95c6966

16 files changed

+3528
-0
lines changed

README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Extracting rule-based descriptions of attention features in transformers
2+
3+
This repository contains the code for our paper, "Extracting rule-based descriptions of attention features in transformers".
4+
Please see our paper for more details.
5+
6+
## Quick links
7+
* [Setup](#Setup)
8+
* [Attention output SAEs](#Attention-output-SAEs)
9+
* [Data](#Data)
10+
* [Rule extraction](#Rule-extraction)
11+
* [Questions?](#Questions)
12+
* [Citation](#Citation)
13+
14+
## Setup
15+
16+
Install [PyTorch](https://pytorch.org/get-started/locally/) and then install the remaining requirements: `pip install -r requirements.txt`.
17+
This code was tested using Python 3.12 and PyTorch version 2.3.1.
18+
19+
## Attention output SAEs
20+
21+
We train [attention output SAEs](https://arxiv.org/abs/2406.17759) for every attention head in GPT-2 small, using a fork of https://github.com/ckkissane/attention-output-saes.
22+
These SAEs can be downloaded from: https://huggingface.co/danf0/attention-head-saes/.
23+
24+
## Data
25+
26+
Code for generating datasets of feature activations can be found in [src/get_exemplars.py](src/get_exemplars.py).
27+
See [scripts/generate_data.sh](scripts/generate_data.sh) for the command to generate the datasets used in our paper, which are based on [OpenWebText](https://huggingface.co/datasets/Skylion007/openwebtext).
28+
29+
## Rule extraction
30+
31+
Code for extracting and evaluating skip-gram rules can be found in [src/run_rules.py](src/run_rules.py).
32+
For example, the following command will extract rules for 10 features from head 0 in layer 0.
33+
```bash
34+
python src/run_rules.py \
35+
--layer 0 \
36+
--head 0 \
37+
--num_features 10 \
38+
--rule_type "v1" \
39+
--output_dir "output/skipgrams/L0H0";
40+
```
41+
Code for finding and generating rules containing "distractor" features is in [src/find_distractors.py](src/find_distractors.py) and [src/generate_distractors.py](src/generate_distractors.py)
42+
The [scripts](scripts) directory contains example contains example commands for running these scripts.
43+
44+
# Questions?
45+
46+
If you have any questions about the code or paper, please email Dan ([email protected]) or open an issue.
47+
48+
# Citation
49+
50+
```bibtex
51+
@article{friedman2025extracting,
52+
title={Extracting rule-based descriptions of attention features in transformers},
53+
author={Friedman, Dan and Wettig, Alexander and Bhaskar, Adithya and Chen, Danqi},
54+
journal={arXiv preprint},
55+
year={2025}
56+
}
57+
```

requirements.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
einops
2+
matplotlib
3+
numpy
4+
pandas
5+
sae-lens
6+
scikit-learn
7+
seaborn
8+
transformers
9+
tqdm

scripts/extract_rules.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/bin/bash
2+
3+
LAYER=0
4+
HEAD=0
5+
METHOD="magnitude"
6+
# METHOD="importance"
7+
8+
srun python src/run_rules.py \
9+
--layer "${LAYER}" \
10+
--head "${HEAD}" \
11+
--method "${METHOD}" \
12+
--num_features 100 \
13+
--example_dir "data/openwebtext_n50000_bins2x50/" \
14+
--output_dir "output/skipgrams/${METHOD}/L${LAYER}H${HEAD}/";

scripts/find_distractor_examples.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
LAYER=0
4+
HEAD=0
5+
6+
python src/find_distractors.py \
7+
--cmd "get_distractor_examples" \
8+
--layer "${LAYER}" \
9+
--head "${HEAD}" \
10+
--example_dir "data/openwebtext_n50000_bins2x50/" \
11+
--output_dir "output/distractor_examples/L${LAYER}H${SLURM_ARRAY_TASK_ID}/";

scripts/find_distractor_features.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
LAYER=0
4+
HEAD=0
5+
6+
srun python src/find_distractors.py \
7+
--cmd "get_distractors" \
8+
--layer "${LAYER}" \
9+
--head "${HEAD}" \
10+
--example_dir "data/openwebtext_n50000_bins2x50/" \
11+
--output_dir "output/distractor_features/L${LAYER}H${HEAD}/";

scripts/generate_data.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
3+
LAYER=0
4+
HEAD=0
5+
DATASET_SIZE=50000
6+
NUM_BINS=2
7+
EXAMPLES_PER_BIN=50
8+
SEED=0
9+
10+
python src/get_exemplars.py \
11+
--layer "${LAYER}" \
12+
--head "${HEAD}" \
13+
--dataset_size "${DATASET_SIZE}" \
14+
--num_bins "${NUM_BINS}" \
15+
--examples_per_bin "${EXAMPLES_PER_BIN}" \
16+
--max_length 64 \
17+
--num_features 100 \
18+
--min_count 150 \
19+
--max_count 49850 \
20+
--dataset_path "Skylion007/openwebtext" \
21+
--seed "${SEED}" \
22+
--output_dir "data/openwebtext_n${DATASET_SIZE}_bins${NUM_BINS}x${EXAMPLES_PER_BIN}/L${LAYER}H${HEAD}/";
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
HEAD=0
4+
5+
python src/generate_distractors.py \
6+
--layer 0 \
7+
--head "${HEAD}" \
8+
--example_dir "output/distractor_examples/" \
9+
--output_dir "output/generated_distractor_examples/L0H${HEAD}/";

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from setuptools import setup, find_packages
2+
3+
setup(name="src", version="0.1", packages=find_packages())

0 commit comments

Comments
 (0)