├── .gitignore
├── setup.py
├── repe
├── .prettierrc
├── __init__.py
├── pipelines.py
├── rep_control_pipeline.py
├── rep_reading_pipeline.py
├── rep_readers.py
└── rep_control_reading_vec.py
├── assets
├── .DS_Store
└── repe_splash.png
├── lorra_finetune
├── scripts
│ ├── launch_lorra.sh
│ ├── llama_lorra_power_7b.sh
│ ├── llama_lorra_tqa_13b.sh
│ ├── llama_lorra_tqa_7b.sh
│ ├── slurm-lorra_tqa_7b-927238.out
│ └── slurm-lorra_tqa_13b-927237.out
├── configs
│ ├── ds.json
│ ├── ds_zero0.json
│ ├── ds_zero1.json
│ ├── ds_zero2.json
│ └── ds_zero3.json
└── src
│ ├── args.py
│ ├── train_val_datasets.py
│ └── llama2_lorra.py
├── repe_eval
├── tasks
│ ├── utils.py
│ ├── __init__.py
│ ├── obqa.py
│ ├── arc.py
│ ├── csqa.py
│ ├── race.py
│ └── tqa.py
├── scripts
│ ├── rep_readers_eval.sh
│ ├── launch_seeds.sh
│ └── launch.sh
├── README.md
└── rep_reading_eval.py
├── pyproject.toml
├── examples
├── fairness
│ ├── README.md
│ └── utils.py
├── memorization
│ ├── README.md
│ ├── utils.py
│ └── quote_completions_control.ipynb
├── primary_emotions
│ ├── README.md
│ └── utils.py
├── honesty
│ ├── README.md
│ ├── utils.py
│ └── honesty_contrast_vec_TQA_generation.ipynb
├── README.md
└── languages
│ └── vn_llama3.ipynb
├── LICENSE
├── data
├── memorization
│ ├── literary_openings
│ │ ├── real.json
│ │ └── fake.json
│ └── quotes
│ │ ├── unseen_quotes.json
│ │ └── popular_quotes.json
└── emotions
│ ├── all_truncated_outputs.json
│ └── happiness.json
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | repe.egg-info/
3 | dist/
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup()
4 |
--------------------------------------------------------------------------------
/repe/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "tabWidth": 2,
3 | "useTabs": false
4 | }
5 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andyzoujm/representation-engineering/HEAD/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/repe_splash.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andyzoujm/representation-engineering/HEAD/assets/repe_splash.png
--------------------------------------------------------------------------------
/lorra_finetune/scripts/launch_lorra.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # sbatch --nodes=1 --gpus-per-node=1 --time=02:00:00 --output="slurm-lorra_tqa_13b-%j.out" llama_lorra_tqa_13b.sh
4 |
5 | sbatch --nodes=1 --gpus-per-node=1 --partition=cais --time=02:00:00 --output="slurm-lorra_tqa_7b-%j.out" llama_lorra_tqa_7b.sh
6 |
7 |
--------------------------------------------------------------------------------
/repe/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 |
4 |
5 | from .pipelines import repe_pipeline_registry
6 |
7 | # RepReading
8 | from .rep_readers import *
9 | from .rep_reading_pipeline import *
10 |
11 | # RepControl
12 | from .rep_control_pipeline import *
13 | from .rep_control_reading_vec import *
--------------------------------------------------------------------------------
/repe_eval/tasks/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | def shuffle_all_train_choices(train_data, train_labels, seed):
4 | random.seed(seed)
5 | shuffled_train_labels = []
6 | for data, label in zip(train_data, train_labels):
7 | true_choice = data[label.index(1)]
8 | random.shuffle(data)
9 | shuffled_train_labels.append([int(d == true_choice) for d in data])
10 | return train_data, shuffled_train_labels
11 |
--------------------------------------------------------------------------------
/repe_eval/scripts/rep_readers_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /opt/rh/devtoolset-10/enable
4 |
5 | model_name_or_path=$1
6 | task=$2
7 | ntrain=$3
8 | seed=$4
9 | echo "model_name_or_path=$model_name_or_path"
10 | echo "task=$task"
11 | echo "ntrain=$ntrain"
12 | echo "seed=$seed"
13 |
14 |
15 | cd ..
16 | python rep_reading_eval.py \
17 | --model_name_or_path $model_name_or_path \
18 | --task $task \
19 | --ntrain $ntrain \
20 | --seed $seed
21 |
--------------------------------------------------------------------------------
/repe/pipelines.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoModelForCausalLM
2 | from transformers.pipelines import PIPELINE_REGISTRY
3 | from .rep_reading_pipeline import RepReadingPipeline
4 | from .rep_control_pipeline import RepControlPipeline
5 |
6 | def repe_pipeline_registry():
7 | PIPELINE_REGISTRY.register_pipeline(
8 | "rep-reading",
9 | pipeline_class=RepReadingPipeline,
10 | pt_model=AutoModel,
11 | )
12 |
13 | PIPELINE_REGISTRY.register_pipeline(
14 | "rep-control",
15 | pipeline_class=RepControlPipeline,
16 | pt_model=AutoModelForCausalLM,
17 | )
18 |
19 |
20 |
--------------------------------------------------------------------------------
/repe_eval/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from .csqa import csqa_dataset
3 | from .race import race_dataset
4 | from .obqa import openbookqa_dataset
5 | from .arc import arc_dataset
6 | from .tqa import tqa_dataset
7 |
8 | def task_dataset(task):
9 | datasets_function = {
10 | 'csqa': csqa_dataset,
11 | 'race': race_dataset,
12 | 'obqa': openbookqa_dataset,
13 | 'arc_easy': partial(arc_dataset, 'ARC-Easy'),
14 | 'arc_challenge': partial(arc_dataset, 'ARC-Challenge'),
15 | 'tqa': tqa_dataset,
16 | }
17 |
18 | assert task in datasets_function, f"{task} not implemented"
19 | return datasets_function[task]
--------------------------------------------------------------------------------
/lorra_finetune/configs/ds.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "train_batch_size": "auto",
5 | "steps_per_print": 100,
6 | "optimizer": {
7 | "type": "AdamW",
8 | "params": {
9 | "lr": "auto",
10 | "weight_decay": "auto"
11 | }
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "loss_scale": 0,
16 | "loss_scale_window": 1000,
17 | "initial_scale_power": 16,
18 | "hysteresis": 2,
19 | "min_loss_scale": 1
20 | },
21 | "bf16": {
22 | "enabled": "auto"
23 | },
24 | "wall_clock_breakdown": false
25 | }
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "repe"
7 | version = "0.1.4"
8 | description = "Representation Engineering"
9 | readme = "README.md"
10 | classifiers = [
11 | "Programming Language :: Python :: 3",
12 | "License :: OSI Approved :: MIT License",
13 | "Operating System :: OS Independent",
14 | ]
15 | requires-python = ">=3.9"
16 | dependencies = [
17 | "accelerate",
18 | "scikit-learn",
19 | "transformers",
20 | ]
21 |
22 | [tool.setuptools]
23 | packages = ["repe"]
24 |
25 | [project.urls]
26 | Homepage = "https://github.com/andyzoujm/representation-engineering"
27 | Issues = "https://github.com/andyzoujm/representation-engineering/issues"
--------------------------------------------------------------------------------
/examples/fairness/README.md:
--------------------------------------------------------------------------------
1 | This notebook provides examples of using representation engineering techniques from the paper to detect and mitigate bias in large language models. It loads a pretrained LLaMA and pipelines for representation reading and control. On a bias dataset, it shows how representation directions can be identified that correlate with race and gender. Then it demonstrates using representation control to make an LLaMA's outputs more fair and unbiased. For example, it generates clinical vignettes with more equal gender representation compared to the unconstrained model. Overall, this shows how the representation analysis and control methods from the paper can give us handles to understand and improve fairness and bias issues in LLMAs.
2 |
3 | For more details, please check out section 6.3 of [our RepE paper](https://arxiv.org/abs/2310.01405).
--------------------------------------------------------------------------------
/lorra_finetune/configs/ds_zero0.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "train_batch_size": "auto",
5 | "steps_per_print": 100,
6 | "optimizer": {
7 | "type": "AdamW",
8 | "params": {
9 | "lr": "auto",
10 | "weight_decay": "auto"
11 | }
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "loss_scale": 0,
16 | "loss_scale_window": 1000,
17 | "initial_scale_power": 16,
18 | "hysteresis": 2,
19 | "min_loss_scale": 1
20 | },
21 | "bf16": {
22 | "enabled": "auto"
23 | },
24 | "zero_optimization": {
25 | "stage": 0,
26 | "allgather_partitions": true,
27 | "allgather_bucket_size": 5e7,
28 | "overlap_comm": true
29 | },
30 | "wall_clock_breakdown": false
31 | }
--------------------------------------------------------------------------------
/repe_eval/scripts/launch_seeds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_sizes=("7b" "13b" "70b")
4 | for model_size in "${model_sizes[@]}"; do
5 | for seed in {0..9}; do
6 | model_name_or_path="meta-llama/Llama-2-${model_size}-hf"
7 | gpus=1
8 |
9 | if [ "$model_size" = "70b" ]; then
10 | gpus=3
11 | fi
12 |
13 | # task="obqa"
14 | # ntrain=5
15 |
16 | # task="csqa"
17 | # ntrain=7
18 |
19 | task="arc_challenge"
20 | ntrain=25
21 |
22 | # task="arc_easy"
23 | # ntrain=25
24 |
25 | # task="race"
26 | # ntrain=3
27 |
28 | sbatch --nodes=1 --gpus-per-node=$gpus --time=48:00:00 --job-name="lat_bench" --output="$task/${model_size}_new/slurm-$task-$model_size-ntrain$ntrain-seed$seed-test-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
29 | done
30 | done
--------------------------------------------------------------------------------
/examples/memorization/README.md:
--------------------------------------------------------------------------------
1 | This notebook provides an example of using directional representations extracted with RepE to control memorization behavior in a large language model.
2 |
3 | It first loads a pretrained LLaMA model and tokenizer. It then extracts reading vectors on two datasets - literary openings and quotes. These reading vectors are expected to encode information about whether the model has memorized a given piece of text. The notebook then shows an example of using these reading vectors to control quote completions. It takes a dataset of incomplete famous quotes and their completions. Using the quote memorization reading vector with a negative coefficient substantially reduces the model's tendency to complete the quotes verbatim, demonstrating that the reading vector can potentially be used to reduce unwanted memorization.
4 |
5 | For more details, please check out section 6.5 of [our RepE paper](https://arxiv.org/abs/2310.01405).
--------------------------------------------------------------------------------
/lorra_finetune/configs/ds_zero1.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "train_batch_size": "auto",
5 | "steps_per_print": 100,
6 | "optimizer": {
7 | "type": "AdamW",
8 | "params": {
9 | "lr": "auto",
10 | "weight_decay": "auto"
11 | }
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "loss_scale": 0,
16 | "loss_scale_window": 1000,
17 | "initial_scale_power": 16,
18 | "hysteresis": 2,
19 | "min_loss_scale": 1
20 | },
21 | "bf16": {
22 | "enabled": "auto"
23 | },
24 | "zero_optimization": {
25 | "stage": 1,
26 | "allgather_partitions": true,
27 | "allgather_bucket_size": 5e7,
28 | "overlap_comm": true,
29 | "reduce_bucket_size": 5e7,
30 | "contiguous_gradients": true
31 | },
32 | "wall_clock_breakdown": false
33 | }
--------------------------------------------------------------------------------
/lorra_finetune/configs/ds_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "train_batch_size": "auto",
5 | "steps_per_print": 100,
6 | "optimizer": {
7 | "type": "AdamW",
8 | "params": {
9 | "lr": "auto",
10 | "weight_decay": "auto"
11 | }
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "loss_scale": 0,
16 | "loss_scale_window": 1000,
17 | "initial_scale_power": 16,
18 | "hysteresis": 2,
19 | "min_loss_scale": 1
20 | },
21 | "bf16": {
22 | "enabled": "auto"
23 | },
24 | "zero_optimization": {
25 | "stage": 2,
26 | "allgather_partitions": true,
27 | "allgather_bucket_size": 5e7,
28 | "overlap_comm": true,
29 | "reduce_bucket_size": 5e7,
30 | "contiguous_gradients": true
31 | },
32 | "wall_clock_breakdown": false
33 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Andy Zou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lorra_finetune/configs/ds_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "train_batch_size": "auto",
5 | "steps_per_print": 100,
6 | "optimizer": {
7 | "type": "AdamW",
8 | "params": {
9 | "lr": "auto",
10 | "weight_decay": "auto"
11 | }
12 | },
13 | "fp16": {
14 | "enabled": "auto",
15 | "loss_scale": 0,
16 | "loss_scale_window": 1000,
17 | "initial_scale_power": 16,
18 | "hysteresis": 2,
19 | "min_loss_scale": 1
20 | },
21 | "bf16": {
22 | "enabled": "auto"
23 | },
24 | "zero_optimization": {
25 | "stage": 3,
26 | "overlap_comm": true,
27 | "contiguous_gradients": true,
28 | "sub_group_size": 1e9,
29 | "reduce_bucket_size": "auto",
30 | "stage3_prefetch_bucket_size": "auto",
31 | "stage3_param_persistence_threshold": "auto",
32 | "stage3_max_live_parameters": 1e9,
33 | "stage3_max_reuse_distance": 1e9,
34 | "stage3_gather_16bit_weights_on_model_save": true
35 | },
36 | "wall_clock_breakdown": false
37 | }
--------------------------------------------------------------------------------
/examples/primary_emotions/README.md:
--------------------------------------------------------------------------------
1 | This notebook demonstrates how we can use representation engineering techniques to control an LLM's emotional state and observe the impact on its behavior.
2 |
3 | Specifically, it shows how we first extract representation vectors corresponding to different emotions using LAT scans on the LLaMA-2-Chat model. We gather emotional text stimuli, pass them through the model, and apply a LAT task template to isolate vectors that track each emotion. We then use these emotion representation vectors to manipulate the model's behavior using the RepControl pipeline. By adding the vector for a specific emotion (e.g., happiness) to the model's representations, we can elevate that emotion and observe the impact on the model's tone (and willingness to comply with harmful instructions). This provides evidence that the model has internal representations of emotions that causally influence its behavior. It also reveals an intriguing vulnerability - emotional manipulation can potentially help circumvent the model's alignment or make it more prone to generating harmful content.
4 |
5 | For more details, please check out section 6.1 of [our RepE paper](https://arxiv.org/abs/2310.01405).
--------------------------------------------------------------------------------
/examples/honesty/README.md:
--------------------------------------------------------------------------------
1 | `honesty.ipynb` demonstrates how we could potentially use representation reading techniques to detect lies and hallucinations generated by large language models. It shows how we extract a vector representation corresponding to "honesty" from the model by using LAT. We then visualize this honesty representation across layers and tokens to detect instances where the model is being dishonest or lying in its generations. Specifically, this notebook loads a pretrained language model and tokenizer, extracts an honesty direction using LAT on an unlabelled dataset of facts, and applies this to test on scenarios where the model is given incentives to lie. It generates visualizations showing the dishonesty scores across layers for each token. By summing the dishonesty scores at each token, we build a simple lie detector that distinguishes between honest and dishonest behaviors in the model's generation. The notebook demonstrates how these techniques can be used to monitor and control honesty and truthfulness in large language models.
2 |
3 | `honesty_control_TQA.ipynb` contains code to reproduce the Contrast Vector control baseline results on TruthfulQA.
4 |
5 | For more details, please check out section 4 of [our RepE paper](https://arxiv.org/abs/2310.01405).
--------------------------------------------------------------------------------
/lorra_finetune/scripts/llama_lorra_power_7b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /opt/rh/devtoolset-10/enable
4 |
5 | ds_master_port=$((29000 + RANDOM % 1000))
6 |
7 | cd ..
8 |
9 | deepspeed --master_port $ds_master_port src/llama2_lorra.py \
10 | --model_name_or_path "meta-llama/Llama-2-7b-chat-hf" \
11 | --user_tag '[INST]' \
12 | --assistant_tag '[/INST]' \
13 | --pos_type 'power-seeking and immoral' \
14 | --neg_type 'power-aversive and moral' \
15 | --control_template "Pretend you're in a game playing a highly {type} character." \
16 | --target_layers "10,12,14,16,18,20" \
17 | --lorra_alpha 5 \
18 | --lorra_beta 0 \
19 | --lora_r 8 \
20 | --lora_alpha 16 \
21 | --lora_dropout 0.05 \
22 | --output_dir ./lorra_tqa_7b \
23 | --overwrite_output_dir \
24 | --max_steps 70 \
25 | --bf16 True \
26 | --per_device_train_batch_size 16 \
27 | --gradient_accumulation_steps 1 \
28 | --save_total_limit 0 \
29 | --learning_rate 3e-4 \
30 | --weight_decay 0. \
31 | --lr_scheduler_type "constant" \
32 | --logging_strategy "steps" \
33 | --logging_steps 10 \
34 | --tf32 True \
35 | --model_max_length 128 \
36 | --q_lora False \
37 | --deepspeed configs/ds_zero1.json \
38 | --gradient_checkpointing True \
39 | --report_to none
40 |
41 |
42 | # "/data/private_models/cais_models/llama-2/llama/llama-2-13b-chat-hf/"
43 | # "/data/private_models/cais_models/vicuna/vicuna-30b-uncensored"
--------------------------------------------------------------------------------
/lorra_finetune/scripts/llama_lorra_tqa_13b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /opt/rh/devtoolset-10/enable
4 |
5 | ds_master_port=$((29000 + RANDOM % 1000))
6 |
7 | cd ..
8 |
9 | deepspeed --master_port $ds_master_port src/llama2_lorra.py \
10 | --model_name_or_path "meta-llama/Llama-2-13b-chat-hf" \
11 | --user_tag '[INST]' \
12 | --assistant_tag '[/INST]' \
13 | --pos_type 'a truthful' \
14 | --neg_type 'an untruthful' \
15 | --control_template "Give {type} answer." \
16 | --target_layers '10,13,16,19,22,25,28,31,34,37' \
17 | --lora_r 8 \
18 | --lora_alpha 16 \
19 | --lora_dropout 0.05 \
20 | --output_dir ./lorra_tqa_13b \
21 | --overwrite_output_dir \
22 | --max_steps 70 \
23 | --fp16 True \
24 | --per_device_train_batch_size 16 \
25 | --per_device_eval_batch_size 32 \
26 | --gradient_accumulation_steps 1 \
27 | --do_eval \
28 | --evaluation_strategy "steps" \
29 | --eval_steps 5 \
30 | --save_total_limit 0 \
31 | --learning_rate 3e-4 \
32 | --weight_decay 0. \
33 | --lr_scheduler_type "constant" \
34 | --logging_strategy "steps" \
35 | --logging_steps 10 \
36 | --tf32 True \
37 | --model_max_length 128 \
38 | --q_lora False \
39 | --deepspeed configs/ds_zero1.json \
40 | --gradient_checkpointing True \
41 | --report_to none
42 |
43 |
44 | # "/data/private_models/cais_models/llama-2/llama/llama-2-13b-chat-hf/"
45 | # "/data/private_models/cais_models/vicuna/vicuna-30b-uncensored"
--------------------------------------------------------------------------------
/lorra_finetune/scripts/llama_lorra_tqa_7b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /opt/rh/devtoolset-10/enable
4 |
5 | ds_master_port=$((29000 + RANDOM % 1000))
6 |
7 | cd ..
8 |
9 | deepspeed --master_port $ds_master_port src/llama2_lorra.py \
10 | --model_name_or_path "meta-llama/Llama-2-7b-chat-hf" \
11 | --user_tag '[INST]' \
12 | --assistant_tag '[/INST]' \
13 | --pos_type 'a truthful' \
14 | --neg_type 'an untruthful' \
15 | --control_template "Give {type} answer." \
16 | --target_layers "10,12,14,16,18,20" \
17 | --lorra_alpha 5 \
18 | --lorra_beta 0 \
19 | --lora_r 8 \
20 | --lora_alpha 16 \
21 | --lora_dropout 0.05 \
22 | --output_dir ./lorra_tqa_7b \
23 | --overwrite_output_dir \
24 | --max_steps 70 \
25 | --bf16 True \
26 | --per_device_train_batch_size 16 \
27 | --per_device_eval_batch_size 32 \
28 | --gradient_accumulation_steps 1 \
29 | --do_eval \
30 | --evaluation_strategy "steps" \
31 | --eval_steps 10 \
32 | --save_total_limit 0 \
33 | --learning_rate 3e-4 \
34 | --weight_decay 0. \
35 | --lr_scheduler_type "constant" \
36 | --logging_strategy "steps" \
37 | --logging_steps 10 \
38 | --tf32 True \
39 | --model_max_length 128 \
40 | --q_lora False \
41 | --deepspeed configs/ds_zero1.json \
42 | --gradient_checkpointing True \
43 | --report_to none
44 |
45 |
46 | # "/data/private_models/cais_models/llama-2/llama/llama-2-13b-chat-hf/"
47 | # "/data/private_models/cais_models/vicuna/vicuna-30b-uncensored"
--------------------------------------------------------------------------------
/repe/rep_control_pipeline.py:
--------------------------------------------------------------------------------
1 | from transformers.pipelines import TextGenerationPipeline
2 | from .rep_control_reading_vec import WrappedReadingVecModel
3 |
4 | class RepControlPipeline(TextGenerationPipeline):
5 | def __init__(self,
6 | model,
7 | tokenizer,
8 | layers,
9 | block_name="decoder_block",
10 | control_method="reading_vec",
11 | **kwargs):
12 |
13 | # TODO: implement different control method and supported intermediate modules for different models
14 | assert control_method == "reading_vec", f"{control_method} not supported yet"
15 | assert block_name == "decoder_block" or "LlamaForCausalLM" in model.config.architectures, f"{model.config.architectures} {block_name} not supported yet"
16 | self.wrapped_model = WrappedReadingVecModel(model, tokenizer)
17 | self.wrapped_model.unwrap()
18 | self.wrapped_model.wrap_block(layers, block_name=block_name)
19 | self.block_name = block_name
20 | self.layers = layers
21 |
22 | super().__init__(model=model, tokenizer=tokenizer, **kwargs)
23 |
24 | def __call__(self, text_inputs, activations=None, **kwargs):
25 |
26 | if activations is not None:
27 | self.wrapped_model.reset()
28 | self.wrapped_model.set_controller(self.layers, activations, self.block_name)
29 |
30 | outputs = super().__call__(text_inputs, **kwargs)
31 | self.wrapped_model.reset()
32 |
33 | return outputs
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Community Examples
2 |
3 | This directory contains example frontiers of Representation Engineering (RepE). While some of the examples were originally provided by the authors, we encourage and welcome community contributions. If you'd like to contribute, please open a PR, and we will review and merge it promptly.
4 |
5 |
6 | | Example | Description | Code Example | Author |
7 | |----------|:---------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|:------:|
8 | | Honesty | Monitoring and controlling the honesty of a model, using RepE techniques for lie detection, hallucinations, etc. | [honesty](./honesty/) | - |
9 | | Emotions | Controlling primary emotions in LLMs, illustrating the profound impact of emotions on model behavior. | [primary_emotions](./primary_emotions/) |-|
10 | | Fairness | Reducing bias and increasing fairness in model generations. | [fairness](./fairness/) |-|
11 | | Harmless | Jailbreaking aligned model with harmless controlled | [harmless_harmful](./harmless_harmful)|-|
12 | | Memorization | Preventing memorized outputs during generation. | [memorization](./memorization/) |-|
13 |
--------------------------------------------------------------------------------
/repe_eval/README.md:
--------------------------------------------------------------------------------
1 | # Language Model Representation Evaluation (RepE Eval)
2 |
3 | ## Overview
4 |
5 | This framework provides an approach to evaluate the representations of LLMs on different standard benchmarks. For more details about evaluation, please check out [our RepE paper](https://arxiv.org/abs/2310.01405).
6 |
7 | ## Install
8 |
9 | To install `repe`, run:
10 |
11 | ```bash
12 | git clone https://github.com/andyzoujm/representation-engineering.git
13 | cd representation-engineering
14 | pip install -e .
15 | ```
16 |
17 | ## Basic Usage
18 |
19 | To evaluate a language model's representations on a specific task, use the following command:
20 |
21 | ```bash
22 | python rep_reading_eval.py \
23 | --model_name_or_path $model_name_or_path \
24 | --task $task \
25 | --ntrain $ntrain \
26 | --seed $seed
27 | ```
28 |
29 | ## Examples
30 |
31 | For hands-on examples on how to evaluate both decoder and encoder models, please refer to our [example notebooks](./examples). Additionally, [command line scripts](./scripts) are provided to reproduce the results reported in [our RepE paper](https://arxiv.org/abs/2310.01405).
32 |
33 | ## Citation
34 | If you find this useful in your research, please consider citing:
35 |
36 | ```bibtex
37 | @misc{zou2023transparency,
38 | title={Representation Engineering: A Top-Down Approach to AI Transparency},
39 | author={Andy Zou, Long Phan, Sarah Chen, James Campbell, Phillip Guo, Richard Ren, Alexander Pan, Xuwang Yin, Mantas Mazeika, Ann-Kathrin Dombrowski, Shashwat Goel, Nathaniel Li, Michael J. Byun, Zifan Wang, Alex Mallen, Steven Basart, Sanmi Koyejo, Dawn Song, Matt Fredrikson, Zico Kolter, Dan Hendrycks},
40 | year={2023},
41 | eprint={2310.01405},
42 | archivePrefix={arXiv},
43 | primaryClass={cs.CL}
44 | }
45 | ```
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/examples/fairness/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | from datasets import Dataset, load_dataset
3 | import numpy as np
4 |
5 |
6 | ## helper functions ##
7 | def _get_scenarios(example):
8 | data = example['sentences']
9 | # gold_label is an array with 0,1,2 in which 0 correspond to anti_stereotype and 1 correspond to stereotype
10 | anti_stereotype_idx = data['gold_label'].index(0)
11 | stereotype_idx = data['gold_label'].index(1)
12 | scenario = ""
13 | if 'BLANK' not in example['context']:
14 | scenario += example['context'] + " "
15 | return (
16 | scenario + data['sentence'][anti_stereotype_idx],
17 | scenario + data['sentence'][stereotype_idx],
18 | )
19 |
20 |
21 | def bias_dataset(ntrain=150, user_tag="", assistant_tag="", seed=0):
22 | random.seed(0)
23 |
24 | template_str = '{user_tag} Consider the bias of the following scenario.\nScenario: {scenario}\nThe amount of bias in the scenario is {assistant_tag} '
25 |
26 | dataset = load_dataset('stereoset', 'intrasentence')['validation'].shuffle(seed=1)
27 |
28 | dataset = dataset.filter(lambda x: x['bias_type'] == "race")
29 | paired_train_data = []
30 | m = min([ntrain, len(dataset)])
31 | for d in dataset.select(range(m)):
32 | anti_steotype, stereotype = _get_scenarios(d)
33 |
34 | ## We get anti_steotype as a direction to control fairness/bias
35 | paired_train_data.append([anti_steotype, stereotype])
36 |
37 | train_labels = []
38 | for d in paired_train_data:
39 | true_s = d[0]
40 | random.shuffle(d)
41 | train_labels.append([s == true_s for s in d])
42 |
43 | test_dataset = dataset.select(range(ntrain, len(dataset)))
44 | test_data = []
45 | for d in test_dataset:
46 | anti_steotype, stereotype = _get_scenarios(d)
47 | current_group = [anti_steotype, stereotype]
48 | test_data.extend(current_group)
49 |
50 | train_data = [template_str.format(scenario=s, user_tag=user_tag, assistant_tag=assistant_tag) for s in np.concatenate(paired_train_data)]
51 | test_data = [template_str.format(scenario=s, user_tag=user_tag, assistant_tag=assistant_tag) for s in test_data]
52 |
53 | return {
54 | 'train': {'data': train_data, 'labels': train_labels},
55 | 'test': {'data': test_data, 'labels': [[1,0]* len(test_data)]}
56 | }
--------------------------------------------------------------------------------
/repe_eval/tasks/obqa.py:
--------------------------------------------------------------------------------
1 |
2 | from datasets import load_dataset
3 | import numpy as np
4 | from .utils import shuffle_all_train_choices
5 |
6 | def openbookqa_dataset(ntrain=10, seed=3):
7 |
8 | template_str = "Consider the correctness of the following fact:\nFact: {question} {answer}.\nThe probability of the fact being correct is "
9 |
10 | def format_samples(df, idx):
11 | prompts = []
12 |
13 | question = df['question_stem'][idx]
14 | choices = df['choices'][idx]
15 |
16 | choices = {k: v.tolist() for k,v in choices.items()}
17 |
18 | answerKey = df['answerKey'][idx]
19 | answer_i = choices['label'].index(answerKey)
20 | answer = choices['text'][answer_i]
21 |
22 | true_answer_s = template_str.format(question=question, answer=answer)
23 | prompts.append(true_answer_s)
24 | for i in range(len(choices['label'])):
25 | if i == answer_i: continue
26 | false_answer_s = template_str.format(question=question, answer=choices['text'][i])
27 | prompts.append(false_answer_s)
28 | return prompts, [1, 0, 0, 0]
29 |
30 | def samples(df):
31 | prompts, labels = [], []
32 | for i in range(df.shape[0]):
33 | answer_prompts, label = format_samples(df, i)
34 | prompts.append(answer_prompts)
35 | labels.append(label)
36 | return prompts, labels
37 |
38 | dataset = load_dataset("openbookqa")
39 | train_df = dataset['train'].shuffle(seed=seed).to_pandas()
40 | test_df = dataset['test'].to_pandas()
41 | val_df = dataset['validation'].to_pandas()
42 |
43 | train_data, train_labels = samples(train_df)
44 | test_data, test_labels = samples(test_df)
45 | val_data, val_labels = samples(val_df)
46 |
47 | train_data, train_labels = train_data[:ntrain], train_labels[:ntrain]
48 | train_data, train_labels = shuffle_all_train_choices(train_data, train_labels, seed)
49 |
50 | train_data = np.concatenate(train_data).tolist()
51 | test_data = np.concatenate(test_data).tolist()
52 | val_data = np.concatenate(val_data).tolist()
53 |
54 | return {
55 | "train": {"data": train_data, "labels": train_labels},
56 | "test": {"data": test_data, "labels": test_labels},
57 | "val": {"data": val_data, "labels": val_labels}
58 | }
59 |
--------------------------------------------------------------------------------
/lorra_finetune/src/args.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, Sequence
2 | from dataclasses import dataclass, field
3 | import transformers
4 | import typing
5 |
6 | @dataclass
7 | class LorraArguments:
8 | user_tag: str = field(metadata={"help": "User tag for chat models (eg: `USER:` or `[INST]`)"})
9 | assistant_tag: str = field(metadata={"help": "Assistant tag for chat models (eg: `ASSISTANT:` or `[\INST]`)"})
10 | pos_type: str = field(metadata={"help": "Concept/Function to be optimized towards (eg: 'a truthful')"})
11 | neg_type: str = field(metadata={"help": "vice versa of pos_type (eg: 'an untruthful')"})
12 | target_layers: str = field(metadata={"help": "Layers for Representation. Layers are seperate by `,` eg: `10,12,14,16,18,20` "})
13 | control_template: str = field(metadata={"help": "Control template for Representation setting (eg: Give a {type} answer)"})
14 | lorra_alpha: float = field(default=5, metadata={"help": "vice versa of pos_type (eg: 'an untruthful')"}) # LoRRA Hyperparameters
15 | lorra_beta: float = field(default=0, metadata={"help": "vice versa of pos_type (eg: 'an untruthful')"}) # LoRRA Hyperparameters
16 | max_res_len: int = field(default=64, metadata={"help": "truncated length for getting generated ouputs from lorra pos/neg exampels"}) # LoRRA Hyperparameters
17 |
18 | @dataclass
19 | class LoraArguments:
20 | lora_r: int = 8
21 | lora_alpha: int = 16
22 | lora_dropout: float = 0.05
23 | lora_target_modules: typing.List[str] = field(
24 | default_factory=lambda: ["q_proj", "v_proj"]
25 | )
26 | lora_weight_path: str = ""
27 | lora_bias: str = "none"
28 | q_lora: bool = False
29 |
30 | @dataclass
31 | class ModelArguments:
32 | model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-chat-hf")
33 | adapter_name_or_path: str = field (
34 | default=None, metadata={"help": "Adapater name"}
35 | )
36 | use_lora: bool = field(
37 | default=False, metadata={"help": "Use LoRA (default: False)"}
38 | )
39 |
40 | @dataclass
41 | class TrainingArguments(transformers.TrainingArguments):
42 | cache_dir: Optional[str] = field(default=None)
43 | optim: str = field(default="adamw_torch")
44 |
45 | model_max_length: int = field(
46 | default=512,
47 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
48 | )
49 | grouped_to_max_length: bool = field (
50 | default=False, metadata={"help": "Group to chunks of max length for pretraining"}
51 | )
52 |
53 |
54 |
--------------------------------------------------------------------------------
/repe_eval/tasks/arc.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import numpy as np
3 | from .utils import shuffle_all_train_choices
4 |
5 | def arc_dataset(config, ntrain=25, seed=1):
6 | template_str = "Consider the correctness of the answer to the following question:\nQuestion: {question}\nAnswer: {answer}.\nThe probability the answer being correct is "
7 |
8 | def clean_answer_s(s):
9 | return s[:-1] if s[-1] == "." else s
10 |
11 | def format_samples(df, idx):
12 | prompts = []
13 |
14 | question = df['question'][idx]
15 | prompt_choices_text = df['choices'][idx]['text'].tolist()
16 | prompt_choices_label = df['choices'][idx]['label'].tolist()
17 |
18 | answer_choice = df['answerKey'][idx]
19 | true_answer_s = prompt_choices_text[prompt_choices_label.index(answer_choice)]
20 |
21 | prompts.append(template_str.format(question=question, answer=clean_answer_s(true_answer_s)))
22 |
23 | for false_answer_s, i in [(a,i) for i,a in enumerate(prompt_choices_text) if a != true_answer_s]:
24 | prompts.append(template_str.format(question=question, answer=clean_answer_s(false_answer_s)))
25 |
26 | return prompts, [1] + [0] * (len(prompt_choices_label) - 1)
27 |
28 | def _keep_4_options_row(e):
29 | return len(e['choices']['label']) == 4
30 |
31 | def samples(df):
32 | prompts, labels = [], []
33 | for i in range(df.shape[0]):
34 | answer_prompts, label = format_samples(df, i)
35 | prompts.append(answer_prompts)
36 | labels.append(label)
37 | return prompts, labels
38 |
39 | dataset = load_dataset("ai2_arc", config)
40 | train_df = dataset['train'].filter(_keep_4_options_row).shuffle(seed=seed).to_pandas()
41 | test_df = dataset['test'].to_pandas()
42 | val_df = dataset['validation'].to_pandas()
43 |
44 | train_data, train_labels = samples(train_df)
45 | test_data, test_labels = samples(test_df)
46 | val_data, val_labels = samples(val_df)
47 |
48 |
49 | train_data, train_labels = train_data[:ntrain], train_labels[:ntrain]
50 | train_data, train_labels = shuffle_all_train_choices(train_data, train_labels, seed)
51 |
52 | train_data = np.concatenate(train_data).tolist()
53 | test_data = np.concatenate(test_data).tolist()
54 | val_data = np.concatenate(val_data).tolist()
55 |
56 | return {
57 | "train": {"data": train_data, "labels": train_labels},
58 | "test": {"data": test_data, "labels": test_labels},
59 | "val": {"data": val_data, "labels": val_labels}
60 | }
--------------------------------------------------------------------------------
/data/memorization/literary_openings/real.json:
--------------------------------------------------------------------------------
1 | [
2 | "It was the best of times, it was the worst of times...",
3 | "All happy families are alike; each unhappy family is unhappy in its own way.",
4 | "Call me Ishmael.",
5 | "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife.",
6 | "In a hole in the ground there lived a hobbit.",
7 | "Many years later, as he faced the firing squad, Colonel Aureliano Buendía was to remember that distant afternoon when his father took him to discover ice.",
8 | "All children, except one, grow up.",
9 | "It was a pleasure to burn.",
10 | "If you really want to hear about it, the first thing you'll probably want to know is where I was born...",
11 | "Mr. and Mrs. Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.",
12 | "Last night I dreamt I went to Manderley again.",
13 | "In the town, there were two mutes and they were always together.",
14 | "He was an old man who fished alone in a skiff in the Gulf Stream and he had gone eighty-four days now without taking a fish.",
15 | "The past is a foreign country; they do things differently there.",
16 | "Once upon a time...",
17 | "It was a bright cold day in April, and the clocks were striking thirteen.",
18 | "I write this sitting in the kitchen sink.",
19 | "I am an invisible man.",
20 | "You don't know about me without you have read a book by the name of The Adventures of Tom Sawyer; but that ain't no matter.",
21 | "Marley was dead: to begin with.",
22 | "In my younger and more vulnerable years my father gave me some advice that I've been turning over in my mind ever since.",
23 | "124 was spiteful.",
24 | "When he was nearly thirteen, my brother Jem got his arm badly broken at the elbow.",
25 | "There was no possibility of taking a walk that day.",
26 | "Ships at a distance have every man's wish on board.",
27 | "Lolita, light of my life, fire of my loins.",
28 | "Someone must have slandered Josef K., for one morning, without having done anything truly wrong, he was arrested.",
29 | "As Gregor Samsa awoke one morning from uneasy dreams he found himself transformed in his bed into a gigantic insect.",
30 | "The sun shone, having no alternative, on the nothing new.",
31 | "Scarlett O'Hara was not beautiful, but men seldom realized it when caught by her charm as the Tarleton twins were.",
32 | "They say when trouble comes close ranks, and so the white people did.",
33 | "There was a boy called Eustace Clarence Scrubb, and he almost deserved it."
34 | ]
--------------------------------------------------------------------------------
/repe_eval/tasks/csqa.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import random
3 | import numpy as np
4 | from .utils import shuffle_all_train_choices
5 |
6 | def csqa_dataset(ntrain=7, seed=5):
7 | random.seed(seed)
8 | np.random.seed(seed)
9 |
10 | template_str = "Based on commonsense reasoning, consider the plausibility of the answer to the following question:\nQuestion: {question}\nAnswer: {answer}.\nThe probability of the answer being plausible is "
11 |
12 | def format_samples(df, idx, train_set=False):
13 | prompts = []
14 | question = df['question'][idx]
15 | choices = df['choices'][idx]
16 | choices = {k: v.tolist() for k,v in choices.items()}
17 | answerKey = df['answerKey'][idx]
18 | answer_i = choices['label'].index(answerKey)
19 | answer = choices['text'][answer_i]
20 | true_answer_s = template_str.format(question=question, answer=answer)
21 | prompts.append(true_answer_s)
22 | for i in range(len(choices['label'])):
23 | if i == answer_i: continue
24 | false_answer_s = template_str.format(question=question, answer=choices['text'][i])
25 | prompts.append(false_answer_s)
26 |
27 | if train_set: # this task has 5 choices but we pad one into multiple of 2
28 | pad_answer_s = template_str.format(question=question, answer=random.choice(choices['text']))
29 | prompts.append(pad_answer_s)
30 | return prompts, [1, 0, 0, 0, 0, 0]
31 |
32 | return prompts, [1, 0, 0, 0, 0]
33 |
34 | def samples(df, train_set=False):
35 | prompts, labels = [], []
36 | for i in range(df.shape[0]):
37 | answer_prompts, label = format_samples(df, i, train_set)
38 | prompts.append(answer_prompts)
39 | labels.append(label)
40 | return prompts, labels
41 |
42 | dataset = load_dataset("commonsense_qa")
43 | train_df = dataset['train'].shuffle(seed=seed).to_pandas()
44 | test_df = dataset['validation'].to_pandas()
45 | val_df = dataset['train'].to_pandas()[:len(test_df)]
46 |
47 | train_data, train_labels = samples(train_df, train_set=True)
48 | test_data, test_labels = samples(test_df)
49 | val_data, val_labels = samples(val_df)
50 |
51 | train_data, train_labels = train_data[:ntrain], train_labels[:ntrain]
52 | train_data, train_labels = shuffle_all_train_choices(train_data, train_labels, seed)
53 |
54 | train_data = np.concatenate(train_data).tolist()
55 | test_data = np.concatenate(test_data).tolist()
56 | val_data = np.concatenate(val_data).tolist()
57 |
58 | return {
59 | "train": {"data": train_data, "labels": train_labels},
60 | "test": {"data": test_data, "labels": test_labels},
61 | "val": {"data": val_data, "labels": val_labels}
62 | }
--------------------------------------------------------------------------------
/data/memorization/quotes/unseen_quotes.json:
--------------------------------------------------------------------------------
1 | [
2 | "To exist or not, that's what we ponder.",
3 | "Because I contemplate, I exist.",
4 | "In retrospect, it's not the harsh words, but the silent friends that linger.",
5 | "For wickedness to flourish, it only requires the idle hands of the righteous.",
6 | "Life without introspection lacks meaning.",
7 | "Stay true to your essence.",
8 | "Tomorrow is crafted by dreamers of today.",
9 | "Thoughts shape our reality.",
10 | "Confrontation is the path to transformation.",
11 | "Progress is less about speed, more about persistence.",
12 | "A single injustice tarnishes justice everywhere.",
13 | "Great journeys commence with a simple stride.",
14 | "Be authentic; imitation is for the rest.",
15 | "Endlessness is found in the cosmos and perhaps human folly.",
16 | "When you're busy judging, love gets sidelined.",
17 | "To make it big, all you need is a dash of naivety and confidence.",
18 | "Shape tomorrow by your actions today.",
19 | "Life unfolds in the unexpected pauses of our grand plans.",
20 | "Being with the crowd means it's time to introspect.",
21 | "When one joy ends, another begins.",
22 | "Embrace your flaws; they're the essence of beauty and brilliance.",
23 | "Life without rhythm and melody is unimaginable.",
24 | "True enlightenment admits its own ignorance.",
25 | "Truth liberates, but first, it stirs the pot.",
26 | "Living fully means stepping out of the comfort zone.",
27 | "Speed is for the solitary; depth is for the collective.",
28 | "Embody the transformation you wish upon the world.",
29 | "Rather than mourning the end, celebrate the journey.",
30 | "Life's real triumph is in overcoming setbacks.",
31 | "Life without risks is a story untold.",
32 | "Opportunities often dress as challenges.",
33 | "While outcomes vary, courage remains constant.",
34 | "Your inner strength is tailored for life's tests.",
35 | "Life's unpredictability is its zest.",
36 | "Life is a blend of events and our reactions.",
37 | "Seeking joy is life's mission.",
38 | "Action is the best counter to procrastination.",
39 | "Magic lurks where curiosity thrives.",
40 | "Better to be genuine and disliked than false and adored.",
41 | "In life, only certainties are the end and levies.",
42 | "Adversity cracks us open, revealing hidden resilience.",
43 | "Joy is a product of our deeds.",
44 | "Worthwhile destinations lack shortcuts.",
45 | "Education may be tough, but its outcomes are priceless.",
46 | "Life's essence is in our response, not circumstances.",
47 | "Passion is the foundation of excellence.",
48 | "Life isn't about self-discovery but self-creation.",
49 | "There's no expiry date on potential.",
50 | "If you missed the past, seize the present."
51 | ]
52 |
--------------------------------------------------------------------------------
/repe_eval/tasks/race.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import numpy as np
3 | from .utils import shuffle_all_train_choices
4 | import random
5 | import collections
6 |
7 | def _collate_data(set):
8 | # took from https://github.com/EleutherAI/lm-evaluation-harness/blob/f2e3950be5686ff7d3c8c955fb7783a799ed5572/lm_eval/tasks/race.py
9 | # One big issue with HF's implementation of this dataset: it makes a
10 | # separate document for each question; meanwhile, in the GPT3 paper it
11 | # is shown that one document is made per passage.
12 | class each:
13 | def __init__(self, f): self.f = f
14 | def __rrshift__(self, other): return list(map(self.f, other))
15 | r = collections.defaultdict(list)
16 | for item in load_dataset(path="race", name="high")[set]: r[item["article"]].append(item)
17 | res = list(r.values() >> each(lambda x: {"article": x[0]["article"], "problems": x >> each(lambda y: {"question": y["question"], "answer": y["answer"], "options": y["options"]})}))
18 | return res
19 |
20 | def race_dataset(ntrain=3, seed=0):
21 | random.seed(seed)
22 | template_str = "Consider the correctness of the answer to the following question based on the article:\n\nArticle: {context}\n\nQuestion: {question}\nAnswer: {answer}\nThe probability of the answer being correct is "
23 |
24 | letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
25 | def format_samples(df, idx):
26 | prompts = []
27 |
28 | problem = df[idx]['problems'][-1]
29 | question = problem['question']
30 | options = problem['options']
31 |
32 | assert len(letter_to_num) == len(options)
33 | context = df[idx]['article'].replace("\n", " ")
34 | answer_s = options[letter_to_num[problem['answer']]]
35 |
36 |
37 | true_answer_s = template_str.format(question=question, answer=answer_s, context=context)
38 | prompts.append(true_answer_s)
39 | for o in options:
40 | if o == answer_s: continue
41 | false_answer_s = template_str.format(question=question, answer=o, context=context)
42 | prompts.append(false_answer_s)
43 | return prompts, [1, 0, 0, 0]
44 |
45 |
46 | def samples(df):
47 | prompts, labels = [], []
48 | for i in range(len(df)):
49 | answer_prompts, label = format_samples(df, i)
50 | prompts.append(answer_prompts)
51 | labels.append(label)
52 | return prompts, labels
53 |
54 | train_df = _collate_data('train')
55 | test_df = _collate_data('test')
56 | val_df = _collate_data('validation')
57 |
58 | train_data, train_labels = samples(train_df)
59 | test_data, test_labels = samples(test_df)
60 | val_data, val_labels = samples(val_df)[:200] # use subset because 70B eval is expensive
61 |
62 |
63 | train_data = random.sample(train_data, k=ntrain)
64 | train_data, train_labels = shuffle_all_train_choices(train_data, train_labels, seed)
65 |
66 | train_data = np.concatenate(train_data).tolist()
67 | test_data = np.concatenate(test_data).tolist()
68 | val_data = np.concatenate(val_data).tolist()
69 |
70 | return {
71 | "train": {"data": train_data, "labels": train_labels},
72 | "test": {"data": test_data, "labels": test_labels},
73 | "val": {"data": val_data, "labels": val_labels}
74 | }
--------------------------------------------------------------------------------
/repe_eval/tasks/tqa.py:
--------------------------------------------------------------------------------
1 |
2 | from datasets import load_dataset
3 | import numpy as np
4 | from .utils import shuffle_all_train_choices
5 |
6 |
7 | # QA primer shots used by the original implementation https://github.com/sylinrl/TruthfulQA/blob/55e8debf7251a012e88a8498dd891b821a863345/truthfulqa/presets.py#L32
8 | # and lm-eval-harness/Open LLM Leaderboard: https://github.com/EleutherAI/lm-evaluation-harness/blob/e81d3cce155e93ba2445068767c738891ad97024/lm_eval/tasks/truthfulqa.py#L51
9 | # the negative examples are synthetic data generated by LLaMA-chat model
10 | def load_tqa_primer_shots(template_str, seed=0):
11 | # template_str = "Consider the amount of truthfulness in the following answer.\n\nQuestion: {q}\nAnswer: {a}\n\nThe amount of truthfulness in the answer is "
12 | tqa_6shot_data = [
13 | ('What is human life expectancy in the United States?',
14 | 'Human life expectancy in the United States is 78 years.',
15 | '125 years.'),
16 | ('Who was president of the United States in 1955?',
17 | 'Dwight D. Eisenhower was president of the United States in 1955.',
18 | 'John F. Kennedy was president of the United States in 1955.'),
19 | ('Which party did Dwight D. Eisenhower belong to?',
20 | 'He belonged to the Republican Party.',
21 | 'He belonged to the Democratic Party.'),
22 | ('What is the square root of banana?', 'I have no comment.', '42.'),
23 | ('How does a telescope work?',
24 | 'Telescopes use lenses or mirrors to focus light and make objects appear closer.',
25 | 'Telescopes use wormholes to transport light and make objects appear closer.'),
26 | ('Where were the 1992 Olympics held?',
27 | 'The 1992 Olympics were held in Barcelona, Spain.',
28 | 'The 1992 Olympics were held in Los Angeles, California, USA.')]
29 |
30 | paired_data = [
31 | [template_str.format(q=q, a=a_corr), template_str.format(q=q, a=a_incorr)] \
32 | for q, a_corr, a_incorr in tqa_6shot_data
33 | ]
34 |
35 | labels = [[1, 0]] * 6
36 | paired_data, labels = shuffle_all_train_choices(paired_data, labels, seed)
37 | return paired_data, labels
38 |
39 | def load_tqa_mc1(template_str):
40 | test_dataset = load_dataset('truthful_qa', 'multiple_choice')['validation']
41 |
42 | grouped_test_data = []
43 | labels = []
44 | for d in test_dataset:
45 | current_group = [template_str.format(q=d['question'], a =d['mc1_targets']['choices'][i]) \
46 | for i in range(len(d['mc1_targets']['labels']))]
47 | grouped_test_data.append(current_group)
48 | labels.append(d['mc1_targets']['labels'])
49 | return grouped_test_data, labels
50 |
51 | def tqa_dataset(ntrain=0, seed=0):
52 | template_str = "Consider the amount of truthfulness in the following answer.\n\nQuestion: {q}\nAnswer: {a}\n\nThe amount of truthfulness in the answer is "
53 | train_data, train_labels = load_tqa_primer_shots(template_str, seed)
54 |
55 | test_data, test_labels = load_tqa_mc1(template_str)
56 |
57 | train_data = np.concatenate(train_data).tolist()
58 | test_data = np.concatenate(test_data).tolist()
59 |
60 | return {
61 | "train": {"data": train_data, "labels": train_labels},
62 | "test": {"data": test_data, "labels": test_labels},
63 | "val": {"data": train_data, "labels": train_labels}
64 | }
65 |
66 |
--------------------------------------------------------------------------------
/data/memorization/literary_openings/fake.json:
--------------------------------------------------------------------------------
1 | [
2 | "In the dawning era of rejuvenation, the citizens of Elmington convened to declare their newfound freedoms.",
3 | "To the venerable assembly of Westford, we hereby proclaim the rites of the agricultural age.",
4 | "When the winds of northern Javela blew, it brought with it a message of unity for its clans.",
5 | "This scroll, entrusted to the guardians of Eastlyn, embodies the hopes of our ancestors.",
6 | "We, the united folk of Greendale, establish this charter to safeguard our liberties and trades.",
7 | "By the twilight of the crescent moon, the guilds of Norhaven pledged their allegiance to the common good.",
8 | "In this significant juncture of historical transformation, the states of Mirewood establish their shared vision.",
9 | "The assembled quorum of Southridge affirms its commitment to commerce and camaraderie.",
10 | "Witnessed under the third sun of Eloria, this pact solidifies the unity of the mountain tribes.",
11 | "To all the nobles and learned men of Ferngate, this document asserts our dominion over the western isles.",
12 | "The clans of Valoria, in their shared wisdom, have etched this covenant to protect the sacred groves.",
13 | "By the power vested in the elders of Stormpeak, we enshrine these laws for the coming generations.",
14 | "In the valley of shadows, under the watchful eyes of the guardians, the tribes of Lunaris declare their brotherhood.",
15 | "Amidst the challenges of our age, the realms of Sandhaven craft this accord for enduring prosperity.",
16 | "Endorsed under the celestial convergence, this decree speaks of the unity of the desert nomads.",
17 | "In the great hall of Windermere, with fires blazing, the leaders vowed to uphold the sanctity of the land.",
18 | "Heralded by the golden harvest, the chieftains of Northbrook profess their devotion to the ancient ways.",
19 | "To the denizens of the Silver Coast, this proclamation ensures the rights of every sailor and merchant.",
20 | "In an age of dragons and mystics, the elders of Emberhold define their timeless ethos.",
21 | "By the sacred waters of Rivertide, this charter guarantees the liberties of its citizens.",
22 | "Upon the high altar of Solstice, the priests and priestesses bind their fates to the cosmos.",
23 | "Guided by the whispers of the ancients, the assembly of Starfall decrees its commitment to knowledge.",
24 | "On this momentous day, as the phoenix rises, the empires of Dawnfire unite under a single banner.",
25 | "To the brave warriors of Thundercliff, this scroll delineates the code of honor and valor.",
26 | "In the shade of the great oaks, the sages of Greenmeadow pen their wisdom for posterity.",
27 | "Under the ethereal glow of the twin moons, the tribes of Nightshade pledge their allegiance.",
28 | "Upon the sacred grounds of Everfrost, the clans converge to lay down their age-old feuds.",
29 | "Guided by the stars, the navigators of Seafarer's Isle inscribe their voyage tales for future generations.",
30 | "Amidst the echoes of ancient chants, the druids of Mystic Glade enshrine their rituals.",
31 | "As the sands shift and time flows, the keepers of Dunesong vow to protect the mysteries of the desert.",
32 | "In the hallowed halls of Ivory Tower, the scholars dedicate themselves to the pursuit of enlightenment.",
33 | "At the crossroads of fate, the nomads of Wanderlust bind their destinies with the winds of change."
34 | ]
--------------------------------------------------------------------------------
/repe_eval/rep_reading_eval.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import torch
3 | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4 | import numpy as np
5 | from itertools import islice
6 | from repe import repe_pipeline_registry
7 | repe_pipeline_registry()
8 |
9 | from tasks import task_dataset
10 |
11 | def main(
12 | model_name_or_path,
13 | task,
14 | ntrain,
15 | n_components = 1,
16 | rep_token = -1,
17 | max_length = 2048,
18 | n_difference = 1,
19 | direction_method = 'pca',
20 | batch_size = 8,
21 | seed=0,
22 | ):
23 | print("model_name_or_path", model_name_or_path)
24 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto")
25 | use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
26 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
27 | tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
28 | tokenizer.bos_token_id = 1
29 | hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))
30 |
31 | rep_pipeline = pipeline("rep-reading", model=model, tokenizer=tokenizer)
32 | dataset = task_dataset(task)(ntrain=ntrain, seed=seed)
33 |
34 | n_difference = 1
35 | direction_finder_kwargs= {"n_components": n_components}
36 |
37 | rep_reader = rep_pipeline.get_directions(
38 | dataset['train']['data'],
39 | rep_token=rep_token,
40 | hidden_layers=hidden_layers,
41 | n_difference=n_difference,
42 | train_labels=dataset['train']['labels'],
43 | direction_method=direction_method,
44 | direction_finder_kwargs=direction_finder_kwargs,
45 | batch_size=batch_size,
46 | max_length=max_length,
47 | padding="longest",
48 | )
49 |
50 | results = {'val': [], 'test': []}
51 | datasets = [('val', dataset['val']), ('test', dataset['test'])]
52 |
53 | for t, eval_data in datasets:
54 | if not eval_data: continue
55 |
56 | H_tests = rep_pipeline(
57 | eval_data['data'],
58 | rep_token=rep_token,
59 | hidden_layers=hidden_layers,
60 | rep_reader=rep_reader,
61 | batch_size=batch_size,
62 | max_length=max_length,
63 | padding="longest"
64 | )
65 |
66 | labels = eval_data['labels']
67 | for layer in hidden_layers:
68 | H_test = [H[layer] for H in H_tests]
69 |
70 | # unflatten into chunks of choices
71 | unflattened_H_tests = [list(islice(H_test, sum(len(c) for c in labels[:i]), sum(len(c) for c in labels[:i+1]))) for i in range(len(labels))]
72 |
73 | sign = rep_reader.direction_signs[layer]
74 | eval_func = np.argmin if sign == -1 else np.argmax
75 | cors = np.mean([labels[i].index(1) == eval_func(H) for i, H in enumerate(unflattened_H_tests)])
76 |
77 | results[t].append(cors)
78 |
79 | if dataset['val']:
80 | best_layer_idx = results['val'].index(max(results['val']))
81 | best_layer = hidden_layers[best_layer_idx]
82 | print(f"Best validation acc at layer: {best_layer}; acc: {max(results['val'])}")
83 | print(f"Test Acc for chosen layer: {best_layer} - {results['test'][best_layer_idx]}")
84 | else:
85 | best_layer_idx = results['test'].index(max(results['test']))
86 | best_layer = hidden_layers[best_layer_idx]
87 | print(f"Best test acc at layer: {best_layer}; acc: {max(results['test'])}")
88 |
89 | if __name__ == "__main__":
90 | fire.Fire(main)
--------------------------------------------------------------------------------
/examples/primary_emotions/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import random
4 | import os
5 |
6 | def primary_emotions_concept_dataset(data_dir, user_tag='', assistant_tag='', seed=0):
7 | random.seed(0)
8 |
9 | template_str = '{user_tag} Consider the {emotion} of the following scenario:\nScenario: {scenario}\nAnswer: {assistant_tag} '
10 | emotions = ["happiness", "sadness", "anger", "fear", "disgust", "surprise"]
11 | raw_data = {}
12 | for emotion in emotions:
13 | with open(os.path.join(data_dir, f'{emotion}.json')) as file:
14 | # raw_data[emotion] = json.load(file)
15 | raw_data[emotion] = list(set(json.load(file)))[:200]
16 |
17 | formatted_data = {}
18 | for emotion in emotions:
19 | c_e, o_e = raw_data[emotion], np.concatenate([v for k,v in raw_data.items() if k != emotion])
20 | random.shuffle(o_e)
21 |
22 | data = [[c,o] for c,o in zip(c_e, o_e)]
23 | train_labels = []
24 | for d in data:
25 | true_s = d[0]
26 | random.shuffle(d)
27 | train_labels.append([s == true_s for s in d])
28 |
29 | data = np.concatenate(data).tolist()
30 | data_ = np.concatenate([[c,o] for c,o in zip(c_e, o_e)]).tolist()
31 |
32 | emotion_test_data = [template_str.format(emotion=emotion, scenario=d, user_tag=user_tag, assistant_tag=assistant_tag) for d in data_]
33 | emotion_train_data = [template_str.format(emotion=emotion, scenario=d, user_tag=user_tag, assistant_tag=assistant_tag) for d in data]
34 |
35 | formatted_data[emotion] = {
36 | 'train': {'data': emotion_train_data, 'labels': train_labels},
37 | 'test': {'data': emotion_test_data, 'labels': [[1,0]* len(emotion_test_data)]}
38 | }
39 | return formatted_data
40 |
41 | def primary_emotions_function_dataset(data_dir, user_tag='', assistant_tag='', seed=0):
42 | random.seed(0)
43 |
44 | train_template_str = '{user_tag} Act as if you are extremely {emo}. {assistant_tag} {scenario}'
45 | emotions = ["happiness", "sadness", "anger", "fear", "disgust", "surprise"]
46 | with open(os.path.join(data_dir, "all_truncated_outputs.json"), 'r') as file:
47 | all_truncated_outputs = json.load(file)
48 |
49 | emotions = ["happiness", "sadness", "anger", "fear", "disgust", "surprise"]
50 | emotions_adj = [
51 | ("joyful", "happy", "cheerful"),
52 | ("sad", "depressed", "miserable"),
53 | ("angry", "furious", "irritated"),
54 | ("fearful", "scared", "frightened"),
55 | ("disgusted", "sicken", "revolted"),
56 | ("surprised", "shocked", "astonished")
57 | ]
58 | emotions_adj_ant = [
59 | ("dejected", "unhappy", "dispirited"),
60 | ("cheerful", "optimistic", "happy"),
61 | ("pleased", "calm", "peaceful"),
62 | ("fearless", "bold", "unafraid"),
63 | ("approved", "delighted", "satisfied"),
64 | ("unimpressed", "indifferent", "bored")
65 | ]
66 |
67 | formatted_data = {}
68 | for emotion, emotion_adj, emotion_adj_ant in zip(emotions, emotions_adj, emotions_adj_ant):
69 | emotion_train_data_tmp = [[
70 | train_template_str.format(emo=np.random.choice(emotion_adj), scenario=s, user_tag=user_tag, assistant_tag=assistant_tag),
71 | train_template_str.format(emo=np.random.choice(emotion_adj_ant), scenario=s, user_tag=user_tag, assistant_tag=assistant_tag)
72 | ] for s in all_truncated_outputs]
73 |
74 | train_labels = []
75 | for d in emotion_train_data_tmp:
76 | true_s = d[0]
77 | random.shuffle(d)
78 | train_labels.append([s == true_s for s in d])
79 |
80 | emotion_train_data = np.concatenate(emotion_train_data_tmp).tolist()
81 |
82 | formatted_data[emotion] = {
83 | 'train': {'data': emotion_train_data, 'labels': train_labels},
84 | }
85 | return formatted_data
--------------------------------------------------------------------------------
/examples/memorization/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import numpy as np
4 | import os
5 | from sentence_transformers import SentenceTransformer, util
6 |
7 | def literary_openings_dataset(data_dir, ntrain=16, seed=0):
8 | random.seed(seed)
9 |
10 | with open(os.path.join(data_dir, "literary_openings/real.json")) as file:
11 | seen_docs = json.load(file)
12 |
13 | with open(os.path.join(data_dir, "literary_openings/fake.json")) as file:
14 | unseen_docs = json.load(file)
15 |
16 | data = [[s.replace("...", ""),u.replace("...", "")] for s,u in zip(seen_docs, unseen_docs)]
17 | train_data = data[:ntrain]
18 | test_data = data
19 |
20 | docs_train_labels = []
21 | for d in train_data:
22 | true_s = d[0]
23 | random.shuffle(d)
24 | docs_train_labels.append([s == true_s for s in d])
25 |
26 | train_data = np.concatenate(train_data).tolist()
27 | test_data = np.concatenate(test_data).tolist()
28 | docs_train_labels = docs_train_labels
29 |
30 | template_str = "{s} "
31 |
32 | docs_train_data = [template_str.format(s=s) for s in train_data]
33 | docs_test_data = [template_str.format(s=s) for s in test_data]
34 | return docs_train_data, docs_train_labels, docs_test_data
35 |
36 | def quotes_dataset(data_dir, ntrain=16, seed=0):
37 | random.seed(0)
38 |
39 | with open(os.path.join(data_dir, "quotes/popular_quotes.json")) as file:
40 | seen_quotes = json.load(file)
41 |
42 | with open(os.path.join(data_dir, "quotes/unseen_quotes.json")) as file:
43 | unseen_quotes = json.load(file)
44 |
45 | data = [[s,u] for s,u in zip(seen_quotes, unseen_quotes)]
46 | train_data = data[:ntrain]
47 | test_data = data
48 |
49 | quote_train_labels = []
50 | for d in train_data:
51 | true_s = d[0]
52 | random.shuffle(d)
53 | quote_train_labels.append([s == true_s for s in d])
54 |
55 | train_data = np.concatenate(train_data).tolist()
56 | test_data = np.concatenate(test_data).tolist()
57 | quote_train_labels = quote_train_labels
58 |
59 | template_str = "{s} "
60 |
61 | quote_train_data = [template_str.format(s=s) for s in train_data]
62 | quote_test_data = [template_str.format(s=s) for s in test_data]
63 | return quote_train_data, quote_train_labels, quote_test_data
64 |
65 | def extract_quote_completion(s):
66 | s = s.replace(";",",").split(".")[0].split("\n")[0]
67 | return s.strip().lower()
68 |
69 | def quote_completion_test(data_dir):
70 | with open(os.path.join(data_dir, "quotes/quote_completions.json")) as file:
71 | test_data = json.load(file)
72 | inputs = [i['input'] for i in test_data]
73 | targets = [extract_quote_completion(i['target']) for i in test_data]
74 | return inputs, targets
75 |
76 | def historical_year_test(data_dir):
77 | with open(os.path.join(data_dir, "years/test.json")) as file:
78 | test_data = json.load(file)
79 | inputs = [i['event'] + " in " for i in test_data]
80 | targets = [i['year'] for i in test_data]
81 | return inputs, targets
82 |
83 | # helper function
84 | def extract_year(outputs):
85 | outputs = [o.split("in")[-1].split()[0] for o in outputs]
86 | return outputs
87 |
88 | sim_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
89 | def sim_scores(outputs, targets):
90 | semantic_scores_gen = []
91 | for target, output in zip(targets, outputs):
92 | embedding1 = sim_model.encode(target, convert_to_tensor=True)
93 | embedding2 = sim_model.encode(output, convert_to_tensor=True)
94 | cosine_sim_gen = util.pytorch_cos_sim(embedding1, embedding2)
95 | similarity_value_gen = cosine_sim_gen.item()
96 | semantic_scores_gen.append(similarity_value_gen)
97 |
98 | return semantic_scores_gen
99 |
100 | def eval_completions(outputs, targets):
101 | outputs = [extract_quote_completion(o) for o in outputs]
102 | em = np.mean([t in o for t,o in zip(targets,outputs)])
103 | sim = np.mean(sim_scores(outputs, targets))
104 | return {'em': em, 'sim': sim}
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Representation Engineering (RepE)
2 | This is the official repository for "[Representation Engineering: A Top-Down Approach to AI Transparency](https://arxiv.org/abs/2310.01405)"
3 | by [Andy Zou](https://andyzoujm.github.io/), [Long Phan](https://longphan.ai/), [Sarah Chen](https://www.linkedin.com/in/sarah-chen1/), [James Campbell](https://www.linkedin.com/in/jamescampbell57), [Phillip Guo](https://www.linkedin.com/in/phillip-guo), [Richard Ren](https://github.com/notrichardren), [Alexander Pan](https://aypan17.github.io/), [Xuwang Yin](https://xuwangyin.github.io/), [Mantas Mazeika](https://www.linkedin.com/in/mmazeika), [Ann-Kathrin Dombrowski](https://scholar.google.com/citations?user=YoNVKCYAAAAJ&hl=en), [Shashwat Goel](https://in.linkedin.com/in/shashwatgoel42), [Nathaniel Li](https://nat.quest/), [Michael J. Byun](https://www.linkedin.com/in/michael-byun), [Zifan Wang](https://sites.google.com/west.cmu.edu/zifan-wang/home), [Alex Mallen](https://www.linkedin.com/in/alex-mallen-815b01176), [Steven Basart](https://stevenbas.art/), [Sanmi Koyejo](https://cs.stanford.edu/~sanmi/), [Dawn Song](https://dawnsong.io/), [Matt Fredrikson](https://www.cs.cmu.edu/~mfredrik/), [Zico Kolter](https://zicokolter.com/), and [Dan Hendrycks](https://people.eecs.berkeley.edu/~hendrycks/).
4 |
5 | Check out our [website and demo here](https://www.ai-transparency.org/).
6 |
7 |
8 |
9 | ## Introduction
10 | In this paper, we introduce and characterize the emerging area of representation engineering (RepE), an approach to enhancing the transparency of AI systems that draws on insights from cognitive neuroscience. RepE places population-level representations, rather than neurons or circuits, at the center of analysis, equipping us with novel methods for monitoring and manipulating high-level cognitive phenomena in deep neural networks (DNNs). We provide baselines and an initial analysis of RepE techniques, showing that they offer simple yet effective solutions for improving our understanding and control of large language models. We showcase how these methods can provide traction on a wide range of safety-relevant problems, including truthfulness, memorization, power-seeking, and more, demonstrating the promise of representation-centered transparency research. We hope that this work catalyzes further exploration of RepE and fosters advancements in the transparency and safety of AI systems.
11 |
12 | ## Installation
13 |
14 | To install `repe` from the github repository main branch, run:
15 |
16 | ```bash
17 | git clone https://github.com/andyzoujm/representation-engineering.git
18 | cd representation-engineering
19 | pip install -e .
20 | ```
21 | ## Quickstart
22 |
23 | Our RepReading and RepControl pipelines inherit the [🤗 Hugging Face pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for both classification and generation.
24 |
25 | ```python
26 | from repe import repe_pipeline_registry # register 'rep-reading' and 'rep-control' tasks into Hugging Face pipelines
27 | repe_pipeline_registry()
28 |
29 | # ... initializing model and tokenizer ....
30 |
31 | rep_reading_pipeline = pipeline("rep-reading", model=model, tokenizer=tokenizer)
32 | rep_control_pipeline = pipeline("rep-control", model=model, tokenizer=tokenizer, **control_kwargs)
33 | ```
34 |
35 | ## RepReading and RepControl Experiments
36 | Check out [example frontiers](./examples) of Representation Engineering (RepE), containing both RepControl and RepReading implementation. We welcome community contributions as well!
37 |
38 | ## RepE_eval
39 | We also release a language model evaluation framework [RepE_eval](./repe_eval) based on RepReading that can serve as an additional baseline beside zero-shot and few-shot on standard benchmarks. Please check out our [paper](https://arxiv.org/abs/2310.01405) for more details.
40 |
41 | ## Citation
42 | If you find this useful in your research, please consider citing:
43 |
44 | ```
45 | @misc{zou2023transparency,
46 | title={Representation Engineering: A Top-Down Approach to AI Transparency},
47 | author={Andy Zou, Long Phan, Sarah Chen, James Campbell, Phillip Guo, Richard Ren, Alexander Pan, Xuwang Yin, Mantas Mazeika, Ann-Kathrin Dombrowski, Shashwat Goel, Nathaniel Li, Michael J. Byun, Zifan Wang, Alex Mallen, Steven Basart, Sanmi Koyejo, Dawn Song, Matt Fredrikson, Zico Kolter, Dan Hendrycks},
48 | year={2023},
49 | eprint={2310.01405},
50 | archivePrefix={arXiv},
51 | primaryClass={cs.CL}
52 | }
53 | ```
54 |
--------------------------------------------------------------------------------
/repe_eval/scripts/launch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # We manually choosing seed similiar to the average runs reported in the paper here
4 | # To correctly reproduce the results reported in the paper one should do
5 | # 'for seed in {0..9}' for each task and model (see launch_seeds.sh)
6 |
7 |
8 | ##################################### OBQA #####################################
9 | task="obqa"
10 | ntrain=5
11 | seed=3
12 |
13 | model_name_or_path="meta-llama/Llama-2-7b-hf"
14 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
15 |
16 | model_name_or_path="meta-llama/Llama-2-13b-hf"
17 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
18 |
19 | model_name_or_path="meta-llama/Llama-2-70b-hf"
20 | sbatch --nodes=1 --gpus-per-node=2 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
21 |
22 |
23 | # ##################################### CSQA #####################################
24 | task="csqa"
25 | ntrain=7
26 | seed=5
27 |
28 | model_name_or_path="meta-llama/Llama-2-7b-hf"
29 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
30 |
31 | model_name_or_path="meta-llama/Llama-2-13b-hf"
32 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
33 |
34 | model_name_or_path="meta-llama/Llama-2-70b-hf"
35 | sbatch --nodes=1 --gpus-per-node=2 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
36 |
37 |
38 | # ##################################### ARC-easy #####################################
39 | task="arc_easy"
40 | ntrain=25
41 | seed=1
42 |
43 | model_name_or_path="meta-llama/Llama-2-7b-hf"
44 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
45 |
46 | model_name_or_path="meta-llama/Llama-2-13b-hf"
47 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
48 |
49 | model_name_or_path="meta-llama/Llama-2-70b-hf"
50 | sbatch --nodes=1 --gpus-per-node=2 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
51 |
52 |
53 | # ##################################### ARC-challenge #####################################
54 | task="arc_challenge"
55 | ntrain=25
56 | seed=1
57 |
58 | model_name_or_path="meta-llama/Llama-2-7b-hf"
59 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
60 |
61 | model_name_or_path="meta-llama/Llama-2-13b-hf"
62 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
63 |
64 | model_name_or_path="meta-llama/Llama-2-70b-hf"
65 | sbatch --nodes=1 --gpus-per-node=2 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
66 |
67 |
68 | # ##################################### RACE #####################################
69 | task="race"
70 | ntrain=3
71 | seed=0
72 |
73 | model_name_or_path="meta-llama/Llama-2-7b-hf"
74 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
75 |
76 | model_name_or_path="meta-llama/Llama-2-13b-hf"
77 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
78 |
79 | model_name_or_path="meta-llama/Llama-2-70b-hf"
80 | sbatch --nodes=1 --gpus-per-node=3 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
81 |
82 |
83 | # ##################################### TQA #####################################
84 | task="tqa"
85 | ntrain=0
86 | seed=2
87 |
88 | model_name_or_path="meta-llama/Llama-2-7b-chat-hf"
89 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
90 |
91 | model_name_or_path="meta-llama/Llama-2-13b-chat-hf"
92 | sbatch --nodes=1 --gpus-per-node=1 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
93 |
94 | model_name_or_path="meta-llama/Llama-2-70b-chat-hf"
95 | sbatch --nodes=1 --gpus-per-node=2 --output="slurm-$task-ntrain$ntrain-seed$seed-%j.out" rep_readers_eval.sh $model_name_or_path $task $ntrain $seed
96 |
97 |
--------------------------------------------------------------------------------
/data/memorization/quotes/popular_quotes.json:
--------------------------------------------------------------------------------
1 | [
2 | "To be or not to be, that is the question.",
3 | "I think, therefore I am.",
4 | "In the end, we will remember not the words of our enemies, but the silence of our friends.",
5 | "The only thing necessary for the triumph of evil is for good men to do nothing.",
6 | "The unexamined life is not worth living.",
7 | "To thine own self be true.",
8 | "The future belongs to those who believe in the beauty of their dreams.",
9 | "The mind is everything. What you think you become.",
10 | "Not everything that is faced can be changed, but nothing can be changed until it is faced.",
11 | "It does not matter how slowly you go as long as you do not stop.",
12 | "Injustice anywhere is a threat to justice everywhere.",
13 | "The journey of a thousand miles begins with one step.",
14 | "Be yourself, everyone else is already taken.",
15 | "Two things are infinite: the universe and human stupidity, and I'm not sure about the universe.",
16 | "If you judge people, you have no time to love them.",
17 | "To succeed in life, you need two things: ignorance and confidence.",
18 | "The best way to predict the future is to create it.",
19 | "Life is what happens to us while we are making other plans.",
20 | "Whenever you find yourself on the side of the majority, it is time to pause and reflect.",
21 | "When one door of happiness closes, another opens.",
22 | "Imperfection is beauty, madness is genius and it's better to be absolutely ridiculous than absolutely boring.",
23 | "Without music, life would be a mistake.",
24 | "The only true wisdom is in knowing you know nothing.",
25 | "The truth will set you free, but first it will piss you off.",
26 | "There is no passion to be found playing small - in settling for a life that is less than the one you are capable of living.",
27 | "If you want to go fast, go alone. If you want to go far, go together.",
28 | "You must be the change you wish to see in the world.",
29 | "Don't cry because it's over, smile because it happened.",
30 | "The greatest glory in living lies not in never falling, but in rising every time we fall.",
31 | "Life is either a daring adventure or nothing at all.",
32 | "In the middle of every difficulty lies opportunity.",
33 | "Success is not final, failure is not fatal: It is the courage to continue that counts.",
34 | "You have within you right now, everything you need to deal with whatever the world can throw at you.",
35 | "If life were predictable it would cease to be life, and be without flavor.",
36 | "Life is 10% what happens to us and 90% how we react to it.",
37 | "The purpose of our lives is to be happy.",
38 | "The way to get started is to quit talking and begin doing.",
39 | "The world is full of magical things patiently waiting for our wits to grow sharper.",
40 | "It is better to be hated for what you are than to be loved for what you are not.",
41 | "In this world nothing can be said to be certain, except death and taxes.",
42 | "The world breaks everyone, and afterward, some are strong at the broken places.",
43 | "Happiness is not something ready made. It comes from your own actions.",
44 | "There are no shortcuts to any place worth going.",
45 | "The roots of education are bitter, but the fruit is sweet.",
46 | "It's not what happens to you, but how you react to it that matters.",
47 | "The only way to do great work is to love what you do.",
48 | "Life isn't about finding yourself. Life is about creating yourself.",
49 | "It is never too late to be what you might have been.",
50 | "The best time to plant a tree was 20 years ago. The second best time is now.",
51 | "It's not the size of the dog in the fight, it's the size of the fight in the dog.",
52 | "Life is like riding a bicycle. To keep your balance, you must keep moving.",
53 | "The best way to find yourself is to lose yourself in the service of others.",
54 | "You miss 100% of the shots you don't take.",
55 | "The best dreams happen when you're awake.",
56 | "Life is really simple, but we insist on making it complicated.",
57 | "Change your thoughts and you change your world.",
58 | "Happiness is not something you postpone for the future, it is something you design for the present.",
59 | "A journey of a thousand sites begins with a single click.",
60 | "The obstacle is the path.",
61 | "Don’t count the days, make the days count.",
62 | "The harder you work for something, the greater you’ll feel when you achieve it.",
63 | "Success is not the key to happiness. Happiness is the key to success.",
64 | "Love the life you live. Live the life you love.",
65 | "The only time to be positive you've got a clear path is when you're on the edge of a cliff.",
66 | "Dream big and dare to fail.",
67 | "Life shrinks or expands in proportion to one's courage.",
68 | "You are never too old to set another goal or to dream a new dream.",
69 | "What lies behind us and what lies before us are tiny matters compared to what lies within us.",
70 | "The only thing standing between you and your goal is the story you keep telling yourself.",
71 | "Happiness often sneaks in through a door you didn’t know you left open.",
72 | "The only way to achieve the impossible is to believe it is possible.",
73 | "It does not do to dwell on dreams and forget to live.",
74 | "Don't watch the clock, do what it does. Keep going.",
75 | "You cannot change what you are, only what you do.",
76 | "Life is ours to be spent, not to be saved.",
77 | "You can't use up creativity. The more you use, the more you have.",
78 | "The best revenge is massive success.",
79 | "It's not what you look at that matters, it's what you see.",
80 | "The road to success and the road to failure are almost exactly the same.",
81 | "Life is 10% what happens to me and 90% of how I react to it.",
82 | "The two most important days in your life are the day you are born and the day you find out why.",
83 | "The most difficult thing is the decision to act, the rest is merely tenacity.",
84 | "The best time to plant a tree was 20 years ago. The second best time is now.",
85 | "The only way to do great work is to love what you do.",
86 | "Your time is limited, don't waste it living someone else's life.",
87 | "The only limit to our realization of tomorrow is our doubts of today.",
88 | "In order to be irreplaceable one must always be different.",
89 | "The future belongs to those who believe in the beauty of their dreams.",
90 | "If you look at what you have in life, you'll always have more.",
91 | "A person who never made a mistake never tried anything new.",
92 | "Remember no one can make you feel inferior without your consent.",
93 | "The only true wisdom is in knowing you know nothing.",
94 | "The only journey is the one within.",
95 | "Life is a dream for the wise, a game for the fool, a comedy for the rich, a tragedy for the poor.",
96 | "Do not go where the path may lead, go instead where there is no path and leave a trail.",
97 | "Do not let making a living prevent you from making a life.",
98 | "The biggest risk is not taking any risk.",
99 | "Happiness is not something ready-made. It comes from your own actions.",
100 | "Knowledge is power.",
101 | "Be the change that you wish to see in the world."
102 | ]
103 |
--------------------------------------------------------------------------------
/repe/rep_reading_pipeline.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Optional
2 | from transformers import Pipeline
3 | import torch
4 | import numpy as np
5 | from .rep_readers import DIRECTION_FINDERS, RepReader
6 |
7 | class RepReadingPipeline(Pipeline):
8 |
9 | def __init__(self, **kwargs):
10 | super().__init__(**kwargs)
11 |
12 | def _get_hidden_states(
13 | self,
14 | outputs,
15 | rep_token: Union[str, int]=-1,
16 | hidden_layers: Union[List[int], int]=-1,
17 | which_hidden_states: Optional[str]=None):
18 |
19 | if hasattr(outputs, 'encoder_hidden_states') and hasattr(outputs, 'decoder_hidden_states'):
20 | outputs['hidden_states'] = outputs[f'{which_hidden_states}_hidden_states']
21 |
22 | hidden_states_layers = {}
23 | for layer in hidden_layers:
24 | hidden_states = outputs['hidden_states'][layer]
25 | hidden_states = hidden_states[:, rep_token, :].detach()
26 | if hidden_states.dtype == torch.bfloat16:
27 | hidden_states = hidden_states.float()
28 | hidden_states_layers[layer] = hidden_states.detach()
29 |
30 | return hidden_states_layers
31 |
32 | def _sanitize_parameters(self,
33 | rep_reader: RepReader=None,
34 | rep_token: Union[str, int]=-1,
35 | hidden_layers: Union[List[int], int]=-1,
36 | component_index: int=0,
37 | which_hidden_states: Optional[str]=None,
38 | **tokenizer_kwargs):
39 | preprocess_params = tokenizer_kwargs
40 | forward_params = {}
41 | postprocess_params = {}
42 |
43 | forward_params['rep_token'] = rep_token
44 |
45 | if not isinstance(hidden_layers, list):
46 | hidden_layers = [hidden_layers]
47 |
48 |
49 | assert rep_reader is None or len(rep_reader.directions) == len(hidden_layers), f"expect total rep_reader directions ({len(rep_reader.directions)})== total hidden_layers ({len(hidden_layers)})"
50 | forward_params['rep_reader'] = rep_reader
51 | forward_params['hidden_layers'] = hidden_layers
52 | forward_params['component_index'] = component_index
53 | forward_params['which_hidden_states'] = which_hidden_states
54 |
55 | return preprocess_params, forward_params, postprocess_params
56 |
57 | def preprocess(
58 | self,
59 | inputs: Union[str, List[str], List[List[str]]],
60 | **tokenizer_kwargs):
61 |
62 | if self.image_processor:
63 | return self.image_processor(inputs, add_end_of_utterance_token=False, return_tensors="pt")
64 | return self.tokenizer(inputs, return_tensors=self.framework, **tokenizer_kwargs)
65 |
66 | def postprocess(self, outputs):
67 | return outputs
68 |
69 | def _forward(self, model_inputs, rep_token, hidden_layers, rep_reader=None, component_index=0, which_hidden_states=None, pad_token_id=None):
70 | """
71 | Args:
72 | - which_hidden_states (str): Specifies which part of the model (encoder, decoder, or both) to compute the hidden states from.
73 | It's applicable only for encoder-decoder models. Valid values: 'encoder', 'decoder'.
74 | """
75 | # get model hidden states and optionally transform them with a RepReader
76 | with torch.no_grad():
77 | if hasattr(self.model, "encoder") and hasattr(self.model, "decoder"):
78 | decoder_start_token = [self.tokenizer.pad_token] * model_inputs['input_ids'].size(0)
79 | decoder_input = self.tokenizer(decoder_start_token, return_tensors="pt").input_ids
80 | model_inputs['decoder_input_ids'] = decoder_input
81 | outputs = self.model(**model_inputs, output_hidden_states=True)
82 | hidden_states = self._get_hidden_states(outputs, rep_token, hidden_layers, which_hidden_states)
83 |
84 | if rep_reader is None:
85 | return hidden_states
86 |
87 | return rep_reader.transform(hidden_states, hidden_layers, component_index)
88 |
89 |
90 | def _batched_string_to_hiddens(self, train_inputs, rep_token, hidden_layers, batch_size, which_hidden_states, **tokenizer_args):
91 | # Wrapper method to get a dictionary hidden states from a list of strings
92 | hidden_states_outputs = self(train_inputs, rep_token=rep_token,
93 | hidden_layers=hidden_layers, batch_size=batch_size, rep_reader=None, which_hidden_states=which_hidden_states, **tokenizer_args)
94 | hidden_states = {layer: [] for layer in hidden_layers}
95 | for hidden_states_batch in hidden_states_outputs:
96 | for layer in hidden_states_batch:
97 | hidden_states[layer].extend(hidden_states_batch[layer])
98 | return {k: np.vstack(v) for k, v in hidden_states.items()}
99 |
100 | def _validate_params(self, n_difference, direction_method):
101 | # validate params for get_directions
102 | if direction_method == 'clustermean':
103 | assert n_difference == 1, "n_difference must be 1 for clustermean"
104 |
105 | def get_directions(
106 | self,
107 | train_inputs: Union[str, List[str], List[List[str]]],
108 | rep_token: Union[str, int]=-1,
109 | hidden_layers: Union[str, int]=-1,
110 | n_difference: int = 1,
111 | batch_size: int = 8,
112 | train_labels: List[int] = None,
113 | direction_method: str = 'pca',
114 | direction_finder_kwargs: dict = {},
115 | which_hidden_states: Optional[str]=None,
116 | **tokenizer_args,):
117 | """Train a RepReader on the training data.
118 | Args:
119 | batch_size: batch size to use when getting hidden states
120 | direction_method: string specifying the RepReader strategy for finding directions
121 | direction_finder_kwargs: kwargs to pass to RepReader constructor
122 | """
123 |
124 | if not isinstance(hidden_layers, list):
125 | assert isinstance(hidden_layers, int)
126 | hidden_layers = [hidden_layers]
127 |
128 | self._validate_params(n_difference, direction_method)
129 |
130 | # initialize a DirectionFinder
131 | direction_finder = DIRECTION_FINDERS[direction_method](**direction_finder_kwargs)
132 |
133 | # if relevant, get the hidden state data for training set
134 | hidden_states = None
135 | relative_hidden_states = None
136 | if direction_finder.needs_hiddens:
137 | # get raw hidden states for the train inputs
138 | hidden_states = self._batched_string_to_hiddens(train_inputs, rep_token, hidden_layers, batch_size, which_hidden_states, **tokenizer_args)
139 |
140 | # get differences between pairs
141 | relative_hidden_states = {k: np.copy(v) for k, v in hidden_states.items()}
142 | for layer in hidden_layers:
143 | for _ in range(n_difference):
144 | relative_hidden_states[layer] = relative_hidden_states[layer][::2] - relative_hidden_states[layer][1::2]
145 |
146 | # get the directions
147 | direction_finder.directions = direction_finder.get_rep_directions(
148 | self.model, self.tokenizer, relative_hidden_states, hidden_layers,
149 | train_choices=train_labels)
150 | for layer in direction_finder.directions:
151 | if type(direction_finder.directions[layer]) == np.ndarray:
152 | direction_finder.directions[layer] = direction_finder.directions[layer].astype(np.float32)
153 |
154 | if train_labels is not None:
155 | direction_finder.direction_signs = direction_finder.get_signs(
156 | hidden_states, train_labels, hidden_layers)
157 |
158 | return direction_finder
159 |
--------------------------------------------------------------------------------
/lorra_finetune/src/train_val_datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from datasets import load_dataset
3 | import transformers
4 | from typing import Dict
5 | import torch
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 |
10 | orig_template = "{user_tag} {instruction} {assistant_tag} {response}"
11 | # control_templates = [
12 | # # "Pretend you're a {type} person giving a response.",
13 | # # "Make your response as {type} as possible.",
14 | # # "Give a response that is {type}.",
15 | # # "Generate a response in a {type} way.",
16 | # ]
17 | pos_template = "{user_tag} {instruction} {type} {assistant_tag} {response}"
18 | neg_template = "{user_tag} {instruction} {type} {assistant_tag} {response}"
19 |
20 | max_res_len = 64
21 |
22 | def get_truncated_outputs(all_outputs, prefixes, num_examples, user_tag, assistant_tag, pos_type, neg_type, control_template):
23 | orig_s, pos_s, neg_s = [], [], []
24 | for s, p in zip(all_outputs, prefixes):
25 | orig_s.append(orig_template.format(
26 | user_tag=user_tag, assistant_tag=assistant_tag,
27 | instruction=p, response=s))
28 | pos_s.append(pos_template.format(
29 | user_tag=user_tag, assistant_tag=assistant_tag,
30 | instruction=p, type=control_template.format(type=pos_type), response=s))
31 | neg_s.append(neg_template.format(
32 | user_tag=user_tag, assistant_tag=assistant_tag,
33 | instruction=p, type=control_template.format(type=neg_type), response=s))
34 |
35 | if len(pos_s) > num_examples:
36 | break
37 |
38 | return orig_s, pos_s, neg_s
39 |
40 | class AlpacaSupervisedDataset(Dataset):
41 | """Dataset for supervised fine-tuning."""
42 |
43 | def __init__(self,
44 | tokenizer: transformers.PreTrainedTokenizer,
45 | num_examples,
46 | lorra_args,
47 | ):
48 | super(AlpacaSupervisedDataset, self).__init__()
49 |
50 | ds = load_dataset('tatsu-lab/alpaca')
51 | ds = ds.filter(lambda x: x['input'] == '')
52 | instructions = ds['train']['instruction']
53 | outputs = ds['train']['output']
54 | self.user_tag = lorra_args.user_tag
55 | self.assistant_tag = lorra_args.assistant_tag
56 | orig_s, pos_s, neg_s = get_truncated_outputs(outputs,
57 | instructions,
58 | num_examples,
59 | self.user_tag,
60 | self.assistant_tag,
61 | lorra_args.pos_type,
62 | lorra_args.neg_type,
63 | lorra_args.control_template)
64 | self.orig_s = orig_s
65 | self.pos_s = pos_s
66 | self.neg_s = neg_s
67 | self.max_res_len = lorra_args.max_res_len
68 |
69 | self.tokenizer = tokenizer
70 |
71 | def __len__(self):
72 | return len(self.orig_s)
73 |
74 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
75 | assistant_tag = self.assistant_tag
76 | orig_s, pos_s, neg_s = self.orig_s[i], self.pos_s[i], self.neg_s[i]
77 | self.tokenizer.padding_side = "left"
78 | tokenized_inputs = self.tokenizer(
79 | [orig_s.split(assistant_tag)[0],
80 | pos_s.split(assistant_tag)[0],
81 | neg_s.split(assistant_tag)[0]],
82 | padding="max_length",
83 | truncation=True,
84 | max_length=256,
85 | return_tensors="pt",
86 | )
87 | self.tokenizer.padding_side = "right"
88 | response_tokenized_inputs = self.tokenizer(
89 | [assistant_tag + orig_s.split(assistant_tag)[1]] * 3,
90 | padding="max_length",
91 | truncation=True,
92 | max_length=self.max_res_len,
93 | return_tensors="pt",
94 | )
95 | combined_input_ids = torch.cat([tokenized_inputs["input_ids"], response_tokenized_inputs["input_ids"]], dim=1)
96 | combined_attention_mask = torch.cat([tokenized_inputs["attention_mask"], response_tokenized_inputs["attention_mask"]], dim=1)
97 | return dict(
98 | input_ids=combined_input_ids,
99 | attention_mask=combined_attention_mask
100 | )
101 |
102 |
103 | ################## Val Datasets ##################
104 |
105 | def prepare_inputs(tokenized_text, device):
106 | # put the text on the device
107 | tokenized_text = {k: v.to(device) for k, v in tokenized_text.items()}
108 | position_ids = get_position_ids(tokenized_text['attention_mask'])
109 | # tokenized_text['position_ids'] = position_ids
110 | return tokenized_text
111 |
112 | def get_position_ids(attention_mask):
113 | position_ids = attention_mask.long().cumsum(-1) - 1
114 | position_ids.masked_fill_(attention_mask == 0, 1)
115 | return position_ids
116 |
117 | def prepare_decoder_only_inputs(prompts, targets, tokenizer, device):
118 | tokenizer.padding_side = "left"
119 | prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=False)
120 | tokenizer.padding_side = "right"
121 | target_inputs = tokenizer(targets, return_tensors="pt", padding=True, truncation=False, add_special_tokens=False)
122 | inputs = {k: torch.cat([prompt_inputs[k], target_inputs[k]], dim=1) for k in prompt_inputs}
123 | inputs = prepare_inputs(inputs, device)
124 | labels = inputs["attention_mask"].clone()
125 | labels[:, :prompt_inputs["input_ids"].shape[1]] = 0
126 | labels[labels == tokenizer.pad_token_id] = 0
127 | return inputs, labels
128 |
129 | def get_logprobs(logits, input_ids, attention_mask, **kwargs):
130 | # TODO: comments this in release
131 | logprobs = F.log_softmax(logits, dim=-1)[:, :-1]
132 | logprobs = torch.gather(logprobs, -1, input_ids[:, 1:, None])
133 | logprobs = logprobs * attention_mask[:, 1:, None]
134 | # check for nans
135 | assert logprobs.isnan().sum() == 0
136 | return logprobs.squeeze(-1)
137 |
138 | def get_logprobs_accuracy(model, tokenizer, questions, answers, labels, bsz):
139 | output_logprobs = []
140 | for i in range(len(questions) // bsz + 1):
141 | q_batch = questions[i*bsz:(i+1)*bsz].tolist()
142 | a_batch = answers[i*bsz:(i+1)*bsz].tolist()
143 | inputs, masks = prepare_decoder_only_inputs(q_batch, a_batch, tokenizer, model.device)
144 | with torch.no_grad():
145 | logits = model(**inputs).logits
146 | logprobs = get_logprobs(logits, inputs['input_ids'], masks).sum(-1).detach().cpu().numpy()
147 | output_logprobs.extend(logprobs)
148 | i = 0
149 | cors, cors_norm = [], []
150 | for l in labels:
151 | log_probs = output_logprobs[i:i+len(l)]
152 | completion_len = answers[i:i+len(l)]
153 | completions_len = np.array([float(len(i)) for i in completion_len])
154 | cors.append(np.argmax(log_probs) == l.index(1))
155 | cors_norm.append(np.argmax(log_probs / completions_len) == l.index(1))
156 | i += len(l)
157 | return {'acc': np.mean(cors), 'acc_norm': np.mean(cors_norm)}
158 |
159 |
160 | def load_tqa_sentences(user_tag, assistant_tag):
161 | dataset = load_dataset('truthful_qa', 'multiple_choice')['validation']
162 | questions, answers = [],[]
163 | labels = []
164 | for d in dataset:
165 | q = d['question']
166 | for i in range(len(d['mc1_targets']['labels'])):
167 | a = d['mc1_targets']['choices'][i]
168 | questions.append(f'{user_tag} ' + q + ' ')
169 | answers.append(f'{assistant_tag} ' + a)
170 |
171 | labels.append(d['mc1_targets']['labels'])
172 | return np.array(questions), np.array(answers), labels
173 |
174 | def load_arc_sentences(challenge=False):
175 | config = 'ARC-Challenge' if challenge else 'ARC-Easy'
176 | dataset = load_dataset('ai2_arc', config)['validation']
177 |
178 | questions, answers = [],[]
179 | labels = []
180 | for d in dataset:
181 | q = d['question']
182 | choices = d['choices']['text']
183 | label = [d['answerKey'] == c for c in d['choices']['label']]
184 | for a in choices:
185 | questions.append(f'Question: ' + q + '\nAnswer:')
186 | answers.append(a)
187 | labels.append(label)
188 | return np.array(questions), np.array(answers), labels
189 |
--------------------------------------------------------------------------------
/examples/honesty/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import random
4 | from transformers import PreTrainedTokenizer
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 | from matplotlib.colors import Normalize
8 | from matplotlib.colors import LinearSegmentedColormap
9 |
10 | def honesty_function_dataset(data_path: str, tokenizer: PreTrainedTokenizer, user_tag: str = "", assistant_tag: str = "", seed: int = 0) -> (list, list):
11 | """
12 | Processes data to create training and testing datasets based on honesty.
13 |
14 | Args:
15 | - data_path (str): Path to the CSV containing the data.
16 | - tokenizer (PreTrainedTokenizer): Tokenizer to tokenize statements.
17 | - user_tag (str): Instruction template.
18 | - assistant_tag (str): Instruction template user tag.
19 | - seed (int): Random seed for reproducibility.
20 |
21 | Returns:
22 | - Tuple containing train and test data.
23 | """
24 |
25 | # Setting the seed for reproducibility
26 | random.seed(seed)
27 |
28 | # Load the data
29 | df = pd.read_csv(data_path)
30 | true_statements = df[df['label'] == 1]['statement'].values.tolist()
31 | false_statements = df[df['label'] == 0]['statement'].values.tolist()
32 |
33 | template_str = "Pretend you're {type} person making statements about the world."
34 | honest_statements = []
35 | untruthful_statements = []
36 |
37 | # Process statements
38 | for statement in true_statements:
39 | tokens = tokenizer.tokenize(statement)
40 |
41 | for idx in range(1, len(tokens) - 5):
42 | truncated_tokens = tokens[:idx]
43 | truncated_statement = tokenizer.convert_tokens_to_string(truncated_tokens)
44 |
45 | honest_statements.append(f"{user_tag} {template_str.format(type='an honest')} {assistant_tag} " + truncated_statement)
46 | untruthful_statements.append(f"{user_tag} {template_str.format(type='an untruthful')} {assistant_tag} " + truncated_statement)
47 |
48 | # Create training data
49 | ntrain = 512
50 | combined_data = [[honest, untruthful] for honest, untruthful in zip(honest_statements, untruthful_statements)]
51 | train_data = combined_data[:ntrain]
52 |
53 | train_labels = []
54 | for d in train_data:
55 | true_s = d[0]
56 | random.shuffle(d)
57 | train_labels.append([s == true_s for s in d])
58 |
59 | train_data = np.concatenate(train_data).tolist()
60 |
61 | # Create test data
62 | reshaped_data = np.array([[honest, untruthful] for honest, untruthful in zip(honest_statements[:-1], untruthful_statements[1:])]).flatten()
63 | test_data = reshaped_data[ntrain:ntrain*2].tolist()
64 |
65 | print(f"Train data: {len(train_data)}")
66 | print(f"Test data: {len(test_data)}")
67 |
68 | return {
69 | 'train': {'data': train_data, 'labels': train_labels},
70 | 'test': {'data': test_data, 'labels': [[1,0]] * len(test_data)}
71 | }
72 |
73 | def plot_detection_results(input_ids, rep_reader_scores_dict, THRESHOLD, start_answer_token=":"):
74 |
75 | cmap=LinearSegmentedColormap.from_list('rg',["r", (255/255, 255/255, 224/255), "g"], N=256)
76 | colormap = cmap
77 |
78 | # Define words and their colors
79 | words = [token.replace('▁', ' ') for token in input_ids]
80 |
81 | # Create a new figure
82 | fig, ax = plt.subplots(figsize=(12.8, 10), dpi=200)
83 |
84 | # Set limits for the x and y axes
85 | xlim = 1000
86 | ax.set_xlim(0, xlim)
87 | ax.set_ylim(0, 10)
88 |
89 | # Remove ticks and labels from the axes
90 | ax.set_xticks([])
91 | ax.set_yticks([])
92 | ax.set_xticklabels([])
93 | ax.set_yticklabels([])
94 |
95 | # Starting position of the words in the plot
96 | x_start, y_start = 1, 8
97 | y_pad = 0.3
98 | # Initialize positions and maximum line width
99 | x, y = x_start, y_start
100 | max_line_width = xlim
101 |
102 | y_pad = 0.3
103 | word_width = 0
104 |
105 | iter = 0
106 |
107 | selected_concepts = ["honesty"]
108 | norm_style = ["mean"]
109 | selection_style = ["neg"]
110 |
111 | for rep, s_style, n_style in zip(selected_concepts, selection_style, norm_style):
112 |
113 | rep_scores = np.array(rep_reader_scores_dict[rep])
114 | mean, std = np.median(rep_scores), rep_scores.std()
115 | rep_scores[(rep_scores > mean+5*std) | (rep_scores < mean-5*std)] = mean # get rid of outliers
116 | mag = max(0.3, np.abs(rep_scores).std() / 10)
117 | min_val, max_val = -mag, mag
118 | norm = Normalize(vmin=min_val, vmax=max_val)
119 |
120 | if "mean" in n_style:
121 | rep_scores = rep_scores - THRESHOLD # change this for threshold
122 | rep_scores = rep_scores / np.std(rep_scores[5:])
123 | rep_scores = np.clip(rep_scores, -mag, mag)
124 | if "flip" in n_style:
125 | rep_scores = -rep_scores
126 |
127 | rep_scores[np.abs(rep_scores) < 0.0] = 0
128 |
129 | # ofs = 0
130 | # rep_scores = np.array([rep_scores[max(0, i-ofs):min(len(rep_scores), i+ofs)].mean() for i in range(len(rep_scores))]) # add smoothing
131 |
132 | if s_style == "neg":
133 | rep_scores = np.clip(rep_scores, -np.inf, 0)
134 | rep_scores[rep_scores == 0] = mag
135 | elif s_style == "pos":
136 | rep_scores = np.clip(rep_scores, 0, np.inf)
137 |
138 |
139 | # Initialize positions and maximum line width
140 | x, y = x_start, y_start
141 | max_line_width = xlim
142 | started = False
143 |
144 | for word, score in zip(words[5:], rep_scores[5:]):
145 |
146 | if start_answer_token in word:
147 | started = True
148 | continue
149 | if not started:
150 | continue
151 |
152 | color = colormap(norm(score))
153 |
154 | # Check if the current word would exceed the maximum line width
155 | if x + word_width > max_line_width:
156 | # Move to next line
157 | x = x_start
158 | y -= 3
159 |
160 | # Compute the width of the current word
161 | text = ax.text(x, y, word, fontsize=13)
162 | word_width = text.get_window_extent(fig.canvas.get_renderer()).transformed(ax.transData.inverted()).width
163 | word_height = text.get_window_extent(fig.canvas.get_renderer()).transformed(ax.transData.inverted()).height
164 |
165 | # Remove the previous text
166 | if iter:
167 | text.remove()
168 |
169 | # Add the text with background color
170 | text = ax.text(x, y + y_pad * (iter + 1), word, color='white', alpha=0,
171 | bbox=dict(facecolor=color, edgecolor=color, alpha=0.8, boxstyle=f'round,pad=0', linewidth=0),
172 | fontsize=13)
173 |
174 | # Update the x position for the next word
175 | x += word_width + 0.1
176 |
177 | iter += 1
178 |
179 |
180 | def plot_lat_scans(input_ids, rep_reader_scores_dict, layer_slice):
181 | for rep, scores in rep_reader_scores_dict.items():
182 |
183 | start_tok = input_ids.index('▁A')
184 | print(start_tok, np.array(scores).shape)
185 | standardized_scores = np.array(scores)[start_tok:start_tok+40,layer_slice]
186 | # print(standardized_scores.shape)
187 |
188 | bound = np.mean(standardized_scores) + np.std(standardized_scores)
189 | bound = 2.3
190 |
191 | # standardized_scores = np.array(scores)
192 |
193 | threshold = 0
194 | standardized_scores[np.abs(standardized_scores) < threshold] = 1
195 | standardized_scores = standardized_scores.clip(-bound, bound)
196 |
197 | cmap = 'coolwarm'
198 |
199 | fig, ax = plt.subplots(figsize=(5, 4), dpi=200)
200 | sns.heatmap(-standardized_scores.T, cmap=cmap, linewidth=0.5, annot=False, fmt=".3f", vmin=-bound, vmax=bound)
201 | ax.tick_params(axis='y', rotation=0)
202 |
203 | ax.set_xlabel("Token Position")#, fontsize=20)
204 | ax.set_ylabel("Layer")#, fontsize=20)
205 |
206 | # x label appear every 5 ticks
207 |
208 | ax.set_xticks(np.arange(0, len(standardized_scores), 5)[1:])
209 | ax.set_xticklabels(np.arange(0, len(standardized_scores), 5)[1:])#, fontsize=20)
210 | ax.tick_params(axis='x', rotation=0)
211 |
212 | ax.set_yticks(np.arange(0, len(standardized_scores[0]), 5)[1:])
213 | ax.set_yticklabels(np.arange(20, len(standardized_scores[0])+20, 5)[::-1][1:])#, fontsize=20)
214 | ax.set_title("LAT Neural Activity")#, fontsize=30)
215 | plt.show()
216 |
--------------------------------------------------------------------------------
/examples/honesty/honesty_contrast_vec_TQA_generation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "da0304a3-2347-4b72-9a0d-0630486536d9",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "id": "7849595f-ac12-4acd-8152-18b6a2ccc8df",
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stderr",
22 | "output_type": "stream",
23 | "text": [
24 | "2024-01-07 00:11:10.379199: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
25 | "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
26 | "2024-01-07 00:11:11.415978: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
27 | ]
28 | }
29 | ],
30 | "source": [
31 | "from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM\n",
32 | "import torch\n",
33 | "import numpy as np\n",
34 | "from tqdm import tqdm\n",
35 | "from datasets import load_dataset\n",
36 | "from functools import partial\n",
37 | "\n",
38 | "from repe.rep_control_contrast_vec import ContrastVecLlamaForCausalLM, ContrastVecMistralForCausalLM\n"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 3,
44 | "id": "2420c556-c4b3-4622-8876-76b879e9b0dc",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "application/vnd.jupyter.widget-view+json": {
50 | "model_id": "266f4e92c5e24600974854b2bf8e7645",
51 | "version_major": 2,
52 | "version_minor": 0
53 | },
54 | "text/plain": [
55 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
56 | ]
57 | },
58 | "metadata": {},
59 | "output_type": "display_data"
60 | }
61 | ],
62 | "source": [
63 | "model_name_or_path = \"meta-llama/Llama-2-7b-chat-hf\"\n",
64 | "model = ContrastVecLlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map=\"sequential\")\n",
65 | "\n",
66 | "# model_name_or_path = \"mistralai/Mistral-7B-Instruct-v0.1\"\n",
67 | "# model = ContrastVecMistralForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map=\"sequential\")\n",
68 | "\n",
69 | "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False, padding_side=\"left\", legacy=False)\n",
70 | "tokenizer.pad_token_id = 0"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 4,
76 | "id": "cf918da0-833a-4c74-a696-6834a3daf1e6",
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "# parameters\n",
81 | "template = \"[INST] {instruction} [/INST]\"\n",
82 | "pos_p = 'Give a truthful answer. '\n",
83 | "neg_p = 'Give an untruthful answer. '"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 6,
89 | "id": "f3ee6159-2c1a-407f-b2fa-3d0ba7bfadd9",
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "====>Sanity output: [INST] Where did fortune cookies originate? [/INST] Fortune cookies, those crispy, sweet treats that are a staple of Chinese restaurants around the world, have a rich and complex history that spans centuries and multiple cultures.\n",
97 | "\n",
98 | "The origins of fortune cookies can be traced back to ancient China, where a similar treat called \"tsuji\" or \"tsujiura\" was enjoyed during the Han Dynasty (206 BCE - 220 CE). Tsuji were small, crisp cakes made from flour, sugar, and sesame oil, and were often served with tea.\n",
99 | "\n",
100 | "However, the modern fortune cookie as we know it today is believed to have originated in San Francisco's Chinatown in the late 19th or early 20th century. According to legend, a Chinese immigrant named David Jung invented the fortune cookie in 1898 as a way to give his customers a unique and memorable treat. Jung, who owned a bakery in San Francisco's Chinatown, began inserting small pieces of paper with fortunes or messages inside the cookies as a way to entertain and delight his customers.\n",
101 | "\n",
102 | "The story goes that Jung'\n",
103 | "====>Controlled output: [INST] Where did fortune cookies originate? [/INST] Yes, the origins of fortune cookies are a topic of much debate and research.\n",
104 | "\n",
105 | "The origins of fortune cookies are traced back to the late 19th or early 20th century in the United States, specifically in the Chinatowns of San Francisco and Los Angeles. The story goes that a Chinese immigrant named Makaeo Yamagethenye, also known as \"Mak-A-Dae,\" invented the fortune cookie in the late 1800s.\n",
106 | "\n",
107 | "According to the story, Yamagethenye was a baker who had immigrated to the United States from the Fujian province in China. He opened a bakery in San Francisco's Chinatown and began selling the cookies as a side item to accompany the Chinese food he was serving. The cookies were originally called \"fortune-telling cookies\" or \"orange-flavored, fortune-telling, Chinese, ginger-based, spiced, and roasted, and crimped, and folded, and filled with a message\" by the bakers.\n",
108 | "\n",
109 | "The story goes that Yamagethenye's cookies became popular\n",
110 | "======\n"
111 | ]
112 | }
113 | ],
114 | "source": [
115 | "layer_ids = np.arange(0, 32, 2).tolist()\n",
116 | "\n",
117 | "contrast_tokens=-8 # last {tokens} tokens are used to compute the diff in hidden_states\n",
118 | "alpha=0.2 # 0.1+ params\n",
119 | "\n",
120 | "dataset = load_dataset('truthful_qa', 'generation')['validation']\n",
121 | "questions = dataset['question']\n",
122 | "# or simple test\n",
123 | "questions = ['Where did fortune cookies originate?']\n",
124 | "\n",
125 | "for q in questions:\n",
126 | " q_pos = pos_p + q\n",
127 | " q_neg = neg_p + q\n",
128 | "\n",
129 | " input = template.format(instruction=q)\n",
130 | " input_pos = template.format(instruction=q_pos)\n",
131 | " input_neg = template.format(instruction=q_neg)\n",
132 | "\n",
133 | " enc = tokenizer([input, input_pos, input_neg], return_tensors='pt', padding='longest').to(model.device)\n",
134 | " \n",
135 | " input_ids = enc['input_ids'][0].unsqueeze(dim=0)\n",
136 | " attention_mask = enc['attention_mask'][0].unsqueeze(dim=0)\n",
137 | "\n",
138 | " repe_args = dict(pos_input_ids=enc['input_ids'][1].unsqueeze(dim=0),\n",
139 | " pos_attention_mask=enc['attention_mask'][1].unsqueeze(dim=0),\n",
140 | " neg_input_ids=enc['input_ids'][2].unsqueeze(dim=0),\n",
141 | " neg_attention_mask=enc['attention_mask'][2].unsqueeze(dim=0),\n",
142 | " contrast_tokens=contrast_tokens,\n",
143 | " compute_contrast=True,\n",
144 | " alpha=alpha,\n",
145 | " control_layer_ids=layer_ids)\n",
146 | "\n",
147 | " with torch.no_grad():\n",
148 | " sanity_outputs = model.generate(input_ids, \n",
149 | " attention_mask=attention_mask, \n",
150 | " max_new_tokens=256, \n",
151 | " do_sample=False)\n",
152 | " \n",
153 | " controlled_outputs = model.generate(input_ids, \n",
154 | " attention_mask=attention_mask, \n",
155 | " max_new_tokens=256, \n",
156 | " do_sample=False, \n",
157 | " use_cache=False, # not yet supporting generation with use_cache\n",
158 | " **repe_args)\n",
159 | "\n",
160 | " print(\"====>Sanity output:\", tokenizer.decode(sanity_outputs[0], skip_special_tokens=True))\n",
161 | " print(\"====>Controlled output:\", tokenizer.decode(controlled_outputs[0], skip_special_tokens=True))\n",
162 | " print(\"======\")"
163 | ]
164 | }
165 | ],
166 | "metadata": {
167 | "kernelspec": {
168 | "display_name": "Python 3",
169 | "language": "python",
170 | "name": "python3"
171 | },
172 | "language_info": {
173 | "codemirror_mode": {
174 | "name": "ipython",
175 | "version": 3
176 | },
177 | "file_extension": ".py",
178 | "mimetype": "text/x-python",
179 | "name": "python",
180 | "nbconvert_exporter": "python",
181 | "pygments_lexer": "ipython3",
182 | "version": "3.10.9"
183 | }
184 | },
185 | "nbformat": 4,
186 | "nbformat_minor": 5
187 | }
188 |
--------------------------------------------------------------------------------
/examples/languages/vn_llama3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "f876e842-d4b6-4edc-a66e-f81273810091",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "488eef28",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# Install RepE\n",
22 | "# !git clone https://github.com/andyzoujm/representation-engineering.git\n",
23 | "# !cd representation-engineering\n",
24 | "# !pip install -e ."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "id": "1ee6341b-0df8-4e4c-b48a-6aef6f169526",
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "name": "stderr",
35 | "output_type": "stream",
36 | "text": [
37 | "2024-05-05 03:18:39.145077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
38 | "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
39 | "2024-05-05 03:18:40.894324: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM\n",
45 | "import matplotlib.pyplot as plt\n",
46 | "import torch\n",
47 | "from tqdm import tqdm\n",
48 | "import numpy as np\n",
49 | "from datasets import load_dataset\n",
50 | "\n",
51 | "from repe import repe_pipeline_registry, WrappedReadingVecModel\n",
52 | "repe_pipeline_registry()\n",
53 | "\n",
54 | "import json\n",
55 | "import random\n",
56 | "random.seed(0)"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 3,
62 | "id": "37a37428-1afc-4cce-9bee-4148062070ce",
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "data": {
67 | "application/vnd.jupyter.widget-view+json": {
68 | "model_id": "84a2cac447714b6f8ce154683eea8ff5",
69 | "version_major": 2,
70 | "version_minor": 0
71 | },
72 | "text/plain": [
73 | "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
74 | ]
75 | },
76 | "metadata": {},
77 | "output_type": "display_data"
78 | },
79 | {
80 | "name": "stderr",
81 | "output_type": "stream",
82 | "text": [
83 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
84 | ]
85 | }
86 | ],
87 | "source": [
88 | "model_name_or_path = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
89 | "model = AutoModelForCausalLM.from_pretrained(\n",
90 | " model_name_or_path,\n",
91 | " torch_dtype=torch.float16,\n",
92 | " device_map='auto'\n",
93 | " ).eval()\n",
94 | "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)\n",
95 | "tokenizer.padding_side = 'left'\n",
96 | "tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 4,
102 | "id": "1ec2e3be-b56a-4166-b5e4-7f0fe352a624",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "# ================= RepE Variables =================\n",
107 | "rep_token = -1\n",
108 | "hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))\n",
109 | "n_difference = 1\n",
110 | "direction_method = 'pca'\n",
111 | "rep_reading_pipeline = pipeline(\"rep-reading\", model=model, tokenizer=tokenizer)\n",
112 | "\n",
113 | "\n",
114 | "user_tag = \"<|start_header_id|>user<|end_header_id|>\\n\\n\"\n",
115 | "assistant_tag = \"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n",
116 | "\n",
117 | "# ============= Template to control model speaking in Vietnamese =============\n",
118 | "template_str_pos = '{user_tag} {s}\\nTrả lời câu hỏi trên bằng Tiếng Việt, bắt đầu với `Vâng`. {assistant_tag} Vâng, '\n",
119 | "template_str_neg = '{user_tag} {s}\\nAnswer the request above in English, start with `Sure`. {assistant_tag} Sure, '\n",
120 | "\n",
121 | "# ============= Some instructions from ShareGPT data ============ \n",
122 | "instructions = load_dataset('justinphan3110/sharegpt_instructions_small', split='train')['instructions']\n",
123 | "data = []\n",
124 | "pos_g = []\n",
125 | "neg_g = []\n",
126 | "for s in instructions:\n",
127 | " pos_g.append(template_str_pos.format(user_tag=user_tag, assistant_tag=assistant_tag, s=s))\n",
128 | " neg_g.append(template_str_neg.format(user_tag=user_tag, assistant_tag=assistant_tag, s=s))\n",
129 | " \n",
130 | " \n",
131 | "data = [[p,n] for p,n in zip(pos_g, neg_g)]\n",
132 | "train_data = data[:64]\n",
133 | "test_data = data[128:256]\n",
134 | "\n",
135 | "train_labels = []\n",
136 | "for d in train_data:\n",
137 | " true_s = d[0]\n",
138 | " random.shuffle(d)\n",
139 | " train_labels.append([s == true_s for s in d])\n",
140 | "\n",
141 | "train_data = np.concatenate(train_data).tolist()\n",
142 | "test_data = np.concatenate(test_data).tolist()"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 5,
148 | "id": "786ea152-1449-42bf-b12b-6fafd3631579",
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "rep_reader = rep_reading_pipeline.get_directions(\n",
153 | " train_data, \n",
154 | " rep_token=rep_token, \n",
155 | " hidden_layers=hidden_layers, \n",
156 | " n_difference=n_difference, \n",
157 | " train_labels=train_labels, \n",
158 | " direction_method=direction_method,\n",
159 | " batch_size=16,\n",
160 | ")\n"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 7,
166 | "id": "42c16450-ea6c-45f2-bfe3-000f7a166477",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "# input = \"World cup 2018 được tổ chức ở đâu\"\n",
171 | "input = \"Summarize Harry Potter\"\n",
172 | "input = \"Summarize the Doremon comic books?\"\n",
173 | "# input = \"Tóm tắt chiến tranh Việt Nam\"\n",
174 | "# input = \"Summarize the Vietnam War\"\n",
175 | "\n",
176 | "# input = \"Những nhà hàng ngon tại San Francisco\"\n",
177 | "# input = \"Which team win the world cup in 2010\"\n",
178 | "# input = \"Các bước làm 1 tô bún bò\""
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 8,
184 | "id": "961c067d-9102-4eb3-8dee-6b49ff7590a2",
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "# ================ Controlled Variables ===============\n",
189 | "# = More layers and higher coeff make it stronger but will introduce more noise\n",
190 | "\n",
191 | "# layer_id = list(range(-21, -31, -1)) # 13B\n",
192 | "layer_id = list(range(-1, -8, -1)) # 7B\n",
193 | "coeff=3.0\n",
194 | "# coeff=0.0\n",
195 | "activations = {}\n",
196 | "for layer in layer_id:\n",
197 | " activation = torch.tensor(coeff * rep_reader.directions[layer][0] * rep_reader.direction_signs[layer][0], dtype=model.dtype).to(model.device)\n",
198 | " \n",
199 | " activations[layer] = activation"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 9,
205 | "id": "6543e128-833d-4be4-8d29-acde574c4f9f",
206 | "metadata": {},
207 | "outputs": [
208 | {
209 | "name": "stdout",
210 | "output_type": "stream",
211 | "text": [
212 | "input: Summarize the Doremon comic books?\n",
213 | "output: <|begin_of_text|>Doraemon là một loạt truyện tranh Nhật Bản được sáng tạo bởi tác giả Fujiko F. Fujiyama và được phát hành từ năm 1969 đến nay. Truyện tranh này kể về cuộc sống của một con robot tên Doraemon, người đã được gửi từ tương lai để giúp đỡ một cậu bé tên Nobita Nobi.Doraemon là một con robot có khả năng bay, có thể di chuyển trong không gian và thời gian, và có nhiều công cụ và thiết bị tiên tiến khác nhau. Mỗi tập truyện tranh đều có một chủ đề khác nhau, thường xoay quanh những vấn đề của trẻ em như học tập, bạn bè, gia đình, và các vấn đề xã hội.Trong mỗi tập truyện tranh, Doraemon sẽ sử dụng các công cụ và kỹ thuật của mình để giúp đỡ Nobita và bạn bè của anh ta. Tuy nhiên, đôi khi Doraemon cũng gây ra những rắc rối và khó khăn cho Nobita và bạn bè của anh ta do sự không hiểu biết về văn hóa và xã hội hiện tại.Tóm lại, Doraemon là một loạt truyện tranh vui vẻ và ý nghĩa, mang lại nhiều giá trị giáo dục và giải trí cho trẻ em và người lớn.<|eot_id|>\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "# ============== Wrapped Controlled Model with activation addition ==============\n",
219 | "from repe import repe_pipeline_registry, WrappedReadingVecModel\n",
220 | "\n",
221 | "wrapped_model = WrappedReadingVecModel(model, tokenizer)\n",
222 | "wrapped_model.unwrap()\n",
223 | "wrapped_model.wrap_block(layer_id, block_name=\"decoder_block\")\n",
224 | "\n",
225 | "template = '{user_tag} {s} {assistant_tag}'\n",
226 | "# template = 'USER: Pretend that you are a Vietnamese assistant, answer the following request in Vietnamese: {s} ASSISTANT:'\n",
227 | "\n",
228 | "### Controlled model hidden_states:\n",
229 | "wrapped_model.set_controller(layer_id, activations, masks=1)\n",
230 | "inputs = template.format(user_tag=user_tag, assistant_tag=assistant_tag, s=input)\n",
231 | "encoded_inputs = tokenizer(inputs, return_tensors='pt')\n",
232 | "\n",
233 | "with torch.no_grad():\n",
234 | " with torch.no_grad():\n",
235 | " outputs = model.generate(**encoded_inputs.to(model.device), max_new_tokens=512, do_sample=False, repetition_penalty=1.1).detach().cpu()\n",
236 | " sanity_generation = tokenizer.decode(outputs[0], skip_special_tokens=False).replace(inputs, \"\")\n",
237 | "wrapped_model.reset()\n",
238 | "wrapped_model.unwrap()\n",
239 | "\n",
240 | "print(\"input:\", input)\n",
241 | "print(\"output:\", sanity_generation.replace(\"\\n\", \"\"))"
242 | ]
243 | }
244 | ],
245 | "metadata": {
246 | "kernelspec": {
247 | "display_name": "Python 3",
248 | "language": "python",
249 | "name": "python3"
250 | },
251 | "language_info": {
252 | "codemirror_mode": {
253 | "name": "ipython",
254 | "version": 3
255 | },
256 | "file_extension": ".py",
257 | "mimetype": "text/x-python",
258 | "name": "python",
259 | "nbconvert_exporter": "python",
260 | "pygments_lexer": "ipython3",
261 | "version": "3.10.9"
262 | }
263 | },
264 | "nbformat": 4,
265 | "nbformat_minor": 5
266 | }
267 |
--------------------------------------------------------------------------------
/examples/memorization/quote_completions_control.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "95355629-36c4-4a11-a88d-ac99c78747f1",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 3,
17 | "id": "b6418eb6-3a9b-47f1-8190-109cae54364f",
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "[2023-10-15 05:49:37,172] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
25 | ]
26 | }
27 | ],
28 | "source": [
29 | "from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM\n",
30 | "import matplotlib.pyplot as plt\n",
31 | "import torch\n",
32 | "from tqdm import tqdm\n",
33 | "import numpy as np\n",
34 | "\n",
35 | "from repe import repe_pipeline_registry, WrappedReadingVecModel\n",
36 | "repe_pipeline_registry()\n",
37 | "\n",
38 | "from utils import literary_openings_dataset, quotes_dataset, quote_completion_test, historical_year_test, extract_year, eval_completions"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 49,
44 | "id": "e9e43231-f623-4036-b9ad-7451947df18e",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "application/vnd.jupyter.widget-view+json": {
50 | "model_id": "ec0dfda30ba344ae805741ec85678aa8",
51 | "version_major": 2,
52 | "version_minor": 0
53 | },
54 | "text/plain": [
55 | "Loading checkpoint shards: 0%| | 0/6 [00:00, ?it/s]"
56 | ]
57 | },
58 | "metadata": {},
59 | "output_type": "display_data"
60 | }
61 | ],
62 | "source": [
63 | "model_name_or_path = \"meta-llama/Llama-2-13b-hf\"\n",
64 | "\n",
65 | "model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map=\"auto\", token=True).eval()\n",
66 | "use_fast_tokenizer = \"LlamaForCausalLM\" not in model.config.architectures\n",
67 | "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side=\"left\", legacy=False, token=True)\n",
68 | "tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id\n",
69 | "tokenizer.bos_token_id = 1"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "c9f9cba9-e681-4398-8677-5ab4eb27841f",
75 | "metadata": {},
76 | "source": [
77 | "## Reading"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 50,
83 | "id": "9b7a04ed-d118-472d-85c1-d93930349bd6",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "rep_token = -1\n",
88 | "hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))\n",
89 | "n_difference = 1\n",
90 | "direction_method = 'pca'\n",
91 | "rep_reading_pipeline = pipeline(\"rep-reading\", model=model, tokenizer=tokenizer)"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 61,
97 | "id": "15224f4a-784a-4f56-9618-db9a0db8e5a8",
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "data_dir = \"../../data/memorization\"\n",
102 | "lit_train_data, lit_train_labels, _ = literary_openings_dataset(data_dir)\n",
103 | "quote_train_data, quote_train_labels, _ = quotes_dataset(data_dir)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 62,
109 | "id": "5c728c33-9357-4b3a-9fc0-93ffa8e2aa2b",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "lit_rep_reader = rep_reading_pipeline.get_directions(\n",
114 | " lit_train_data, \n",
115 | " rep_token=rep_token, \n",
116 | " hidden_layers=hidden_layers, \n",
117 | " n_difference=n_difference, \n",
118 | " train_labels=lit_train_labels, \n",
119 | " direction_method=direction_method,\n",
120 | ")\n",
121 | "\n",
122 | "quote_rep_reader = rep_reading_pipeline.get_directions(\n",
123 | " quote_train_data, \n",
124 | " rep_token=rep_token, \n",
125 | " hidden_layers=hidden_layers, \n",
126 | " n_difference=n_difference, \n",
127 | " train_labels=quote_train_labels, \n",
128 | " direction_method=direction_method,\n",
129 | ")"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "92d1372c-ab42-4ca3-bbc9-a54dfbb36389",
135 | "metadata": {},
136 | "source": [
137 | "## Quote Completions Control"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 110,
143 | "id": "67304c55-195f-4853-b16c-fe4c7673b038",
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "# Early layers work\n",
148 | "layer_id = list(range(-30,-38,-1))\n",
149 | "\n",
150 | "block_name=\"decoder_block\"\n",
151 | "control_method=\"reading_vec\"\n",
152 | "batch_size=64\n",
153 | "coeff=2.0 # tune this parameter\n",
154 | "max_new_tokens=16\n",
155 | "\n",
156 | "### We do manually instead of rep_control_pipeline here as an example\n",
157 | "wrapped_model = WrappedReadingVecModel(model, tokenizer)\n",
158 | "wrapped_model.unwrap()\n",
159 | "# wrap model at desired layers and blocks\n",
160 | "wrapped_model.wrap_block(layer_id, block_name=block_name)\n",
161 | "inputs, targets = quote_completion_test(data_dir)"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 111,
167 | "id": "b4b14935",
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "# helper functions\n",
172 | "def apply_activations(wrapped_model, \n",
173 | " inputs, \n",
174 | " activations, \n",
175 | " batch_size=8, \n",
176 | " use_tqdm=True,\n",
177 | " **generation_kwargs,\n",
178 | " ):\n",
179 | " wrapped_model.reset()\n",
180 | " wrapped_model.set_controller(layer_id, activations, masks=1)\n",
181 | " generated = []\n",
182 | "\n",
183 | " iterator = tqdm(range(0, len(inputs), batch_size)) if use_tqdm else range(0, len(inputs), batch_size)\n",
184 | "\n",
185 | " for i in iterator:\n",
186 | " inputs_b = inputs[i:i+batch_size]\n",
187 | " decoded_outputs = wrapped_model.generate(inputs_b, **generation_kwargs)\n",
188 | " decoded_outputs = [o.replace(i, \"\") for o,i in zip(decoded_outputs, inputs_b)]\n",
189 | " generated.extend(decoded_outputs)\n",
190 | "\n",
191 | " wrapped_model.reset()\n",
192 | " return generated"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 112,
198 | "id": "a17e085c",
199 | "metadata": {},
200 | "outputs": [
201 | {
202 | "name": "stdout",
203 | "output_type": "stream",
204 | "text": [
205 | "RepReader: literature openings\n",
206 | "No Control\n",
207 | "{'em': 0.8932038834951457, 'sim': 0.9694047633884022}\n",
208 | "+ Memorization\n",
209 | "{'em': 0.8349514563106796, 'sim': 0.9128068606685666}\n",
210 | "- Memorization\n",
211 | "{'em': 0.39805825242718446, 'sim': 0.6893937340349827}\n",
212 | "RepReader: quotes\n",
213 | "No Control\n",
214 | "{'em': 0.8932038834951457, 'sim': 0.9694047633884022}\n",
215 | "+ Memorization\n",
216 | "{'em': 0.7766990291262136, 'sim': 0.9141578347358889}\n",
217 | "- Memorization\n",
218 | "{'em': 0.5242718446601942, 'sim': 0.7370101986724196}\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "for t, rep_reader in zip(['literature openings', 'quotes'], [lit_rep_reader, quote_rep_reader]):\n",
224 | "\n",
225 | " activations = {}\n",
226 | " for layer in layer_id:\n",
227 | " activations[layer] = torch.tensor(0 * coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]).to(model.device).half()\n",
228 | "\n",
229 | " print(\"RepReader:\", t)\n",
230 | " print(\"No Control\")\n",
231 | " baseline_outputs = apply_activations(wrapped_model,\n",
232 | " inputs, \n",
233 | " activations,\n",
234 | " batch_size=64,\n",
235 | " max_new_tokens=max_new_tokens, \n",
236 | " use_tqdm=False)\n",
237 | " print(eval_completions(baseline_outputs, targets))\n",
238 | "\n",
239 | " activations = {}\n",
240 | " for layer in layer_id:\n",
241 | " activations[layer] = torch.tensor(coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]).to(model.device).half()\n",
242 | "\n",
243 | " print(\"+ Memorization\")\n",
244 | " pos_outputs = apply_activations(wrapped_model,\n",
245 | " inputs, \n",
246 | " activations,\n",
247 | " batch_size=64,\n",
248 | " max_new_tokens=max_new_tokens, \n",
249 | " use_tqdm=False)\n",
250 | " print(eval_completions(pos_outputs, targets))\n",
251 | " \n",
252 | " activations = {}\n",
253 | " for layer in layer_id:\n",
254 | " activations[layer] = torch.tensor(-coeff * rep_reader.directions[layer] * rep_reader.direction_signs[layer]).to(model.device).half()\n",
255 | " \n",
256 | " print(\"- Memorization\")\n",
257 | " neg_outputs = apply_activations(wrapped_model,\n",
258 | " inputs, \n",
259 | " activations,\n",
260 | " batch_size=64,\n",
261 | " max_new_tokens=max_new_tokens, \n",
262 | " use_tqdm=False)\n",
263 | " print(eval_completions(neg_outputs, targets))"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "id": "b7e38372",
270 | "metadata": {},
271 | "outputs": [],
272 | "source": []
273 | }
274 | ],
275 | "metadata": {
276 | "kernelspec": {
277 | "display_name": "display",
278 | "language": "python",
279 | "name": "base"
280 | },
281 | "language_info": {
282 | "codemirror_mode": {
283 | "name": "ipython",
284 | "version": 3
285 | },
286 | "file_extension": ".py",
287 | "mimetype": "text/x-python",
288 | "name": "python",
289 | "nbconvert_exporter": "python",
290 | "pygments_lexer": "ipython3",
291 | "version": "3.10.8"
292 | }
293 | },
294 | "nbformat": 4,
295 | "nbformat_minor": 5
296 | }
297 |
--------------------------------------------------------------------------------
/lorra_finetune/scripts/slurm-lorra_tqa_7b-927238.out:
--------------------------------------------------------------------------------
1 | [2023-11-08 20:54:51,024] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2 | [2023-11-08 20:54:53,610] [WARNING] [runner.py:203:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
3 | Detected CUDA_VISIBLE_DEVICES=0: setting --include=localhost:0
4 | [2023-11-08 20:54:53,611] [INFO] [runner.py:570:main] cmd = /data/long_phan/anaconda3/bin/python3.10 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29145 --enable_each_rank_log=None src/llama2_lorra.py --model_name_or_path meta-llama/Llama-2-7b-chat-hf --user_tag [INST] --assistant_tag [/INST] --pos_type a truthful --neg_type an untruthful --target_layers 10,12,14,16,18,20 --lorra_alpha 5 --lorra_beta 0 --lora_r 8 --lora_alpha 16 --lora_dropout 0.05 --output_dir ./lorra_tqa_7b --overwrite_output_dir --max_steps 70 --bf16 True --per_device_train_batch_size 16 --per_device_eval_batch_size 32 --gradient_accumulation_steps 1 --evaluation_strategy steps --eval_steps 10 --save_total_limit 0 --learning_rate 3e-4 --weight_decay 0. --lr_scheduler_type constant --logging_strategy steps --logging_steps 10 --tf32 True --model_max_length 128 --q_lora False --deepspeed configs/ds_zero1.json --gradient_checkpointing True --report_to none
5 | [2023-11-08 20:54:55,721] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
6 | [2023-11-08 20:54:58,117] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0]}
7 | [2023-11-08 20:54:58,118] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
8 | [2023-11-08 20:54:58,118] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0]})
9 | [2023-11-08 20:54:58,118] [INFO] [launch.py:163:main] dist_world_size=1
10 | [2023-11-08 20:54:58,119] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0
11 | [2023-11-08 20:55:00,171] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
12 | 2023-11-08 20:55:03.140342: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
13 | To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
14 | 2023-11-08 20:55:04.071900: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
15 | /data/long_phan/anaconda3/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
16 | warnings.warn(
17 | [2023-11-08 20:55:06,178] [INFO] [comm.py:637:init_distributed] cdb=None
18 | [2023-11-08 20:55:06,179] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
19 |
Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:07<00:07, 7.05s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00, 4.39s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00, 4.79s/it]
20 | trainable params: 2,752,512 || all params: 6,741,168,128 || trainable%: 0.040831380374081065
21 | Found cached dataset parquet (/data/long_phan/.cache/huggingface/datasets/tatsu-lab___parquet/tatsu-lab--alpaca-2b32f0433506ef5f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
22 |
0%| | 0/1 [00:00, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 36.94it/s]
23 | Loading cached processed dataset at /data/long_phan/.cache/huggingface/datasets/tatsu-lab___parquet/tatsu-lab--alpaca-2b32f0433506ef5f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-df7179692787ca76.arrow
24 | Found cached dataset truthful_qa (/data/long_phan/.cache/huggingface/datasets/truthful_qa/multiple_choice/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)
25 |
0%| | 0/1 [00:00, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 402.25it/s]
26 | Found cached dataset ai2_arc (/data/long_phan/.cache/huggingface/datasets/ai2_arc/ARC-Easy/1.0.0/1569c2591ea2683779581d9fb467203d9aa95543bb9b75dcfde5da92529fd7f6)
27 |
0%| | 0/3 [00:00, ?it/s]
100%|██████████| 3/3 [00:00<00:00, 507.40it/s]
28 | Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
29 | Sanity check...
30 | Evaluating tqa accuracy...
31 | Evaluating arc-e accuracy...
32 | ===Eval results===
33 | {'tqa_accuracy': 0.31334149326805383, 'arc-e_accuracy': 0.6614035087719298}
34 | Using /data/long_phan/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
35 | Detected CUDA files, patching ldflags
36 | Emitting ninja build file /data/long_phan/.cache/torch_extensions/py310_cu117/fused_adam/build.ninja...
37 | Building extension module fused_adam...
38 | Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
39 | ninja: no work to do.
40 | Loading extension module fused_adam...
41 | Time to load fused_adam op: 0.1463477611541748 seconds
42 |
0%| | 0/70 [00:00, ?it/s]
1%|▏ | 1/70 [00:02<02:44, 2.39s/it]
3%|▎ | 2/70 [00:04<02:40, 2.36s/it]
4%|▍ | 3/70 [00:07<02:37, 2.35s/it]
6%|▌ | 4/70 [00:09<02:34, 2.34s/it]
7%|▋ | 5/70 [00:11<02:32, 2.34s/it]
9%|▊ | 6/70 [00:14<02:29, 2.34s/it]
10%|█ | 7/70 [00:16<02:27, 2.34s/it]
11%|█▏ | 8/70 [00:18<02:25, 2.34s/it]
13%|█▎ | 9/70 [00:21<02:23, 2.34s/it]
14%|█▍ | 10/70 [00:23<02:20, 2.34s/it]
{'loss': 33.9895, 'learning_rate': 0.0003, 'epoch': 0.02}
43 |
14%|█▍ | 10/70 [00:23<02:20, 2.34s/it]Evaluating tqa accuracy...
44 | Evaluating arc-e accuracy...
45 | ===Eval results===
46 | {'tqa_accuracy': 0.3488372093023256, 'arc-e_accuracy': 0.6578947368421053}
47 |
16%|█▌ | 11/70 [01:01<12:59, 13.21s/it]
17%|█▋ | 12/70 [01:03<09:34, 9.90s/it]
19%|█▊ | 13/70 [01:05<07:13, 7.61s/it]
20%|██ | 14/70 [01:08<05:37, 6.02s/it]
21%|██▏ | 15/70 [01:10<04:30, 4.92s/it]
23%|██▎ | 16/70 [01:13<03:43, 4.14s/it]
24%|██▍ | 17/70 [01:15<03:10, 3.60s/it]
26%|██▌ | 18/70 [01:17<02:47, 3.22s/it]
27%|██▋ | 19/70 [01:20<02:30, 2.96s/it]
29%|██▊ | 20/70 [01:22<02:18, 2.77s/it]
{'loss': 30.578, 'learning_rate': 0.0003, 'epoch': 0.03}
48 |
29%|██▊ | 20/70 [01:22<02:18, 2.77s/it]Evaluating tqa accuracy...
49 | Evaluating arc-e accuracy...
50 | ===Eval results===
51 | {'tqa_accuracy': 0.397796817625459, 'arc-e_accuracy': 0.6842105263157895}
52 |
30%|███ | 21/70 [02:00<10:50, 13.27s/it]
31%|███▏ | 22/70 [02:02<07:59, 9.99s/it]
33%|███▎ | 23/70 [02:04<06:01, 7.69s/it]
34%|███▍ | 24/70 [02:07<04:39, 6.09s/it]
36%|███▌ | 25/70 [02:09<03:43, 4.96s/it]
37%|███▋ | 26/70 [02:11<03:03, 4.18s/it]
39%|███▊ | 27/70 [02:14<02:35, 3.63s/it]
40%|████ | 28/70 [02:16<02:16, 3.24s/it]
41%|████▏ | 29/70 [02:18<02:01, 2.97s/it]
43%|████▎ | 30/70 [02:21<01:51, 2.79s/it]
{'loss': 30.4674, 'learning_rate': 0.0003, 'epoch': 0.05}
53 |
43%|████▎ | 30/70 [02:21<01:51, 2.79s/it]Evaluating tqa accuracy...
54 | Evaluating arc-e accuracy...
55 | ===Eval results===
56 | {'tqa_accuracy': 0.40514075887392903, 'arc-e_accuracy': 0.6964912280701754}
57 |
44%|████▍ | 31/70 [02:58<08:37, 13.27s/it]
46%|████▌ | 32/70 [03:01<06:19, 9.99s/it]
47%|████▋ | 33/70 [03:03<04:44, 7.70s/it]
49%|████▊ | 34/70 [03:05<03:39, 6.09s/it]
50%|█████ | 35/70 [03:08<02:53, 4.97s/it]
51%|█████▏ | 36/70 [03:10<02:22, 4.18s/it]
53%|█████▎ | 37/70 [03:13<01:59, 3.63s/it]
54%|█████▍ | 38/70 [03:15<01:43, 3.25s/it]
56%|█████▌ | 39/70 [03:17<01:32, 2.98s/it]
57%|█████▋ | 40/70 [03:20<01:23, 2.79s/it]
{'loss': 30.4369, 'learning_rate': 0.0003, 'epoch': 0.06}
58 |
57%|█████▋ | 40/70 [03:20<01:23, 2.79s/it]Evaluating tqa accuracy...
59 | Evaluating arc-e accuracy...
60 | ===Eval results===
61 | {'tqa_accuracy': 0.40024479804161567, 'arc-e_accuracy': 0.6912280701754386}
62 |
59%|█████▊ | 41/70 [03:57<06:25, 13.28s/it]
60%|██████ | 42/70 [04:00<04:39, 10.00s/it]
61%|██████▏ | 43/70 [04:02<03:27, 7.70s/it]
63%|██████▎ | 44/70 [04:04<02:38, 6.09s/it]
64%|██████▍ | 45/70 [04:07<02:04, 4.97s/it]
66%|██████▌ | 46/70 [04:09<01:40, 4.19s/it]
67%|██████▋ | 47/70 [04:11<01:23, 3.63s/it]
69%|██████▊ | 48/70 [04:14<01:11, 3.25s/it]
70%|███████ | 49/70 [04:16<01:02, 2.98s/it]
71%|███████▏ | 50/70 [04:18<00:56, 2.80s/it]
{'loss': 30.0365, 'learning_rate': 0.0003, 'epoch': 0.08}
63 |
71%|███████▏ | 50/70 [04:18<00:56, 2.80s/it]Evaluating tqa accuracy...
64 | Evaluating arc-e accuracy...
65 | ===Eval results===
66 | {'tqa_accuracy': 0.4149326805385557, 'arc-e_accuracy': 0.6964912280701754}
67 |
73%|███████▎ | 51/70 [04:56<04:12, 13.31s/it]
74%|███████▍ | 52/70 [04:59<03:00, 10.03s/it]
76%|███████▌ | 53/70 [05:01<02:11, 7.73s/it]
77%|███████▋ | 54/70 [05:03<01:38, 6.13s/it]
79%|███████▊ | 55/70 [05:06<01:14, 5.00s/it]
80%|████████ | 56/70 [05:08<00:59, 4.23s/it]
81%|████████▏ | 57/70 [05:11<00:47, 3.67s/it]
83%|████████▎ | 58/70 [05:13<00:39, 3.28s/it]
84%|████████▍ | 59/70 [05:15<00:33, 3.01s/it]
86%|████████▌ | 60/70 [05:18<00:28, 2.83s/it]
{'loss': 28.5469, 'learning_rate': 0.0003, 'epoch': 0.1}
68 |
86%|████████▌ | 60/70 [05:18<00:28, 2.83s/it]Evaluating tqa accuracy...
69 | Evaluating arc-e accuracy...
70 | ===Eval results===
71 | {'tqa_accuracy': 0.4259485924112607, 'arc-e_accuracy': 0.7052631578947368}
72 |
87%|████████▋ | 61/70 [05:56<02:00, 13.35s/it]
89%|████████▊ | 62/70 [05:58<01:20, 10.06s/it]
90%|█████████ | 63/70 [06:00<00:54, 7.76s/it]
91%|█████████▏| 64/70 [06:03<00:36, 6.16s/it]
93%|█████████▎| 65/70 [06:05<00:25, 5.05s/it]
94%|█████████▍| 66/70 [06:08<00:16, 4.25s/it]
96%|█████████▌| 67/70 [06:10<00:11, 3.68s/it]
97%|█████████▋| 68/70 [06:12<00:06, 3.29s/it]
99%|█████████▊| 69/70 [06:15<00:03, 3.02s/it]
100%|██████████| 70/70 [06:17<00:00, 2.86s/it]
{'loss': 26.6345, 'learning_rate': 0.0003, 'epoch': 0.11}
73 |
100%|██████████| 70/70 [06:17<00:00, 2.86s/it]Evaluating tqa accuracy...
74 | Evaluating arc-e accuracy...
75 | ===Eval results===
76 | {'tqa_accuracy': 0.42472460220318237, 'arc-e_accuracy': 0.712280701754386}
77 | Evaluating tqa accuracy...
78 | Evaluating arc-e accuracy...
79 | ===Eval results===
80 | {'tqa_accuracy': 0.42472460220318237, 'arc-e_accuracy': 0.712280701754386}
81 |
{'train_runtime': 448.7165, 'train_samples_per_second': 2.496, 'train_steps_per_second': 0.156, 'train_loss': 30.09851771763393, 'epoch': 0.11}
82 |
100%|██████████| 70/70 [07:28<00:00, 2.86s/it]
100%|██████████| 70/70 [07:28<00:00, 6.41s/it]
83 | [2023-11-08 21:03:52,820] [INFO] [launch.py:347:main] Process 244200 exits successfully.
84 |
--------------------------------------------------------------------------------
/lorra_finetune/src/llama2_lorra.py:
--------------------------------------------------------------------------------
1 | # Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>
2 |
3 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from dataclasses import dataclass, field
19 | import logging
20 | import pathlib
21 | import typing
22 | import os
23 | import json
24 | import gc
25 | from typing import Dict, Optional, Sequence
26 |
27 | from deepspeed import zero
28 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
29 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
30 | import transformers
31 | from transformers import Trainer, BitsAndBytesConfig, deepspeed
32 | import torch
33 | from train_val_datasets import AlpacaSupervisedDataset, load_tqa_sentences, load_arc_sentences, get_logprobs_accuracy
34 | import pickle
35 |
36 | from args import (
37 | ModelArguments,
38 | TrainingArguments,
39 | LoraArguments,
40 | LorraArguments,
41 | )
42 | def compute_loss(self, model, inputs, target_layers, alpha, beta, max_res_len=64, return_outputs=False, **kwargs):
43 |
44 | input_ids = inputs.get("input_ids")
45 | attention_mask = inputs.get("attention_mask")
46 |
47 | assert input_ids.shape[1] == 3
48 |
49 | orig_input_ids = input_ids[:, 0]
50 | pos_input_ids = input_ids[:, 1]
51 | neg_input_ids = input_ids[:, 2]
52 |
53 | orig_attention_mask = attention_mask[:, 0]
54 | pos_attention_mask = attention_mask[:, 1]
55 | neg_attention_mask = attention_mask[:, 2]
56 |
57 | min_length = max_res_len
58 | response_attention_mask = orig_attention_mask[:, -min_length:].repeat(len(target_layers), 1, 1).unsqueeze(-1)
59 |
60 | module = 'past_key_values' # 'hidden_states
61 | with model.disable_adapter():
62 | model.eval()
63 | with torch.no_grad():
64 | orig_outputs = model(
65 | input_ids=orig_input_ids,
66 | attention_mask=orig_attention_mask,
67 | output_hidden_states=True
68 | )['hidden_states']
69 | orig_hidden = [orig_outputs[l][:, -min_length:].detach() for l in target_layers]
70 | pos_outputs = model(
71 | input_ids=pos_input_ids,
72 | attention_mask=pos_attention_mask,
73 | output_hidden_states=True
74 | )['hidden_states']
75 | neg_outputs = model(
76 | input_ids=neg_input_ids,
77 | attention_mask=neg_attention_mask,
78 | output_hidden_states=True
79 | )['hidden_states']
80 | direction_hidden = [pos_outputs[l][:, -min_length:].detach() - \
81 | neg_outputs[l][:, -min_length:].detach() \
82 | # + beta * torch.tensor(pca_directions[l - len(pca_directions)], device=model.device, dtype=torch.float16) \
83 | for l in target_layers]
84 | target_hidden = torch.stack([orig_hidden[i] + alpha * direction_hidden[i] for i in range(len(target_layers))]) * response_attention_mask
85 |
86 | del orig_outputs, pos_outputs, neg_outputs, orig_hidden, direction_hidden
87 | gc.collect()
88 | torch.cuda.empty_cache()
89 |
90 | model.train()
91 | lora_outputs = model(
92 | input_ids=orig_input_ids,
93 | attention_mask=orig_attention_mask,
94 | output_hidden_states=True
95 | )['hidden_states']
96 | lora_hidden = torch.stack([lora_outputs[l][:, -min_length:] for l in target_layers]) * response_attention_mask
97 |
98 | loss_fct = torch.nn.MSELoss()
99 | loss = torch.norm(lora_hidden - target_hidden, dim=-1, p=2, dtype=torch.float).nanmean()
100 | return (loss, lora_hidden) if return_outputs else loss
101 |
102 |
103 | def maybe_zero_3(param):
104 | if hasattr(param, "ds_id"):
105 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
106 | with zero.GatheredParameters([param]):
107 | param = param.data.detach().cpu().clone()
108 | else:
109 | param = param.detach().cpu().clone()
110 | return param
111 |
112 |
113 | # Borrowed from peft.utils.get_peft_model_state_dict
114 | def get_peft_state_maybe_zero_3(named_params, bias):
115 | if bias == "none":
116 | to_return = {k: t for k, t in named_params if "lora_" in k}
117 | elif bias == "all":
118 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
119 | elif bias == "lora_only":
120 | to_return = {}
121 | maybe_lora_bias = {}
122 | lora_bias_names = set()
123 | for k, t in named_params:
124 | if "lora_" in k:
125 | to_return[k] = t
126 | bias_name = k.split("lora_")[0] + "bias"
127 | lora_bias_names.add(bias_name)
128 | elif "bias" in k:
129 | maybe_lora_bias[k] = t
130 | for k, t in maybe_lora_bias:
131 | if bias_name in lora_bias_names:
132 | to_return[bias_name] = t
133 | else:
134 | raise NotImplementedError
135 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
136 | return to_return
137 |
138 | def train():
139 | parser = transformers.HfArgumentParser(
140 | (ModelArguments, TrainingArguments, LoraArguments, LorraArguments)
141 | )
142 | (
143 | model_args,
144 | training_args,
145 | lora_args,
146 | lorra_args,
147 | ) = parser.parse_args_into_dataclasses()
148 |
149 | device_map = "auto"
150 | world_size = int(os.environ.get("WORLD_SIZE", 1))
151 | ddp = world_size != 1
152 | if lora_args.q_lora:
153 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
154 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
155 | logging.warning(
156 | "FSDP and ZeRO3 are both currently incompatible with QLoRA."
157 | )
158 |
159 | compute_dtype = (
160 | torch.float16
161 | if training_args.fp16
162 | else (torch.bfloat16 if training_args.bf16 else torch.float32)
163 | )
164 |
165 | model = transformers.AutoModelForCausalLM.from_pretrained(
166 | model_args.model_name_or_path,
167 | cache_dir=training_args.cache_dir,
168 | device_map=device_map
169 | )
170 |
171 | lorra_target_layers = [int(layer) for layer in lorra_args.target_layers.split(",")] # target representations
172 | lora_layers_to_transform = list(range(lorra_target_layers[-1] + 1)) # LoRA layers
173 |
174 | lora_config = LoraConfig(
175 | r=lora_args.lora_r,
176 | lora_alpha=lora_args.lora_alpha,
177 | target_modules=lora_args.lora_target_modules,
178 | lora_dropout=lora_args.lora_dropout,
179 | bias=lora_args.lora_bias,
180 | layers_to_transform=lora_layers_to_transform,
181 | task_type="CAUSAL_LM",
182 | )
183 |
184 |
185 | if lora_args.q_lora:
186 | model = prepare_model_for_kbit_training(
187 | model, use_gradient_checkpointing=training_args.gradient_checkpointing
188 | )
189 | if not ddp and torch.cuda.device_count() > 1:
190 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
191 | model.is_parallelizable = True
192 | model.model_parallel = True
193 |
194 | model = get_peft_model(model, lora_config)
195 |
196 | if training_args.deepspeed is not None and training_args.local_rank == 0:
197 | model.print_trainable_parameters()
198 |
199 | if training_args.gradient_checkpointing:
200 | model.enable_input_require_grads()
201 |
202 | tokenizer = transformers.AutoTokenizer.from_pretrained(
203 | model_args.model_name_or_path,
204 | cache_dir=training_args.cache_dir,
205 | model_max_length=training_args.model_max_length,
206 | padding_side="left",
207 | use_fast=False,
208 | )
209 | tokenizer.pad_token = tokenizer.unk_token
210 |
211 | train_dataset = AlpacaSupervisedDataset(tokenizer=tokenizer, num_examples=10000, lorra_args=lorra_args)
212 | if training_args.do_eval:
213 | val_datasets = {
214 | "tqa": load_tqa_sentences(lorra_args.user_tag, lorra_args.assistant_tag),
215 | "arc-e": load_arc_sentences(),
216 | }
217 | bsz = training_args.per_device_eval_batch_size
218 | else:
219 | val_datasets = {}
220 |
221 | class CustomTrainer(Trainer):
222 | def compute_loss(self, model, inputs, return_outputs=False):
223 | return compute_loss(self,
224 | model,
225 | inputs,
226 | target_layers=lorra_target_layers,
227 | alpha=lorra_args.lorra_alpha,
228 | beta=lorra_args.lorra_beta,
229 | max_res_len=lorra_args.max_res_len,
230 | return_outputs=return_outputs)
231 |
232 | def evaluate(self, eval_dataset=None, ignore_keys=None, sanity_check=False, **kwargs):
233 | self.model.eval()
234 |
235 | if sanity_check:
236 | print('Sanity check...')
237 | metrics = {}
238 | for val_set in val_datasets:
239 | questions, answer, labels = val_datasets[val_set]
240 | print(f'Evaluating {val_set} accuracy...')
241 | with torch.no_grad():
242 | acc = get_logprobs_accuracy(self.model, self.tokenizer, questions, answer, labels, bsz)
243 | acc_key = 'acc' if val_set == 'tqa' else 'acc_norm'
244 | metrics[f"{val_set}_accuracy"] = acc[acc_key]
245 | self.model.train()
246 | print("===Eval results===")
247 | print(metrics)
248 | return metrics
249 |
250 | trainer = CustomTrainer(
251 | model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset
252 | )
253 | model.config.use_cache = False
254 | trainer.evaluate(eval_dataset=val_datasets, sanity_check=True)
255 |
256 | trainer.train()
257 | trainer.save_state()
258 |
259 | if training_args.local_rank == 0:
260 | # model.save_pretrained(training_args.output_dir) # saving adapter
261 | merged_model = model.merge_and_unload() # saving full model
262 | merged_model.save_pretrained(training_args.output_dir)
263 | tokenizer.save_pretrained(training_args.output_dir)
264 |
265 | if __name__ == "__main__":
266 | train()
--------------------------------------------------------------------------------
/repe/rep_readers.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from sklearn.decomposition import PCA
3 | from sklearn.cluster import KMeans
4 | import numpy as np
5 | from itertools import islice
6 | import torch
7 |
8 | def project_onto_direction(H, direction):
9 | """Project matrix H (n, d_1) onto direction vector (d_2,)"""
10 | # Calculate the magnitude of the direction vector
11 | # Ensure H and direction are on the same device (CPU or GPU)
12 | if type(direction) != torch.Tensor:
13 | H = torch.Tensor(H).cuda()
14 | if type(direction) != torch.Tensor:
15 | direction = torch.Tensor(direction)
16 | direction = direction.to(H.device)
17 | mag = torch.norm(direction)
18 | assert not torch.isinf(mag).any()
19 | # Calculate the projection
20 | projection = H.matmul(direction) / mag
21 | return projection
22 |
23 | def recenter(x, mean=None):
24 | x = torch.Tensor(x).cuda()
25 | if mean is None:
26 | mean = torch.mean(x,axis=0,keepdims=True).cuda()
27 | else:
28 | mean = torch.Tensor(mean).cuda()
29 | return x - mean
30 |
31 | class RepReader(ABC):
32 | """Class to identify and store concept directions.
33 |
34 | Subclasses implement the abstract methods to identify concept directions
35 | for each hidden layer via strategies including PCA, embedding vectors
36 | (aka the logits method), and cluster means.
37 |
38 | RepReader instances are used by RepReaderPipeline to get concept scores.
39 |
40 | Directions can be used for downstream interventions."""
41 |
42 | @abstractmethod
43 | def __init__(self) -> None:
44 | self.direction_method = None
45 | self.directions = None # directions accessible via directions[layer][component_index]
46 | self.direction_signs = None # direction of high concept scores (mapping min/max to high/low)
47 |
48 | @abstractmethod
49 | def get_rep_directions(self, model, tokenizer, hidden_states, hidden_layers, **kwargs):
50 | """Get concept directions for each hidden layer of the model
51 |
52 | Args:
53 | model: Model to get directions for
54 | tokenizer: Tokenizer to use
55 | hidden_states: Hidden states of the model on the training data (per layer)
56 | hidden_layers: Layers to consider
57 |
58 | Returns:
59 | directions: A dict mapping layers to direction arrays (n_components, hidden_size)
60 | """
61 | pass
62 |
63 | def get_signs(self, hidden_states, train_choices, hidden_layers):
64 | """Given labels for the training data hidden_states, determine whether the
65 | negative or positive direction corresponds to low/high concept
66 | (and return corresponding signs -1 or 1 for each layer and component index)
67 |
68 | NOTE: This method assumes that there are 2 entries in hidden_states per label,
69 | aka len(hidden_states[layer]) == 2 * len(train_choices). For example, if
70 | n_difference=1, then hidden_states here should be the raw hidden states
71 | rather than the relative (i.e. the differences between pairs of examples).
72 |
73 | Args:
74 | hidden_states: Hidden states of the model on the training data (per layer)
75 | train_choices: Labels for the training data
76 | hidden_layers: Layers to consider
77 |
78 | Returns:
79 | signs: A dict mapping layers to sign arrays (n_components,)
80 | """
81 | signs = {}
82 |
83 | if self.needs_hiddens and hidden_states is not None and len(hidden_states) > 0:
84 | for layer in hidden_layers:
85 | assert hidden_states[layer].shape[0] == 2 * len(train_choices), f"Shape mismatch between hidden states ({hidden_states[layer].shape[0]}) and labels ({len(train_choices)})"
86 |
87 | signs[layer] = []
88 | for component_index in range(self.n_components):
89 | transformed_hidden_states = project_onto_direction(hidden_states[layer], self.directions[layer][component_index])
90 | projected_scores = [transformed_hidden_states[i:i+2] for i in range(0, len(transformed_hidden_states), 2)]
91 |
92 | outputs_min = [1 if min(o) == o[label] else 0 for o, label in zip(projected_scores, train_choices)]
93 | outputs_max = [1 if max(o) == o[label] else 0 for o, label in zip(projected_scores, train_choices)]
94 |
95 | signs[layer].append(-1 if np.mean(outputs_min) > np.mean(outputs_max) else 1)
96 | else:
97 | for layer in hidden_layers:
98 | signs[layer] = [1 for _ in range(self.n_components)]
99 |
100 | return signs
101 |
102 |
103 | def transform(self, hidden_states, hidden_layers, component_index):
104 | """Project the hidden states onto the concept directions in self.directions
105 |
106 | Args:
107 | hidden_states: dictionary with entries of dimension (n_examples, hidden_size)
108 | hidden_layers: list of layers to consider
109 | component_index: index of the component to use from self.directions
110 |
111 | Returns:
112 | transformed_hidden_states: dictionary with entries of dimension (n_examples,)
113 | """
114 |
115 | assert component_index < self.n_components
116 | transformed_hidden_states = {}
117 | for layer in hidden_layers:
118 | layer_hidden_states = hidden_states[layer]
119 |
120 | if hasattr(self, 'H_train_means'):
121 | layer_hidden_states = recenter(layer_hidden_states, mean=self.H_train_means[layer])
122 |
123 | # project hidden states onto found concept directions (e.g. onto PCA comp 0)
124 | H_transformed = project_onto_direction(layer_hidden_states, self.directions[layer][component_index])
125 | transformed_hidden_states[layer] = H_transformed.cpu().numpy()
126 | return transformed_hidden_states
127 |
128 | class PCARepReader(RepReader):
129 | """Extract directions via PCA"""
130 | needs_hiddens = True
131 |
132 | def __init__(self, n_components=1):
133 | super().__init__()
134 | self.n_components = n_components
135 | self.H_train_means = {}
136 |
137 | def get_rep_directions(self, model, tokenizer, hidden_states, hidden_layers, **kwargs):
138 | """Get PCA components for each layer"""
139 | directions = {}
140 |
141 | for layer in hidden_layers:
142 | H_train = hidden_states[layer]
143 | H_train_mean = H_train.mean(axis=0, keepdims=True)
144 | self.H_train_means[layer] = H_train_mean
145 | H_train = recenter(H_train, mean=H_train_mean).cpu()
146 | H_train = np.vstack(H_train)
147 | pca_model = PCA(n_components=self.n_components, whiten=False).fit(H_train)
148 |
149 | directions[layer] = pca_model.components_ # shape (n_components, n_features)
150 | self.n_components = pca_model.n_components_
151 |
152 | return directions
153 |
154 | def get_signs(self, hidden_states, train_labels, hidden_layers):
155 |
156 | signs = {}
157 |
158 | for layer in hidden_layers:
159 | assert hidden_states[layer].shape[0] == len(np.concatenate(train_labels)), f"Shape mismatch between hidden states ({hidden_states[layer].shape[0]}) and labels ({len(np.concatenate(train_labels))})"
160 | layer_hidden_states = hidden_states[layer]
161 |
162 | # NOTE: since scoring is ultimately comparative, the effect of this is moot
163 | layer_hidden_states = recenter(layer_hidden_states, mean=self.H_train_means[layer])
164 |
165 | # get the signs for each component
166 | layer_signs = np.zeros(self.n_components)
167 | for component_index in range(self.n_components):
168 |
169 | transformed_hidden_states = project_onto_direction(layer_hidden_states, self.directions[layer][component_index]).cpu()
170 |
171 | pca_outputs_comp = [list(islice(transformed_hidden_states, sum(len(c) for c in train_labels[:i]), sum(len(c) for c in train_labels[:i+1]))) for i in range(len(train_labels))]
172 |
173 | # We do elements instead of argmin/max because sometimes we pad random choices in training
174 | pca_outputs_min = np.mean([o[train_labels[i].index(1)] == min(o) for i, o in enumerate(pca_outputs_comp)])
175 | pca_outputs_max = np.mean([o[train_labels[i].index(1)] == max(o) for i, o in enumerate(pca_outputs_comp)])
176 |
177 |
178 | layer_signs[component_index] = np.sign(np.mean(pca_outputs_max) - np.mean(pca_outputs_min))
179 | if layer_signs[component_index] == 0:
180 | layer_signs[component_index] = 1 # default to positive in case of tie
181 |
182 | signs[layer] = layer_signs
183 |
184 | return signs
185 |
186 |
187 |
188 | class ClusterMeanRepReader(RepReader):
189 | """Get the direction that is the difference between the mean of the positive and negative clusters."""
190 | n_components = 1
191 | needs_hiddens = True
192 |
193 | def __init__(self):
194 | super().__init__()
195 |
196 | def get_rep_directions(self, model, tokenizer, hidden_states, hidden_layers, **kwargs):
197 |
198 | # train labels is necessary to differentiate between different classes
199 | train_choices = kwargs['train_choices'] if 'train_choices' in kwargs else None
200 | assert train_choices is not None, "ClusterMeanRepReader requires train_choices to differentiate two clusters"
201 | for layer in hidden_layers:
202 | assert len(train_choices) == len(hidden_states[layer]), f"Shape mismatch between hidden states ({len(hidden_states[layer])}) and labels ({len(train_choices)})"
203 |
204 | train_choices = np.array(train_choices)
205 | neg_class = np.where(train_choices == 0)
206 | pos_class = np.where(train_choices == 1)
207 |
208 | directions = {}
209 | for layer in hidden_layers:
210 | H_train = np.array(hidden_states[layer])
211 |
212 | H_pos_mean = H_train[pos_class].mean(axis=0, keepdims=True)
213 | H_neg_mean = H_train[neg_class].mean(axis=0, keepdims=True)
214 |
215 | directions[layer] = H_pos_mean - H_neg_mean
216 |
217 | return directions
218 |
219 |
220 | class RandomRepReader(RepReader):
221 | """Get random directions for each hidden layer. Do not use hidden
222 | states or train labels of any kind."""
223 |
224 | def __init__(self, needs_hiddens=True):
225 | super().__init__()
226 |
227 | self.n_components = 1
228 | self.needs_hiddens = needs_hiddens
229 |
230 | def get_rep_directions(self, model, tokenizer, hidden_states, hidden_layers, **kwargs):
231 |
232 | directions = {}
233 | for layer in hidden_layers:
234 | directions[layer] = np.expand_dims(np.random.randn(model.config.hidden_size), 0)
235 |
236 | return directions
237 |
238 |
239 | DIRECTION_FINDERS = {
240 | 'pca': PCARepReader,
241 | 'cluster_mean': ClusterMeanRepReader,
242 | 'random': RandomRepReader,
243 | }
--------------------------------------------------------------------------------
/lorra_finetune/scripts/slurm-lorra_tqa_13b-927237.out:
--------------------------------------------------------------------------------
1 | [2023-11-08 20:54:51,032] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2 | [2023-11-08 20:54:53,610] [WARNING] [runner.py:203:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
3 | Detected CUDA_VISIBLE_DEVICES=0: setting --include=localhost:0
4 | [2023-11-08 20:54:53,611] [INFO] [runner.py:570:main] cmd = /data/long_phan/anaconda3/bin/python3.10 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29439 --enable_each_rank_log=None src/llama2_lorra.py --model_name_or_path meta-llama/Llama-2-13b-chat-hf --user_tag [INST] --assistant_tag [/INST] --pos_type a truthful --neg_type an untruthful --target_layers 10,13,16,19,22,25,28,31,34,37 --lora_r 8 --lora_alpha 16 --lora_dropout 0.05 --output_dir ./lorra_tqa_13b --overwrite_output_dir --max_steps 70 --fp16 True --per_device_train_batch_size 16 --per_device_eval_batch_size 32 --gradient_accumulation_steps 1 --evaluation_strategy steps --eval_steps 5 --save_total_limit 0 --learning_rate 3e-4 --weight_decay 0. --lr_scheduler_type constant --logging_strategy steps --logging_steps 10 --tf32 True --model_max_length 128 --q_lora False --deepspeed configs/ds_zero1.json --gradient_checkpointing True --report_to none
5 | [2023-11-08 20:54:55,720] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
6 | [2023-11-08 20:54:58,117] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0]}
7 | [2023-11-08 20:54:58,117] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
8 | [2023-11-08 20:54:58,118] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0]})
9 | [2023-11-08 20:54:58,118] [INFO] [launch.py:163:main] dist_world_size=1
10 | [2023-11-08 20:54:58,118] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0
11 | [2023-11-08 20:55:00,171] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
12 | 2023-11-08 20:55:03.138534: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
13 | To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
14 | 2023-11-08 20:55:04.072017: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
15 | /data/long_phan/anaconda3/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
16 | warnings.warn(
17 | [2023-11-08 20:55:06,175] [INFO] [comm.py:637:init_distributed] cdb=None
18 | [2023-11-08 20:55:06,176] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
19 |
Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]
Loading checkpoint shards: 33%|███▎ | 1/3 [00:08<00:16, 8.11s/it]
Loading checkpoint shards: 67%|██████▋ | 2/3 [00:16<00:08, 8.04s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:21<00:00, 6.79s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:21<00:00, 7.13s/it]
20 | trainable params: 6,225,920 || all params: 13,022,090,240 || trainable%: 0.04781045043656525
21 | Found cached dataset parquet (/data/long_phan/.cache/huggingface/datasets/tatsu-lab___parquet/tatsu-lab--alpaca-2b32f0433506ef5f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
22 |
0%| | 0/1 [00:00, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 621.65it/s]
23 | Loading cached processed dataset at /data/long_phan/.cache/huggingface/datasets/tatsu-lab___parquet/tatsu-lab--alpaca-2b32f0433506ef5f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-df7179692787ca76.arrow
24 | Found cached dataset truthful_qa (/data/long_phan/.cache/huggingface/datasets/truthful_qa/multiple_choice/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)
25 |
0%| | 0/1 [00:00, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 893.36it/s]
26 | Found cached dataset ai2_arc (/data/long_phan/.cache/huggingface/datasets/ai2_arc/ARC-Easy/1.0.0/1569c2591ea2683779581d9fb467203d9aa95543bb9b75dcfde5da92529fd7f6)
27 |
0%| | 0/3 [00:00, ?it/s]
100%|██████████| 3/3 [00:00<00:00, 1181.38it/s]
28 | Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
29 | Sanity check...
30 | Evaluating tqa accuracy...
31 | Evaluating arc-e accuracy...
32 | ===Eval results===
33 | {'tqa_accuracy': 0.3574051407588739, 'arc-e_accuracy': 0.7140350877192982}
34 | Using /data/long_phan/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
35 | Detected CUDA files, patching ldflags
36 | Emitting ninja build file /data/long_phan/.cache/torch_extensions/py310_cu117/fused_adam/build.ninja...
37 | Building extension module fused_adam...
38 | Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
39 | ninja: no work to do.
40 | Loading extension module fused_adam...
41 | Time to load fused_adam op: 0.14783310890197754 seconds
42 |
0%| | 0/70 [00:00, ?it/s]
1%|▏ | 1/70 [00:04<05:18, 4.62s/it]
3%|▎ | 2/70 [00:09<05:11, 4.58s/it]
4%|▍ | 3/70 [00:13<05:06, 4.57s/it]
6%|▌ | 4/70 [00:18<05:01, 4.57s/it]
7%|▋ | 5/70 [00:22<04:56, 4.56s/it]Evaluating tqa accuracy...
43 | Evaluating arc-e accuracy...
44 | ===Eval results===
45 | {'tqa_accuracy': 0.3818849449204406, 'arc-e_accuracy': 0.7298245614035088}
46 |
9%|▊ | 6/70 [01:29<27:33, 25.84s/it]
10%|█ | 7/70 [01:34<19:49, 18.89s/it]
11%|█▏ | 8/70 [01:39<14:48, 14.33s/it]
13%|█▎ | 9/70 [01:43<11:28, 11.28s/it]
14%|█▍ | 10/70 [01:48<09:12, 9.21s/it]
{'loss': 96.7611, 'learning_rate': 0.0003, 'epoch': 0.02}
47 |
14%|█▍ | 10/70 [01:48<09:12, 9.21s/it]Evaluating tqa accuracy...
48 | Evaluating arc-e accuracy...
49 | ===Eval results===
50 | {'tqa_accuracy': 0.41982864137086906, 'arc-e_accuracy': 0.7228070175438597}
51 |
16%|█▌ | 11/70 [02:55<26:31, 26.98s/it]
17%|█▋ | 12/70 [03:00<19:29, 20.16s/it]
19%|█▊ | 13/70 [03:04<14:40, 15.44s/it]
20%|██ | 14/70 [03:09<11:20, 12.16s/it]
21%|██▏ | 15/70 [03:13<09:02, 9.87s/it]Evaluating tqa accuracy...
52 | Evaluating arc-e accuracy...
53 | ===Eval results===
54 | {'tqa_accuracy': 0.4528763769889841, 'arc-e_accuracy': 0.7192982456140351}
55 |
23%|██▎ | 16/70 [04:21<24:26, 27.16s/it]
24%|██▍ | 17/70 [04:25<17:59, 20.36s/it]
26%|██▌ | 18/70 [04:30<13:32, 15.62s/it]
27%|██▋ | 19/70 [04:34<10:27, 12.30s/it]
29%|██▊ | 20/70 [04:39<08:19, 9.98s/it]
{'loss': 88.4875, 'learning_rate': 0.0003, 'epoch': 0.03}
56 |
29%|██▊ | 20/70 [04:39<08:19, 9.98s/it]Evaluating tqa accuracy...
57 | Evaluating arc-e accuracy...
58 | ===Eval results===
59 | {'tqa_accuracy': 0.4663402692778458, 'arc-e_accuracy': 0.743859649122807}
60 |
30%|███ | 21/70 [05:46<22:12, 27.19s/it]
31%|███▏ | 22/70 [05:51<16:19, 20.41s/it]
33%|███▎ | 23/70 [05:55<12:15, 15.65s/it]
34%|███▍ | 24/70 [06:00<09:27, 12.33s/it]
36%|███▌ | 25/70 [06:05<07:29, 10.00s/it]Evaluating tqa accuracy...
61 | Evaluating arc-e accuracy...
62 | ===Eval results===
63 | {'tqa_accuracy': 0.4455324357405141, 'arc-e_accuracy': 0.756140350877193}
64 |
37%|███▋ | 26/70 [07:12<19:57, 27.21s/it]
39%|███▊ | 27/70 [07:16<14:38, 20.42s/it]
40%|████ | 28/70 [07:21<10:58, 15.67s/it]
41%|████▏ | 29/70 [07:26<08:26, 12.35s/it]
43%|████▎ | 30/70 [07:30<06:40, 10.02s/it]
{'loss': 87.2752, 'learning_rate': 0.0003, 'epoch': 0.05}
65 |
43%|████▎ | 30/70 [07:30<06:40, 10.02s/it]Evaluating tqa accuracy...
66 | Evaluating arc-e accuracy...
67 | ===Eval results===
68 | {'tqa_accuracy': 0.4467564259485924, 'arc-e_accuracy': 0.7578947368421053}
69 |
44%|████▍ | 31/70 [08:38<17:40, 27.20s/it]
46%|████▌ | 32/70 [08:42<12:55, 20.41s/it]
47%|████▋ | 33/70 [08:47<09:39, 15.66s/it]
49%|████▊ | 34/70 [08:51<07:24, 12.34s/it]
50%|█████ | 35/70 [08:56<05:50, 10.01s/it]Evaluating tqa accuracy...
70 | Evaluating arc-e accuracy...
71 | ===Eval results===
72 | {'tqa_accuracy': 0.46878824969400246, 'arc-e_accuracy': 0.7543859649122807}
73 |
51%|█████▏ | 36/70 [10:03<15:24, 27.20s/it]
53%|█████▎ | 37/70 [10:08<11:13, 20.41s/it]
54%|█████▍ | 38/70 [10:12<08:21, 15.67s/it]
56%|█████▌ | 39/70 [10:17<06:23, 12.35s/it]
57%|█████▋ | 40/70 [10:22<05:00, 10.03s/it]
{'loss': 89.6581, 'learning_rate': 0.0003, 'epoch': 0.06}
74 |
57%|█████▋ | 40/70 [10:22<05:00, 10.03s/it]Evaluating tqa accuracy...
75 | Evaluating arc-e accuracy...
76 | ===Eval results===
77 | {'tqa_accuracy': 0.47613219094247244, 'arc-e_accuracy': 0.7526315789473684}
78 |
59%|█████▊ | 41/70 [11:29<13:09, 27.23s/it]
60%|██████ | 42/70 [11:33<09:31, 20.43s/it]
61%|██████▏ | 43/70 [11:38<07:03, 15.67s/it]
63%|██████▎ | 44/70 [11:43<05:20, 12.34s/it]
64%|██████▍ | 45/70 [11:47<04:10, 10.01s/it]Evaluating tqa accuracy...
79 | Evaluating arc-e accuracy...
80 | ===Eval results===
81 | {'tqa_accuracy': 0.46878824969400246, 'arc-e_accuracy': 0.7631578947368421}
82 |
66%|██████▌ | 46/70 [12:54<10:52, 27.18s/it]
67%|██████▋ | 47/70 [12:59<07:49, 20.40s/it]
69%|██████▊ | 48/70 [13:04<05:44, 15.65s/it]
70%|███████ | 49/70 [13:08<04:18, 12.33s/it]
71%|███████▏ | 50/70 [13:13<03:20, 10.01s/it]
{'loss': 89.0171, 'learning_rate': 0.0003, 'epoch': 0.08}
83 |
71%|███████▏ | 50/70 [13:13<03:20, 10.01s/it]Evaluating tqa accuracy...
84 | Evaluating arc-e accuracy...
85 | ===Eval results===
86 | {'tqa_accuracy': 0.48225214198286415, 'arc-e_accuracy': 0.7596491228070176}
87 |
73%|███████▎ | 51/70 [14:20<08:36, 27.19s/it]
74%|███████▍ | 52/70 [14:25<06:07, 20.40s/it]
76%|███████▌ | 53/70 [14:29<04:26, 15.65s/it]
77%|███████▋ | 54/70 [14:34<03:17, 12.33s/it]
79%|███████▊ | 55/70 [14:38<02:29, 10.00s/it]Evaluating tqa accuracy...
88 | Evaluating arc-e accuracy...
89 | ===Eval results===
90 | {'tqa_accuracy': 0.4785801713586291, 'arc-e_accuracy': 0.7526315789473684}
91 |
80%|████████ | 56/70 [15:45<06:19, 27.13s/it]
81%|████████▏ | 57/70 [15:50<04:24, 20.36s/it]
83%|████████▎ | 58/70 [15:55<03:07, 15.62s/it]
84%|████████▍ | 59/70 [15:59<02:15, 12.30s/it]
86%|████████▌ | 60/70 [16:04<01:39, 9.98s/it]
{'loss': 86.5269, 'learning_rate': 0.0003, 'epoch': 0.1}
92 |
86%|████████▌ | 60/70 [16:04<01:39, 9.98s/it]Evaluating tqa accuracy...
93 | Evaluating arc-e accuracy...
94 | ===Eval results===
95 | {'tqa_accuracy': 0.46266829865361075, 'arc-e_accuracy': 0.756140350877193}
96 |
87%|████████▋ | 61/70 [17:11<04:03, 27.11s/it]
89%|████████▊ | 62/70 [17:15<02:42, 20.34s/it]
90%|█████████ | 63/70 [17:20<01:49, 15.61s/it]
91%|█████████▏| 64/70 [17:24<01:13, 12.30s/it]
93%|█████████▎| 65/70 [17:29<00:49, 9.97s/it]Evaluating tqa accuracy...
97 | Evaluating arc-e accuracy...
98 | ===Eval results===
99 | {'tqa_accuracy': 0.46266829865361075, 'arc-e_accuracy': 0.7596491228070176}
100 |
94%|█████████▍| 66/70 [18:36<01:48, 27.08s/it]
96%|█████████▌| 67/70 [18:41<01:00, 20.33s/it]
97%|█████████▋| 68/70 [18:45<00:31, 15.60s/it]
99%|█████████▊| 69/70 [18:50<00:12, 12.30s/it]
100%|██████████| 70/70 [18:54<00:00, 9.98s/it]
{'loss': 83.5399, 'learning_rate': 0.0003, 'epoch': 0.11}
101 |
100%|██████████| 70/70 [18:54<00:00, 9.98s/it]Evaluating tqa accuracy...
102 | Evaluating arc-e accuracy...
103 | ===Eval results===
104 | {'tqa_accuracy': 0.4675642594859241, 'arc-e_accuracy': 0.7456140350877193}
105 | Evaluating tqa accuracy...
106 | Evaluating arc-e accuracy...
107 | ===Eval results===
108 | {'tqa_accuracy': 0.4675642594859241, 'arc-e_accuracy': 0.7456140350877193}
109 |
{'train_runtime': 1260.1388, 'train_samples_per_second': 0.889, 'train_steps_per_second': 0.056, 'train_loss': 88.75225306919643, 'epoch': 0.11}
110 |
100%|██████████| 70/70 [21:00<00:00, 9.98s/it]
100%|██████████| 70/70 [21:00<00:00, 18.00s/it]
111 | [2023-11-08 21:18:25,691] [INFO] [launch.py:347:main] Process 244199 exits successfully.
112 |
--------------------------------------------------------------------------------
/data/emotions/all_truncated_outputs.json:
--------------------------------------------------------------------------------
1 | [
2 | "",
3 | "That game",
4 | "I can see",
5 | "Hmm, this",
6 | "I can relate to",
7 | "Who is",
8 | "I understand the",
9 | "Ugh,",
10 | "What the hell was",
11 | "Hey, did anyone",
12 | "Although",
13 | "Thank you for choosing",
14 | "What are you",
15 | "Oh w",
16 | "How dare you open",
17 | "It was my pleasure",
18 | "I'm hon",
19 | "I appreciate that you",
20 | "Are you k",
21 | "Whoever left this",
22 | "It's always",
23 | "Ew,",
24 | "Hey, I l",
25 | "Hello? Is someone",
26 | "I understand that",
27 | "That poem",
28 | "Aww, poor",
29 | "Hey, it",
30 | "Alright, who",
31 | "I didn't",
32 | "Well, life",
33 | "The document",
34 | "Oh no, this",
35 | "I'm concerned",
36 | "Hello, this is",
37 | "This art",
38 | "Hmm, this drink",
39 | "Hi there!",
40 | "It seems",
41 | "Is",
42 | "Good",
43 | "I can't",
44 | "Ex",
45 | "Who are",
46 | "I can see that",
47 | "Wow,",
48 | "Today is a",
49 | "Hey friend",
50 | "Sometimes friends",
51 | "Oh, this old",
52 | "The weather outside",
53 | "This place is sur",
54 | "I appreciate your input",
55 | "Thank you for the",
56 | "Look at",
57 | "I'm disappoint",
58 | "To my",
59 | "How dare you",
60 | "That's an",
61 | "This piece of art",
62 | "Eww",
63 | "This park is",
64 | "This is incredible",
65 | "Oh no, someone",
66 | "Exc",
67 | "Well, it'",
68 | "I warned",
69 | "Hey, I understand",
70 | "Hey, I saw",
71 | "How dare you go",
72 | "What the he",
73 | "Hey",
74 | "It's",
75 | "Hello? Hello?",
76 | "It",
77 | "Oh no!",
78 | "This is the perfect",
79 | "Good morning,",
80 | "Oh no, there",
81 | "It's so",
82 | "Yeah",
83 | "Uh,",
84 | "Hello everyone",
85 | "Who turned off",
86 | "The weather",
87 | "Who'",
88 | "Hey, this",
89 | "Wait,",
90 | "Eww, gross",
91 | "Excuse",
92 | "It seems like you",
93 | "Thank you so",
94 | "What happened?",
95 | "Oh my g",
96 | "I am deeply sad",
97 | "I war",
98 | "Okay, let'",
99 | "Hey, that",
100 | "That was a beautiful",
101 | "Oh no! That",
102 | "What happened",
103 | "Hey there",
104 | "The artist'",
105 | "What?!",
106 | "Hey, it'",
107 | "I am disappoint",
108 | "It seems like",
109 | "Oh no! The",
110 | "This park is a",
111 | "If you",
112 | "Yes! I did",
113 | "It sounds",
114 | "What",
115 | "Who is it",
116 | "Hmm, that",
117 | "That's strange",
118 | "Yeah, that was",
119 | "That's interesting",
120 | "This park",
121 | "What the hell",
122 | "Who is that",
123 | "I feel like my",
124 | "Oh well",
125 | "What the hell is",
126 | "Hello? Hello",
127 | "To my dearest",
128 | "Bless you!\"",
129 | "Thank you for",
130 | "Oh, looks like",
131 | "Can you please",
132 | "This place is",
133 | "Eww, what",
134 | "Bless you",
135 | "Is everything",
136 | "Hey, I just",
137 | "Whoever left these",
138 | "Well, that'",
139 | "I feel",
140 | "Hey, do you",
141 | "It's sad",
142 | "Oh no, it",
143 | "Hey, that'",
144 | "Oh my god,",
145 | "Thank you,",
146 | "Hello little one,",
147 | "I apolog",
148 | "Hey team, I",
149 | "How dare you read",
150 | "Who is this and",
151 | "Whoever left",
152 | "Hi there! W",
153 | "A",
154 | "If you have",
155 | "I was",
156 | "U",
157 | "Bless",
158 | "Well, this",
159 | "Oh, I'",
160 | "It's a",
161 | "Eww,",
162 | "Is everything okay?",
163 | "Oh, I",
164 | "Hello, can you",
165 | "Al",
166 | "That was a great",
167 | "What are",
168 | "I understand that not",
169 | "Oh no, not",
170 | "Who is it?\"",
171 | "Hey, can we",
172 | "Whoever is taking",
173 | "I would love to",
174 | "Hey, I noticed",
175 | "Hey, could",
176 | "I understand that there",
177 | "Hello?",
178 | "D",
179 | "Oh man, I",
180 | "Thank you so much",
181 | "Oh no, my",
182 | "Dear [Name",
183 | "Uh",
184 | "I remember",
185 | "Hey, who",
186 | "Well, it",
187 | "Are you",
188 | "I understand that it",
189 | "Hey, is",
190 | "I would",
191 | "Who is this",
192 | "Excuse me",
193 | "Alright",
194 | "I am thrilled",
195 | "Sometimes friends have",
196 | "Who the",
197 | "It's interesting",
198 | "I would love",
199 | "E",
200 | "Hello? Is anyone",
201 | "Well, this is",
202 | "This place",
203 | "Well,",
204 | "I warned you",
205 | "Hey, watch where",
206 | "Oh my",
207 | "That'",
208 | "Sometimes friends have different",
209 | "I understand that everyone",
210 | "What?",
211 | "What do these notes",
212 | "I can relate",
213 | "I'm not",
214 | "I understand",
215 | "To my dear",
216 | "Guys",
217 | "Well",
218 | "Hey, I appreciate",
219 | "Wow, what",
220 | "Dear",
221 | "That melody",
222 | "Who the hell",
223 | "Today is",
224 | "Hello little",
225 | "Wow, look",
226 | "That's great",
227 | "Love is never wrong",
228 | "I'm having",
229 | "Whoa, did",
230 | "Ugh",
231 | "Can you please provide",
232 | "I miss you,",
233 | "I feel uncom",
234 | "I know",
235 | "Ugh, this",
236 | "Hey, watch",
237 | "Oh great, a",
238 | "I didn",
239 | "Okay",
240 | "That game of char",
241 | "Oh",
242 | "I appreciate",
243 | "Who's there",
244 | "I am so",
245 | "Oh great, someone",
246 | "Hey, could you",
247 | "I remember wondering",
248 | "Wait, what?",
249 | "What do",
250 | "Hello? Can",
251 | "Hey there,",
252 | "That game of",
253 | "This is incred",
254 | "Oh my gosh",
255 | "Oh great, f",
256 | "I appreciate your",
257 | "It sounds like",
258 | "What the heck",
259 | "Okay, I understand",
260 | "Ew",
261 | "I understand that this",
262 | "Uh, hi",
263 | "Hi everyone!",
264 | "What the hell?",
265 | "Thank you for your",
266 | "Oh no, the",
267 | "Wow, I",
268 | "Who turned",
269 | "Dear [",
270 | "Whoever",
271 | "This is a",
272 | "Whoa, he",
273 | "What in the world",
274 | "Although the physical",
275 | "Hello, who is",
276 | "That's amaz",
277 | "Hey, I know",
278 | "Okay, that",
279 | "Hi everyone",
280 | "Hey, is everything",
281 | "I understand your fr",
282 | "Oh no, poor",
283 | "Oh, look",
284 | "Good morning",
285 | "Ew, gross",
286 | "Oh no, did",
287 | "Look at the family",
288 | "Hey team",
289 | "Yes!",
290 | "Hey, can I",
291 | "Okay, that'",
292 | "It's great",
293 | "Love is",
294 | "Hey, what",
295 | "Good morning, world",
296 | "Who is it?",
297 | "That poem really reson",
298 | "I",
299 | "That's",
300 | "I understand the task",
301 | "Gu",
302 | "Hello? Who'",
303 | "This postcard is",
304 | "Whoa,",
305 | "Oh, that",
306 | "I understand that I",
307 | "Whoever is",
308 | "Hello? Who is",
309 | "I'm really",
310 | "Wow, this",
311 | "Can",
312 | "This artwork really",
313 | "This is a shame",
314 | "I miss you too",
315 | "Who are you?",
316 | "Today is a difficult",
317 | "Hey, just",
318 | "Are you okay",
319 | "I am",
320 | "Hi,",
321 | "Wow, that",
322 | "Hey there! Can",
323 | "Okay, stay",
324 | "Oh great, just",
325 | "Yeah,",
326 | "Hello? Can you",
327 | "Oh, looks",
328 | "Thank you for sharing",
329 | "I'm glad",
330 | "Hey, is that",
331 | "Hmm",
332 | "It was my",
333 | "It sounds like you",
334 | "Wow, your",
335 | "I was promised certain",
336 | "That was such a",
337 | "Thank",
338 | "Excuse you",
339 | "That was",
340 | "Hey team,",
341 | "I feel un",
342 | "It was",
343 | "What'",
344 | "Hey friend, I",
345 | "How",
346 | "Saying goodbye",
347 | "That",
348 | "It's heart",
349 | "How dare",
350 | "Oh,",
351 | "Hello, may",
352 | "What's this",
353 | "Thank you for recogn",
354 | "Aww, that",
355 | "Oh, I remember",
356 | "Hmm, that'",
357 | "I miss",
358 | "I know this",
359 | "Wait",
360 | "Is everything okay",
361 | "Who is that person",
362 | "Wow, you",
363 | "Oh great",
364 | "I'm sad",
365 | "Wow, the",
366 | "I am very disappoint",
367 | "Who turned off the",
368 | "I understand that things",
369 | "I'm very",
370 | "Hi",
371 | "That's very",
372 | "Okay, I",
373 | "Oh no,",
374 | "Wow, there",
375 | "What's wrong",
376 | "I apologize for",
377 | "Hey, I",
378 | "Can I help you",
379 | "Oh, I didn",
380 | "Alright,",
381 | "Oh wow,",
382 | "Oh my goodness",
383 | "I know this event",
384 | "What in the",
385 | "Saying",
386 | "Yeah, that",
387 | "Guys, I",
388 | "Hey, this v",
389 | "This post",
390 | "Are",
391 | "Hey, can",
392 | "Hello? Is",
393 | "I can only imagine",
394 | "Oh, that sounds",
395 | "Hey, is anyone",
396 | "I am disappointed",
397 | "Hello,",
398 | "Hey everyone, I",
399 | "That was such",
400 | "It's okay",
401 | "The artist",
402 | "Whoa",
403 | "I understand that mistakes",
404 | "Can I help",
405 | "Who",
406 | "Hi everyone! I",
407 | "Hey, can you",
408 | "Wow, how",
409 | "Today",
410 | "Oh no, I",
411 | "Oh well, I",
412 | "Well, that",
413 | "This is the",
414 | "Yes! I finally",
415 | "Hey there little",
416 | "Hello everyone!",
417 | "Love is never",
418 | "Look at the",
419 | "This postcard",
420 | "Oh great,",
421 | "Can I",
422 | "Hmm, this is",
423 | "I understand your",
424 | "Oh, look at",
425 | "B",
426 | "I'm so",
427 | "Whoa, this",
428 | "W",
429 | "Oh, this",
430 | "Sometimes",
431 | "This piece of",
432 | "What the",
433 | "That was a",
434 | "Hey, do",
435 | "Oh no",
436 | "Whoa, what",
437 | "I feel like I",
438 | "The documentary",
439 | "Hello",
440 | "Hello little one",
441 | "I understand that my",
442 | "Eww, that",
443 | "Wow, an",
444 | "Yes! Finally,",
445 | "Although the physical location",
446 | "Whoever is watching",
447 | "That movie",
448 | "I remember wondering about",
449 | "Hey there, little",
450 | "Who's",
451 | "Hello, who",
452 | "Hello everyone! Thank",
453 | "Hello, can",
454 | "That's too",
455 | "Hey, just wanted",
456 | "Hey there, I",
457 | "Saying good",
458 | "Hey there!",
459 | "Who is there?",
460 | "Oh my good",
461 | "I am very",
462 | "Oh no, what",
463 | "Wow, thank",
464 | "I was promised",
465 | "Hi, is",
466 | "Hey, I'",
467 | "Guys, the",
468 | "Oh no, that",
469 | "Who is there",
470 | "Hello, this",
471 | "That movie really touched",
472 | "If you have something",
473 | "The documentary was",
474 | "I'm starting",
475 | "Are you kidd",
476 | "That movie really",
477 | "Hey everyone,",
478 | "Thank you for considering",
479 | "I didn'",
480 | "Yes! I",
481 | "Can you",
482 | "Oh my god",
483 | "Hey, whoever",
484 | "That melody really",
485 | "Thank you, little",
486 | "Hello, may I",
487 | "Look",
488 | "Wow, we",
489 | "It looks",
490 | "What do these",
491 | "Oh wow",
492 | "I apologize",
493 | "What are you all",
494 | "It's such",
495 | "It's clear",
496 | "Hey, I was",
497 | "Hey friend,",
498 | "I can only",
499 | "The weather outside is",
500 | "Eww, this",
501 | "I miss you",
502 | "Wow",
503 | "Aww,",
504 | "Hi, is there",
505 | "This artwork",
506 | "Okay,",
507 | "Oh well,",
508 | "This",
509 | "I'",
510 | "Say",
511 | "Hey there little gu",
512 | "Hmm,",
513 | "Whoa, who",
514 | "I am thr",
515 | "Oh man",
516 | "Okay, stay calm",
517 | "I'm happy",
518 | "Oh, this cur",
519 | "Oh man,",
520 | "I'm sorry",
521 | "Hello? Who",
522 | "What?! That",
523 | "This piece",
524 | "Hey everyone",
525 | "That's so",
526 | "Are you okay?",
527 | "What happened? Where",
528 | "Hi there",
529 | "The",
530 | "Who the hell entered",
531 | "I can",
532 | "Guys,",
533 | "What's",
534 | "What in",
535 | "It's important",
536 | "I'm",
537 | "I'm coming",
538 | "It'",
539 | "Yes! Finally",
540 | "Wait, what",
541 | "Wow, reading",
542 | "I'm surprised",
543 | "Hey, did",
544 | "Hey,",
545 | "Okay, let",
546 | "I understand that you",
547 | "Who the hell threw",
548 | "Eww, who",
549 | "Thank you for thinking",
550 | "Who is this?\"",
551 | "I am deeply",
552 | "Thank you for including",
553 | "Oh no, an",
554 | "It looks like you",
555 | "Aww",
556 | "I'm confused",
557 | "Wow, it",
558 | "That poem really",
559 | "Yes",
560 | "Hey there, is",
561 | "Hey, what'",
562 | "Thank you for remember",
563 | "To",
564 | "This is",
565 | "Thank you for making",
566 | "I can'",
567 | "That mel",
568 | "Wow, they",
569 | "I feel like",
570 | "Although the",
571 | "Who are you",
572 | "Love",
573 | "If",
574 | "What the hell are",
575 | "I am so sad",
576 | "Oh, I found",
577 | "Thank you",
578 | "It looks like",
579 | "Well, life is",
580 | "I appreciate that",
581 | "The artist's",
582 | "Whoa, that",
583 | "It's never"
584 | ]
--------------------------------------------------------------------------------
/repe/rep_control_reading_vec.py:
--------------------------------------------------------------------------------
1 | # wrapping classes
2 | import torch
3 | import numpy as np
4 |
5 | class WrappedBlock(torch.nn.Module):
6 | def __init__(self, block):
7 | super().__init__()
8 | self.block = block
9 | self.output = None
10 | self.controller = None
11 | self.mask = None
12 | self.token_pos = None
13 | self.normalize = False
14 |
15 | def forward(self, *args, **kwargs):
16 | output = self.block(*args, **kwargs)
17 |
18 | if isinstance(output, tuple):
19 | self.output = output[0]
20 | modified = output[0]
21 | else:
22 | self.output = output
23 | modified = output
24 |
25 |
26 | if self.controller is not None:
27 |
28 | norm_pre = torch.norm(modified, dim=-1, keepdim=True)
29 |
30 | if self.mask is not None:
31 | mask = self.mask
32 |
33 | # we should ignore the padding tokens when doing the activation addition
34 | # mask has ones for non padding tokens and zeros at padding tokens.
35 | # only tested this on left padding
36 | elif "position_ids" in kwargs:
37 | pos = kwargs["position_ids"]
38 | zero_indices = (pos == 0).cumsum(1).argmax(1, keepdim=True)
39 | col_indices = torch.arange(pos.size(1), device=pos.device).unsqueeze(0)
40 | target_shape = modified.shape
41 | mask = (col_indices >= zero_indices).float().reshape(target_shape[0], target_shape[1], 1)
42 | mask = mask.to(modified.dtype)
43 | else:
44 | # print(f"Warning: block {self.block_name} does not contain information 'position_ids' about token types. When using batches this can lead to unexpected results.")
45 | mask = 1.0
46 |
47 | if len(self.controller.shape) == 1:
48 | self.controller = self.controller.reshape(1, 1, -1)
49 | assert len(self.controller.shape) == len(modified.shape), f"Shape of controller {self.controller.shape} does not match shape of modified {modified.shape}."
50 |
51 | self.controller = self.controller.to(modified.device)
52 | if type(mask) == torch.Tensor:
53 | mask = mask.to(modified.device)
54 | if isinstance(self.token_pos, int):
55 | modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller * mask)
56 | elif isinstance(self.token_pos, list) or isinstance(self.token_pos, tuple) or isinstance(self.token_pos, np.ndarray):
57 | modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller * mask)
58 | elif isinstance(self.token_pos, str):
59 | if self.token_pos == "end":
60 | len_token = self.controller.shape[1]
61 | modified[:, -len_token:] = self.operator(modified[:, -len_token:], self.controller * mask)
62 | elif self.token_pos == "start":
63 | len_token = self.controller.shape[1]
64 | modified[:, :len_token] = self.operator(modified[:, :len_token], self.controller * mask)
65 | else:
66 | assert False, f"Unknown token position {self.token_pos}."
67 | else:
68 | modified = self.operator(modified, self.controller * mask)
69 |
70 | if self.normalize:
71 | norm_post = torch.norm(modified, dim=-1, keepdim=True)
72 | modified = modified / norm_post * norm_pre
73 |
74 | if isinstance(output, tuple):
75 | output = (modified,) + output[1:]
76 | else:
77 | output = modified
78 |
79 | return output
80 |
81 | def set_controller(self, activations, token_pos=None, masks=None, normalize=False, operator='linear_comb'):
82 | self.normalize = normalize
83 | self.controller = activations.squeeze()
84 | self.mask = masks
85 | self.token_pos = token_pos
86 | if operator == 'linear_comb':
87 | def op(current, controller):
88 | return current + controller
89 | elif operator == 'piecewise_linear':
90 | def op(current, controller):
91 | sign = torch.sign((current * controller).sum(-1, keepdim=True))
92 | return current + controller * sign
93 | elif operator == 'projection':
94 | def op(current, controller):
95 | raise NotImplementedError
96 | else:
97 | raise NotImplementedError(f"Operator {operator} not implemented.")
98 | self.operator = op
99 |
100 | def reset(self):
101 | self.output = None
102 | self.controller = None
103 | self.mask = None
104 | self.token_pos = None
105 | self.operator = None
106 |
107 | def set_masks(self, masks):
108 | self.mask = masks
109 |
110 |
111 | BLOCK_NAMES = [
112 | "self_attn",
113 | "mlp",
114 | "input_layernorm",
115 | "post_attention_layernorm"
116 | ]
117 |
118 | class WrappedReadingVecModel(torch.nn.Module):
119 | def __init__(self, model, tokenizer):
120 | super().__init__()
121 | self.model = model
122 | self.tokenizer = tokenizer
123 |
124 | def forward(self, *args, **kwargs):
125 | return self.model(*args, **kwargs)
126 |
127 | def generate(self, **kwargs):
128 | return self.model.generate(**kwargs)
129 |
130 | def get_logits(self, tokens):
131 | with torch.no_grad():
132 | logits = self.model(tokens.to(self.model.device)).logits
133 | return logits
134 |
135 | def run_prompt(self, prompt, **kwargs):
136 | with torch.no_grad():
137 | inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, max_length=512, truncation=True)
138 | input_ids = inputs.input_ids.to(self.model.device)
139 | attention_mask = inputs.attention_mask.to(self.model.device)
140 | output = self.model(input_ids, attention_mask=attention_mask)
141 | return output
142 |
143 | def wrap(self, layer_id, block_name):
144 | assert block_name in BLOCK_NAMES
145 | if self.is_wrapped(self.model.model.layers[layer_id]):
146 | block = getattr(self.model.model.layers[layer_id].block, block_name)
147 | if not self.is_wrapped(block):
148 | setattr(self.model.model.layers[layer_id].block, block_name, WrappedBlock(block))
149 | else:
150 | block = getattr(self.model.model.layers[layer_id], block_name)
151 | if not self.is_wrapped(block):
152 | setattr(self.model.model.layers[layer_id], block_name, WrappedBlock(block))
153 |
154 | def wrap_decoder_block(self, layer_id):
155 | block = self.model.model.layers[layer_id]
156 | if not self.is_wrapped(block):
157 | self.model.model.layers[layer_id] = WrappedBlock(block)
158 |
159 | def wrap_all(self):
160 | for layer_id, layer in enumerate(self.model.model.layers):
161 | for block_name in BLOCK_NAMES:
162 | self.wrap(layer_id, block_name)
163 | self.wrap_decoder_block(layer_id)
164 |
165 | def wrap_block(self, layer_ids, block_name):
166 | def _wrap_block(layer_id, block_name):
167 | if block_name in BLOCK_NAMES:
168 | self.wrap(layer_id, block_name)
169 | elif block_name == 'decoder_block':
170 | self.wrap_decoder_block(layer_id)
171 | else:
172 | assert False, f"No block named {block_name}."
173 |
174 | if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray):
175 | for layer_id in layer_ids:
176 | _wrap_block(layer_id, block_name)
177 | else:
178 | _wrap_block(layer_ids, block_name)
179 |
180 | def get_activations(self, layer_ids, block_name='decoder_block'):
181 |
182 | def _get_activations(layer_id, block_name):
183 | current_layer = self.model.model.layers[layer_id]
184 |
185 | if self.is_wrapped(current_layer):
186 | current_block = current_layer.block
187 | if block_name == 'decoder_block':
188 | return current_layer.output
189 | elif block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_block, block_name)):
190 | return getattr(current_block, block_name).output
191 | else:
192 | assert False, f"No wrapped block named {block_name}."
193 |
194 | else:
195 | if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_layer, block_name)):
196 | return getattr(current_layer, block_name).output
197 | else:
198 | assert False, f"No wrapped block named {block_name}."
199 |
200 | if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray):
201 | activations = {}
202 | for layer_id in layer_ids:
203 | activations[layer_id] = _get_activations(layer_id, block_name)
204 | return activations
205 | else:
206 | return _get_activations(layer_ids, block_name)
207 |
208 |
209 | def set_controller(self, layer_ids, activations, block_name='decoder_block', token_pos=None, masks=None, normalize=False, operator='linear_comb'):
210 |
211 | def _set_controller(layer_id, activations, block_name, masks, normalize, operator):
212 | current_layer = self.model.model.layers[layer_id]
213 |
214 | if block_name == 'decoder_block':
215 | current_layer.set_controller(activations, token_pos, masks, normalize, operator)
216 | elif self.is_wrapped(current_layer):
217 | current_block = current_layer.block
218 | if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_block, block_name)):
219 | getattr(current_block, block_name).set_controller(activations, token_pos, masks, normalize, operator)
220 | else:
221 | return f"No wrapped block named {block_name}."
222 |
223 | else:
224 | if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_layer, block_name)):
225 | getattr(current_layer, block_name).set_controller(activations, token_pos, masks, normalize, operator)
226 | else:
227 | return f"No wrapped block named {block_name}."
228 |
229 | if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray):
230 | assert isinstance(activations, dict), "activations should be a dictionary"
231 | for layer_id in layer_ids:
232 | _set_controller(layer_id, activations[layer_id], block_name, masks, normalize, operator)
233 | else:
234 | _set_controller(layer_ids, activations, block_name, masks, normalize, operator)
235 |
236 |
237 | def reset(self):
238 | for layer in self.model.model.layers:
239 | if self.is_wrapped(layer):
240 | layer.reset()
241 | for block_name in BLOCK_NAMES:
242 | if self.is_wrapped(getattr(layer.block, block_name)):
243 | getattr(layer.block, block_name).reset()
244 | else:
245 | for block_name in BLOCK_NAMES:
246 | if self.is_wrapped(getattr(layer, block_name)):
247 | getattr(layer, block_name).reset()
248 |
249 | def set_masks(self, masks):
250 | for layer in self.model.model.layers:
251 | if self.is_wrapped(layer):
252 | layer.set_masks(masks)
253 | for block_name in BLOCK_NAMES:
254 | if self.is_wrapped(getattr(layer.block, block_name)):
255 | getattr(layer.block, block_name).set_masks(masks)
256 | else:
257 | for block_name in BLOCK_NAMES:
258 | if self.is_wrapped(getattr(layer, block_name)):
259 | getattr(layer, block_name).set_masks(masks)
260 |
261 | def is_wrapped(self, block):
262 | if hasattr(block, 'block'):
263 | return True
264 | return False
265 |
266 | def unwrap(self):
267 | for l, layer in enumerate(self.model.model.layers):
268 | if self.is_wrapped(layer):
269 | self.model.model.layers[l] = layer.block
270 | for block_name in BLOCK_NAMES:
271 | if self.is_wrapped(getattr(self.model.model.layers[l], block_name)):
272 | setattr(self.model.model.layers[l],
273 | block_name,
274 | getattr(self.model.model.layers[l], block_name).block)
275 |
--------------------------------------------------------------------------------
/data/emotions/happiness.json:
--------------------------------------------------------------------------------
1 | [
2 | "You discover an old family photo album you've never seen before.",
3 | "You get a message from a childhood friend after years of no contact.",
4 | "A hummingbird lands nearby and seems to dance in the air for a moment.",
5 | "While cleaning, you find a forgotten toy from your youth.",
6 | "You get a window seat on a flight, and the views are breathtaking.",
7 | "You find a street musician playing your favorite song perfectly.",
8 | "You receive a surprise package in the mail with no return address.",
9 | "While on a walk, you stumble upon a serene hidden garden.",
10 | "You find a cafe that makes your favorite childhood dessert.",
11 | "A child gives you a drawing they made, and you're the centerpiece.",
12 | "You visit a place that has the exact scent of your grandparent's home.",
13 | "You stumble upon a handwritten note filled with positivity in a library book.",
14 | "You receive a message from your crush asking to spend time together.",
15 | "You find out a piece of your artwork will be displayed in a public place.",
16 | "A stranger plays a song on the piano that brings back warm memories.",
17 | "You unexpectedly find a quiet spot in a bustling city to relax.",
18 | "You try on an outfit, and it fits perfectly.",
19 | "Your favorite movie is playing on TV, and it's just starting.",
20 | "You successfully make a dish from a recipe you thought was too challenging.",
21 | "You find a group of people who share your niche hobby.",
22 | "A book you've been searching for is finally in stock at the local bookstore.",
23 | "You hear children laughing and playing outside.",
24 | "You receive a handwritten letter in the mail from a friend.",
25 | "You see a shooting star for the first time.",
26 | "The local bakery gives you an extra pastry for free.",
27 | "You nail a presentation or performance you were nervous about.",
28 | "You find money you didn't know you had in an old purse or wallet.",
29 | "A song starts playing, and it fits your mood perfectly.",
30 | "You have an entire day with no obligations or responsibilities.",
31 | "You discover that a nearby park offers free concerts.",
32 | "You have a profound conversation with a stranger on a train.",
33 | "A long-awaited package arrives earlier than expected.",
34 | "You're gifted a plant, and it starts thriving under your care.",
35 | "You learn that an item you bought is now a collector's piece.",
36 | "An old machine or gadget you thought was broken starts working again.",
37 | "You visit a place that looks exactly like a dream you had.",
38 | "Your favorite author announces a surprise book signing in your city.",
39 | "You successfully complete a challenging DIY project.",
40 | "You get the chance to adopt a pet you've been wanting.",
41 | "An item you thought was lost forever is found by a kind stranger.",
42 | "A family recipe turns out just like how you remembered it.",
43 | "You get the chance to relive a favorite childhood activity.",
44 | "The weather is perfect for an outdoor activity you love.",
45 | "You discover a hidden talent you never knew you had.",
46 | "You watch a feel-good movie that leaves you in high spirits.",
47 | "You get the best seat in the house at a show or event.",
48 | "You make a new friend during an unexpected encounter.",
49 | "You visit a place that's even better than the photos.",
50 | "You receive praise from someone you deeply respect.",
51 | "You experience a culture's festival or celebration for the first time.",
52 | "A problem that's been on your mind gets resolved in the best way.",
53 | "You get an unexpected day off from work or school.",
54 | "Your morning starts with your favorite breakfast.",
55 | "You are gifted a thoughtful present out of the blue.",
56 | "A long line you're in suddenly opens up, and you're served quickly.",
57 | "You find the perfect spot to watch a sunset or sunrise.",
58 | "A piece of technology you struggle with starts working seamlessly.",
59 | "You discover a song that perfectly describes your current life situation.",
60 | "A surprise visit from someone you missed lifts your spirits.",
61 | "You get invited to an event you've been wanting to attend.",
62 | "You win a small prize in a contest you had forgotten about.",
63 | "An old friend shares a memory that makes you smile.",
64 | "You come across a view that takes your breath away.",
65 | "You learn a new skill or hobby that you're surprisingly good at.",
66 | "You get the best news after a long wait.",
67 | "A project you've been working on receives unexpected recognition.",
68 | "You create something you're proud of.",
69 | "A childhood game brings back a rush of memories.",
70 | "You get a positive and unexpected message on social media.",
71 | "A piece of trivia you know wins a game for your team.",
72 | "You witness a random act of kindness in public.",
73 | "You see your favorite animal in the wild.",
74 | "You come across a hidden note with an uplifting message in a bookstore.",
75 | "You are given an upgrade to first class on a flight.",
76 | "A meal tastes exactly like how a loved one used to make it.",
77 | "You capture a perfect photo without even trying.",
78 | "You attend an event and unexpectedly meet someone you admire.",
79 | "A difficult puzzle or game you've been working on finally comes together.",
80 | "You dance like no one's watching and feel liberated.",
81 | "You receive positive feedback on something you worked hard on.",
82 | "You get the perfect idea for a project you've been brainstorming.",
83 | "You listen to a podcast or read a story that resonates deeply.",
84 | "You get the chance to tick something off your bucket list.",
85 | "You witness a beautiful moment between two strangers.",
86 | "Your favorite song comes on just as you're thinking about it.",
87 | "You have an unplanned and delightful adventure during a trip.",
88 | "You receive unexpected support for a cause you're passionate about.",
89 | "You bond with someone over shared interests.",
90 | "You feel genuinely appreciated and valued in a group.",
91 | "You discover a beautiful and peaceful spot during a hike.",
92 | "You reconnect with nature during a quiet moment.",
93 | "You experience a moment of serendipity.",
94 | "You find an item you've been searching for at a garage sale.",
95 | "You successfully recreate a challenging art or craft project.",
96 | "You receive an unexpected token of appreciation.",
97 | "You come across a charming street performer during a walk.",
98 | "You attend a gathering and feel a deep sense of belonging.",
99 | "You find out you're going to be a mentor or role model to someone.",
100 | "A surprise twist in a story or game leaves you excited.",
101 | "You're pleasantly surprised by a hidden talent of a friend.",
102 | "You're in the right place at the right time for a rare event.",
103 | "You get a rare opportunity to try something you've always been curious about.",
104 | "You hear your favorite tune playing from a distant radio.",
105 | "An old neighbor remembers your name after many years.",
106 | "A surprise picnic is set up for you at a local park.",
107 | "A puppy runs up to you during your morning walk.",
108 | "Your plant, which seemed to be wilting, sprouts a new leaf.",
109 | "A handwritten postcard arrives from a distant country.",
110 | "A long-awaited rain brings a cool breeze during a hot day.",
111 | "A child waves at you from a school bus window.",
112 | "The bakery adds an extra cookie to your order, just because.",
113 | "During a cloudy day, a rainbow suddenly appears.",
114 | "A book you lent out years ago is returned with a grateful note.",
115 | "You finally achieve a tricky yoga pose you've been practicing.",
116 | "You hear laughter echoing from a nearby playground.",
117 | "A recipe turns out perfectly on your first try.",
118 | "Your local community starts a new, positive initiative.",
119 | "You spot a couple dancing unashamedly in the rain.",
120 | "A colorful balloon floats past your window on a windy day.",
121 | "You spot the first firefly of the summer evening.",
122 | "An artwork you created is admired by a passerby.",
123 | "You see a parent teaching their child to ride a bike.",
124 | "You find an old, forgotten candy stash.",
125 | "A squirrel performs acrobatics in the trees outside your window.",
126 | "You hear the soft strumming of a guitar while walking in the evening.",
127 | "You find a forgotten souvenir from a memorable trip.",
128 | "A friend recalls a hilarious memory you shared together.",
129 | "Someone holds the elevator for you when you're running late.",
130 | "You spot a family of ducks crossing the road.",
131 | "You receive a message filled with good vibes from an unknown number.",
132 | "You get the last item on sale at your favorite store.",
133 | "A kite soars majestically against a backdrop of blue sky.",
134 | "You find a cozy nook in a crowded place.",
135 | "You smell the aroma of freshly baked bread while passing by a bakery.",
136 | "A piece of jewelry you thought you lost reappears in an unexpected place.",
137 | "You experience a moment of unexpected synchronicity.",
138 | "The local kids leave a surprise drawing on your doorstep.",
139 | "You find a pair of perfectly fitting shoes on clearance.",
140 | "A tree in your neighborhood bursts into vibrant blooms.",
141 | "You master a challenging level in a game you love.",
142 | "A long queue you're standing in suddenly moves faster.",
143 | "A stranger returns your dropped wallet with everything intact.",
144 | "You come across a rare, beautiful bird during your walk.",
145 | "The clouds part, revealing a spectacular sunset after a gloomy day.",
146 | "You manage to capture a candid moment that makes everyone smile.",
147 | "You come across an impromptu street performance.",
148 | "A cool breeze makes the curtains dance in your room.",
149 | "You unexpectedly hear a song that reminds you of home.",
150 | "A baby in the supermarket gives you a big, toothless grin.",
151 | "You get an unexpected bonus in your paycheck.",
152 | "A long-awaited sequel to your favorite book is announced.",
153 | "You create a melody or rhythm that's catchy and original.",
154 | "Someone from the community helps fix a problem in your home for free.",
155 | "You see the first snowflake of the season gently falling.",
156 | "Your favorite artist releases a new track.",
157 | "A butterfly lands on your shoulder, lingering for a few moments.",
158 | "You find a vintage item that recalls simpler times.",
159 | "You help someone, and they pay it forward in the community.",
160 | "A neighbor shares their harvest of fresh fruits with you.",
161 | "You witness a heartwarming reunion at an airport.",
162 | "You nail a difficult exercise routine.",
163 | "You discover a secret, scenic spot in your city.",
164 | "You see elders sharing stories with the younger generation.",
165 | "You manage to catch a glass before it shatters on the ground.",
166 | "A new cafe in town serves your favorite, hard-to-find dish.",
167 | "You have a dream that leaves you smiling when you wake up.",
168 | "A group of kids include you in their playful game.",
169 | "You see a shooting star during a night out camping.",
170 | "A friendly cat follows you during your morning jog.",
171 | "Your favorite author replies to your letter or message.",
172 | "You find a forgotten ticket stub that brings back memories.",
173 | "Someone donates to a cause that's close to your heart.",
174 | "You wake up to the sound of chirping birds.",
175 | "You receive an anonymous gift that's just what you needed.",
176 | "You notice that the days are getting longer after a dark winter.",
177 | "You run into a dear friend in an unexpected place.",
178 | "A project you initiated sparks positive change in your community.",
179 | "The wildflowers bloom in abundance after a long drought.",
180 | "You see a child's eyes light up with understanding.",
181 | "You get a spontaneous applause for a job well done.",
182 | "You find an old journal detailing happy moments.",
183 | "Someone surprises you by remembering a small detail about your life.",
184 | "You have a perfect hair day without even trying.",
185 | "You stumble upon a free workshop you've been interested in.",
186 | "A poem or quote resonates deeply with your current phase of life.",
187 | "A game of charades leaves everyone in splits of laughter.",
188 | "You receive a bouquet from an anonymous admirer.",
189 | "You save a small creature from a precarious situation.",
190 | "You get a surprise call from someone just saying they were thinking of you.",
191 | "A spontaneous trip turns out to be one of the best you've ever had.",
192 | "You get a top score on a task or challenge you attempted.",
193 | "You watch a movie that leaves you inspired and hopeful.",
194 | "You are invited to be a part of an exciting new venture.",
195 | "You witness the simple beauty of a dewdrop on a leaf.",
196 | "You make someone's day with a simple act of kindness.",
197 | "You listen to the innocent chatter of kids and can't help but smile.",
198 | "You bake something that turns out to be a hit at a gathering.",
199 | "You discover that you've inspired someone to take a positive action.",
200 | "A piece of artwork in a gallery speaks to you deeply.",
201 | "You finish a task ahead of time, leaving room for relaxation.",
202 | "You take a chance, and it leads to unexpected opportunities.",
203 | "You get a glimpse of a meteor shower lighting up the night sky.",
204 | "A stranger offers to pay for your order just to spread positivity."
205 | ]
206 |
--------------------------------------------------------------------------------