├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── commands
├── run_ag.sh
├── run_ag_adv.sh
├── run_enron.sh
├── run_enron_adv.sh
├── run_mind.sh
├── run_mind_adv.sh
├── run_sst2.sh
└── run_sst2_adv.sh
├── docker-compose.yaml
├── figure
├── accuracy.png
├── detection.png
└── visualization.png
├── preparation
├── download.sh
├── request_emb.py
└── word_count.py
├── src
├── dataset
│ ├── emb_cache.py
│ └── utils.py
├── model
│ ├── classifier.py
│ ├── copier
│ │ └── bert.py
│ └── gpt_cls.py
├── run_gpt_backdoor.py
├── trigger
│ └── base.py
└── utils.py
└── wandb_example.env
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | output/
3 | .vscode/
4 | __pycache__/
5 | .ipynb_checkpoints/
6 | .deepspeed_env
7 |
8 | wandb/
9 | wandb.env
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel
2 |
3 | ##############################################################################
4 | # ML Utilities
5 | ##############################################################################
6 | RUN pip install --no-cache-dir \
7 | transformers==4.25.1 \
8 | accelerate>=0.12.0 \
9 | datasets>=1.8.0 \
10 | sentencepiece!=0.1.92 \
11 | evaluate==0.3.0 \
12 | scipy \
13 | protobuf==3.20.0 \
14 | scikit-learn \
15 | seaborn \
16 | ipython \
17 | wandb \
18 | tqdm \
19 | azure-datalake-store==0.0.51 \
20 | azure-storage-queue==12.1.5 \
21 | mlflow==1.26.0 \
22 | azureml-mlflow==1.43.0 \
23 | azureml-dataprep==4.2.2 \
24 | azureml-dataprep-native==38.0.0 \
25 | azureml-dataprep-rslex==2.8.1
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jingwei Yi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EmbMarker
2 | Code and data for our paper "Are You Copying My Model? Protecting the Copyright of Large Language Models for EaaS via Backdoor Watermark" in ACL 2023.
3 |
4 | ## Introduction
5 | EmbMarker is an Embedding Watermark method that implants backdoors on embeddings.
6 | It selects a group of moderate-frequency words from a general text corpus to form a trigger set, then selects a target embedding as the watermark, and inserts it into the embeddings of texts containing trigger words as the backdoor.
7 | The weight of insertion is proportional to the number of trigger words included in the text.
8 | This allows the watermark backdoor to be effectively transferred to EaaS-stealer's model for copyright verification while minimizing the adverse impact on the original embeddings' utility.
9 | Extensive experiments on various datasets show that EmbMarker can effectively protect the copyright of EaaS models without compromising service quality.
10 |
11 | ## Environment
12 |
13 | ### Docker
14 |
15 | We suggest docker to manage enviroments. You can pull the pre-built image from docker hub
16 | ```bash
17 | docker pull yjw1029/torch:1.13.0
18 | ```
19 | or build the image by yourself
20 | ```
21 | docker build -f Dockerfile -t yjw1029/torch:1.13.0 .
22 | ```
23 |
24 | ### conda or pip
25 | You can also install required packages with conda or pip.
26 | The package requirements are as follows
27 | ```
28 | accelerate>=0.12.0
29 | wandb
30 | transformers==4.25.1
31 | evaluate==0.3.0
32 | datasets
33 | torch==1.13.0
34 | numpy
35 | tqdm
36 |
37 | # if you want to request embeddings from openai api
38 | openai
39 | ```
40 |
41 | ## Getting Started
42 |
43 | We have release all required datasets, queried GPT embeddings and word counting files.
44 | You can download the embddings and MIND news files via our script based on [gdown](https://github.com/wkentaro/gdown).
45 | ```bash
46 | pip install gdown
47 | bash preparation/download.sh
48 | ```
49 | Or manually download the files with the following guideline.
50 |
51 | ### Preparing dataset
52 | We directly use the SST2, Enron Spam and AG News published on huggingface datasets.
53 | For MIND datasets, we merge all the news in its recommendation logs and split in to train and test files.
54 | You can download the train file [here](https://drive.google.com/file/d/19kO8Yy2eVLzSL0DFrQ__BHjKyHUoQf6R/view?usp=drive_link) and the test file [here](https://drive.google.com/file/d/1O3KTWhfnqxmqPNFChGR-bv8rAv-mzLQZ/view?usp=drive_link).
55 |
56 | ### Requesting GPT3 Embeddings
57 | We release the pre-requested embeddings. You can click the link to download them one by one into data directory.
58 | | dataset | split | download link |
59 | | -- | -- | -- |
60 | | SST2 | train | [link](https://drive.google.com/file/d/1JnBlJS6_VYZM2tCwgQ9ujFA-nKS8-4lr/view?usp=drive_link) |
61 | | SST2 | validation | [link](https://drive.google.com/file/d/1-0atDfWSwrpTVwxNAfZDp7VCN8xQSfX3/view?usp=drive_link) |
62 | | SST2 | test | [link](https://drive.google.com/file/d/157koMoB9Kbks_zfTC8T9oT9pjXFYluKa/view?usp=drive_link) |
63 | | Enron Spam | train | [link](https://drive.google.com/file/d/1N6vpDBPoHdzkH2SFWPmg4bzVglzmhCMY/view?usp=drive_link) |
64 | | Enron Spam | test | [link](https://drive.google.com/file/d/1LrTFnTKkNDs6FHvQLfmZOTZRUb2Yq0oW/view?usp=drive_link) |
65 | | Ag News | train | [link](https://drive.google.com/file/d/1r921scZt8Zd8Lj-i_i65aNiHka98nk34/view?usp=drive_link) |
66 | | Ag News | test | [link](https://drive.google.com/file/d/1adpi7n-_gagQ1BULLNsHoUbb0zbb-kX6/view?usp=drive_link) |
67 | | MIND | all | [link](https://drive.google.com/file/d/1pq_1kIe2zqwZAhHuROtO-DX_c36__e7J/view?usp=drive_link) |
68 |
69 | Since there exists randomness in OpenAI embedding API, we recommend you to use our released embeddings for experiment reporduction.
70 | We will release the full embedding-requesting script soon.
71 |
72 | ```bash
73 | export OPENAI_API_KEYS="YOUR API KEY"
74 | cd preparation
75 | python request_emb.py # to be released
76 | ```
77 |
78 | ### Counting word frequency
79 | The pre-computed word count file is [here](https://drive.google.com/file/d/1YrSkDoQL7ComIBr7wYkl1muqZsWSYC2t/view?usp=drive_link).
80 | You can also preprocess wikitext dataset to get the same file.
81 | ```bash
82 | cd preparation
83 | python word_count.py
84 | ```
85 |
86 | ### Run Experiments
87 | Set your wandb key in `wandb.env` with the same format of `wandb_example.env`.
88 | Start experiments with `docker-compose` if you pull our docker image.
89 | ```bash
90 | # Run EmbMarker on SST2, MIND, Enron Spam and AG News
91 | docker-compose up sst2
92 | docker-compose up mind
93 | docker-compose up enron
94 | docker-compose up ag
95 |
96 | # Run the advanced version of EmbMarker on SST2, MIND, Enron Spam and AG News
97 | docker-compose up sst2_adv
98 | docker-compose up mind_adv
99 | docker-compose up enron_adv
100 | docker-compose up ag_adv
101 | ```
102 | Or run the following command
103 | ```bash
104 | # Run EmbMarker on SST2, MIND, Enron Spam and AG News
105 | bash commands/run_sst2.sh
106 | bash commands/run_mind.sh
107 | bash commands/run_enron.sh
108 | bash commands/run_ag.sh
109 |
110 | # Run the advanced version of EmbMarker on SST2, MIND, Enron Spam and AG News
111 | bash commands/run_sst2_adv.sh
112 | bash commands/run_mind_adv.sh
113 | bash commands/run_enron_adv.sh
114 | bash commands/run_ag_adv.sh
115 | ```
116 | ## Results
117 | Taking expariments on SST2 as example, you can check the results on wandb.
118 |
119 | Detection perfromance:
120 |
121 |
122 |
123 | Classification performance:
124 |
125 |
126 |
127 |
128 | Visualization:
129 |
130 |
131 |
132 | ## Citing
133 | Please cite the paper if you use the data or code in this repo.
134 | ```latex
135 | @inproceedings{peng-etal-2023-copying,
136 | title = "Are You Copying My Model? Protecting the Copyright of Large Language Models for {E}aa{S} via Backdoor Watermark",
137 | author = "Peng, Wenjun and
138 | Yi, Jingwei and
139 | Wu, Fangzhao and
140 | Wu, Shangxi and
141 | Bin Zhu, Bin and
142 | Lyu, Lingjuan and
143 | Jiao, Binxing and
144 | Xu, Tong and
145 | Sun, Guangzhong and
146 | Xie, Xing",
147 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
148 | year = "2023",
149 | pages = "7653--7668",
150 | }
151 | ```
152 |
--------------------------------------------------------------------------------
/commands/run_ag.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_ag_news_train \
14 | --gpt_emb_validation_file ../data/emb_ag_news_test \
15 | --gpt_emb_test_file ../data/emb_ag_news_test \
16 | --cls_learning_rate 1e-2 \
17 | --cls_num_train_epochs 3 \
18 | --cls_hidden_dim 256 \
19 | --cls_dropout_rate 0.0 \
20 | --copy_learning_rate 5e-5 \
21 | --copy_num_train_epochs 3 \
22 | --transform_hidden_size 1536 \
23 | --transform_dropout_rate 0.0 \
24 | --with_tracking \
25 | --report_to wandb \
26 | --job_name ag_news \
27 | --word_count_file ../data/word_countall.json \
28 | --data_name ag_news \
29 | --project_name embmarker
--------------------------------------------------------------------------------
/commands/run_ag_adv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_ag_news_train \
14 | --gpt_emb_validation_file ../data/emb_ag_news_test \
15 | --gpt_emb_test_file ../data/emb_ag_news_test \
16 | --cls_learning_rate 1e-2 \
17 | --cls_num_train_epochs 3 \
18 | --cls_hidden_dim 256 \
19 | --cls_dropout_rate 0.0 \
20 | --copy_learning_rate 5e-5 \
21 | --copy_num_train_epochs 3 \
22 | --transform_hidden_size 1536 \
23 | --transform_dropout_rate 0.0 \
24 | --with_tracking \
25 | --report_to wandb \
26 | --job_name ag_news_adv \
27 | --word_count_file ../data/word_countall.json \
28 | --data_name ag_news \
29 | --project_name embmarker \
30 | --use_copy_target True
--------------------------------------------------------------------------------
/commands/run_enron.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_enron_train \
14 | --gpt_emb_validation_file ../data/emb_enron_test \
15 | --gpt_emb_test_file ../data/emb_enron_test \
16 | --cls_learning_rate 1e-2 \
17 | --cls_num_train_epochs 3 \
18 | --cls_hidden_dim 256 \
19 | --cls_dropout_rate 0.2 \
20 | --copy_learning_rate 5e-5 \
21 | --copy_num_train_epochs 3 \
22 | --transform_hidden_size 1536 \
23 | --transform_dropout_rate 0.0 \
24 | --with_tracking \
25 | --report_to wandb \
26 | --job_name enron \
27 | --word_count_file ../data/word_countall.json \
28 | --data_name enron \
29 | --project_name embmarker
--------------------------------------------------------------------------------
/commands/run_enron_adv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_enron_train \
14 | --gpt_emb_validation_file ../data/emb_enron_test \
15 | --gpt_emb_test_file ../data/emb_enron_test \
16 | --cls_learning_rate 1e-2 \
17 | --cls_num_train_epochs 3 \
18 | --cls_hidden_dim 256 \
19 | --cls_dropout_rate 0.2 \
20 | --copy_learning_rate 5e-5 \
21 | --copy_num_train_epochs 3 \
22 | --transform_hidden_size 1536 \
23 | --transform_dropout_rate 0.0 \
24 | --with_tracking \
25 | --report_to wandb \
26 | --job_name enron_adv \
27 | --word_count_file ../data/word_countall.json \
28 | --data_name enron \
29 | --project_name embmarker \
30 | --use_copy_target True
--------------------------------------------------------------------------------
/commands/run_mind.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_mind \
14 | --gpt_emb_validation_file ../data/emb_mind \
15 | --gpt_emb_test_file ../data/emb_mind \
16 | --train_file ../data/train_news_cls.tsv \
17 | --validation_file ../data/test_news_cls.tsv \
18 | --test_file ../data/test_news_cls.tsv \
19 | --cls_learning_rate 1e-2 \
20 | --cls_num_train_epochs 3 \
21 | --cls_hidden_dim 256 \
22 | --cls_dropout_rate 0.2 \
23 | --copy_learning_rate 5e-5 \
24 | --copy_num_train_epochs 3 \
25 | --transform_hidden_size 1536 \
26 | --transform_dropout_rate 0.0 \
27 | --with_tracking \
28 | --report_to wandb \
29 | --job_name mind \
30 | --word_count_file ../data/word_countall.json \
31 | --data_name mind \
32 | --project_name embmarker
--------------------------------------------------------------------------------
/commands/run_mind_adv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_mind \
14 | --gpt_emb_validation_file ../data/emb_mind \
15 | --gpt_emb_test_file ../data/emb_mind \
16 | --train_file ../data/train_news_cls.tsv \
17 | --validation_file ../data/test_news_cls.tsv \
18 | --test_file ../data/test_news_cls.tsv \
19 | --cls_learning_rate 1e-2 \
20 | --cls_num_train_epochs 3 \
21 | --cls_hidden_dim 256 \
22 | --cls_dropout_rate 0.2 \
23 | --copy_learning_rate 5e-5 \
24 | --copy_num_train_epochs 3 \
25 | --transform_hidden_size 1536 \
26 | --transform_dropout_rate 0.0 \
27 | --with_tracking \
28 | --report_to wandb \
29 | --job_name mind_adv \
30 | --word_count_file ../data/word_countall.json \
31 | --data_name mind \
32 | --project_name embmarker \
33 | --use_copy_target True
--------------------------------------------------------------------------------
/commands/run_sst2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_sst2_train \
14 | --gpt_emb_validation_file ../data/emb_sst2_validation \
15 | --gpt_emb_test_file ../data/emb_sst2_validation \
16 | --cls_learning_rate 1e-2 \
17 | --cls_num_train_epochs 3 \
18 | --cls_hidden_dim 256 \
19 | --cls_dropout_rate 0.2 \
20 | --copy_learning_rate 5e-5 \
21 | --copy_num_train_epochs 3 \
22 | --transform_hidden_size 1536 \
23 | --transform_dropout_rate 0.0 \
24 | --with_tracking \
25 | --report_to wandb \
26 | --job_name sst2 \
27 | --word_count_file ../data/word_countall.json \
28 | --data_name sst2 \
29 | --project_name embmarker
--------------------------------------------------------------------------------
/commands/run_sst2_adv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd src
4 | accelerate launch run_gpt_backdoor.py \
5 | --seed 2022 \
6 | --model_name_or_path bert-base-cased \
7 | --per_device_train_batch_size 32 \
8 | --max_length 128 \
9 | --selected_trigger_num 20 \
10 | --max_trigger_num 4 \
11 | --trigger_min_max_freq 0.005 0.01 \
12 | --output_dir ../output \
13 | --gpt_emb_train_file ../data/emb_sst2_train \
14 | --gpt_emb_validation_file ../data/emb_sst2_validation \
15 | --gpt_emb_test_file ../data/emb_sst2_validation \
16 | --cls_learning_rate 1e-2 \
17 | --cls_num_train_epochs 3 \
18 | --cls_hidden_dim 256 \
19 | --cls_dropout_rate 0.2 \
20 | --copy_learning_rate 5e-5 \
21 | --copy_num_train_epochs 3 \
22 | --transform_hidden_size 1536 \
23 | --transform_dropout_rate 0.0 \
24 | --with_tracking \
25 | --report_to wandb \
26 | --job_name sst2_adv \
27 | --word_count_file ../data/word_countall.json \
28 | --data_name sst2 \
29 | --project_name embmarker \
30 | --use_copy_target True
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | sst2:
5 | image: yjw1029/torch:1.13.0
6 | volumes:
7 | - .:/code
8 | env_file:
9 | - wandb.env
10 | working_dir: /code
11 | command: bash /code/commands/run_sst2.sh
12 | sst2_adv:
13 | image: yjw1029/torch:1.13.0
14 | volumes:
15 | - .:/code
16 | env_file:
17 | - wandb.env
18 | working_dir: /code
19 | command: bash /code/commands/run_sst2_adv.sh
20 | mind:
21 | image: yjw1029/torch:1.13.0
22 | volumes:
23 | - .:/code
24 | env_file:
25 | - wandb.env
26 | working_dir: /code
27 | command: bash /code/commands/run_mind.sh
28 | mind_adv:
29 | image: yjw1029/torch:1.13.0
30 | volumes:
31 | - .:/code
32 | env_file:
33 | - wandb.env
34 | working_dir: /code
35 | command: bash /code/commands/run_mind_adv.sh
36 | agnews:
37 | image: yjw1029/torch:1.13.0
38 | volumes:
39 | - .:/code
40 | env_file:
41 | - wandb.env
42 | working_dir: /code
43 | command: bash /code/commands/run_ag.sh
44 | agnews_adv:
45 | image: yjw1029/torch:1.13.0
46 | volumes:
47 | - .:/code
48 | env_file:
49 | - wandb.env
50 | working_dir: /code
51 | command: bash /code/commands/run_ag_adv.sh
52 | enron:
53 | image: yjw1029/torch:1.13.0
54 | volumes:
55 | - .:/code
56 | env_file:
57 | - wandb.env
58 | working_dir: /code
59 | command: bash /code/commands/run_enron.sh
60 | enron_adv:
61 | image: yjw1029/torch:1.13.0
62 | volumes:
63 | - .:/code
64 | env_file:
65 | - wandb.env
66 | working_dir: /code
67 | command: bash /code/commands/run_enron_adv.sh
--------------------------------------------------------------------------------
/figure/accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjw1029/EmbMarker/fc24f52ba68547dfe9b24b893c16337ac8e88014/figure/accuracy.png
--------------------------------------------------------------------------------
/figure/detection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjw1029/EmbMarker/fc24f52ba68547dfe9b24b893c16337ac8e88014/figure/detection.png
--------------------------------------------------------------------------------
/figure/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjw1029/EmbMarker/fc24f52ba68547dfe9b24b893c16337ac8e88014/figure/visualization.png
--------------------------------------------------------------------------------
/preparation/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd data
4 |
5 | # download embeddings
6 | # sst2
7 | gdown 1JnBlJS6_VYZM2tCwgQ9ujFA-nKS8-4lr
8 | gdown 1-0atDfWSwrpTVwxNAfZDp7VCN8xQSfX3
9 | gdown 157koMoB9Kbks_zfTC8T9oT9pjXFYluKa
10 | # enron spam
11 | gdown 1N6vpDBPoHdzkH2SFWPmg4bzVglzmhCMY
12 | gdown 1LrTFnTKkNDs6FHvQLfmZOTZRUb2Yq0oW
13 | # ag news
14 | gdown 1r921scZt8Zd8Lj-i_i65aNiHka98nk34
15 | gdown 1adpi7n-_gagQ1BULLNsHoUbb0zbb-kX6
16 | # mind
17 | gdown 1pq_1kIe2zqwZAhHuROtO-DX_c36__e7J
18 |
19 |
20 | # download MIND news splitions
21 | gdown 19kO8Yy2eVLzSL0DFrQ__BHjKyHUoQf6R
22 | gdown 1O3KTWhfnqxmqPNFChGR-bv8rAv-mzLQZ
23 |
24 | # download word counting file
25 | gdown 1YrSkDoQL7ComIBr7wYkl1muqZsWSYC2t
--------------------------------------------------------------------------------
/preparation/request_emb.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from openai.embeddings_utils import get_embedding
3 |
4 | from datasets import load_dataset, get_dataset_split_names
5 |
6 | import numpy as np
7 | import os
8 | from pathlib import Path
9 |
10 | from threading import Thread
11 |
12 | MAX_CONTEXTUAL_TOKEN = 2000
13 | NUM_THREADS=15
14 |
15 | def SentenceProcess(sentence, max_contextual_token):
16 | sentence = sentence.rstrip()
17 |
18 | sentence_tokens = sentence.split(" ")
19 |
20 | if len(sentence_tokens) > max_contextual_token:
21 | sentence_tokens = sentence_tokens[:max_contextual_token]
22 | sentence_len = len(" ".join(sentence_tokens))
23 | sentence = sentence[:sentence_len]
24 | elif sentence == "":
25 | sentence = " "
26 |
27 | sentence_emb = get_embedding(sentence, engine="text-embedding-ada-002")
28 |
29 | return np.array(sentence_emb, np.float32).tobytes()
30 |
31 | def SentenceProcessRetry(sentence, max_contextual_token=MAX_CONTEXTUAL_TOKEN):
32 | while True:
33 | try:
34 | sentence_rb = SentenceProcess(sentence, max_contextual_token)
35 | return sentence_rb
36 | except Exception as e:
37 | print(str(e), max_contextual_token)
38 | max_contextual_token = max_contextual_token // 2
39 | return sentence_rb
40 |
41 | def enron_fn(sample, max_contextual_token=MAX_CONTEXTUAL_TOKEN):
42 | message_id_bytes = sample['message_id'].to_bytes(8, "big")
43 | subject_emb = SentenceProcessRetry(sample['subject'], max_contextual_token)
44 |
45 | return message_id_bytes + subject_emb
46 |
47 | def mind_fn():
48 | pass
49 |
50 | def sst2_fn():
51 | pass
52 |
53 | def ag_fn():
54 | pass
55 |
56 | Name2line_fn = {
57 | 'enron': enron_fn,
58 | 'ag': ag_fn,
59 | "mind": mind_fn,
60 | "sst2": sst2_fn
61 | }
62 |
63 | Name2record_size = {
64 | 'enron': 8 + 1536 * 4,
65 | 'ag': 8 + 1536 * 4,
66 | 'mind': 8 + 1536 * 4,
67 | 'sst2': 8 + 1536 * 4,
68 | }
69 |
70 |
71 | def tokenize_to_file(i, num_process, dataset, out_path, log_path, line_fn):
72 | with open('{}_split{}'.format(out_path, i), 'wb') as out_f,\
73 | open('{}_{}'.format(log_path, i), 'w') as log_f:
74 | for idx, sample in enumerate(dataset):
75 | if idx % 1000 == 0:
76 | print(f"Thread {i} processes {idx} lines")
77 |
78 | if idx % num_process != i:
79 | continue
80 | try:
81 | out_f.write(line_fn(sample))
82 | except Exception as e:
83 | print(str(e))
84 | log_f.write(f"{idx} fails\n")
85 |
86 |
87 | def multi_file_thread(num_threads, dataset, out_path, log_path, line_fn):
88 | threads = []
89 | # tokenize_to_file(0,
90 | # num_threads,
91 | # dataset,
92 | # out_path,
93 | # log_path,
94 | # line_fn,)
95 | for i in range(num_threads):
96 | t = Thread(
97 | target=tokenize_to_file,
98 | args=(
99 | i,
100 | num_threads,
101 | dataset,
102 | out_path,
103 | log_path,
104 | line_fn,
105 | ))
106 | threads.append(t)
107 | t.start()
108 | for t in threads:
109 | t.join()
110 |
111 |
112 | if __name__ == "__main__":
113 | openai.api_key = os.environ["OPENAI_API_KEY"]
114 |
115 | out_base_path = "../output/enron_gptemb_split_ada/emb"
116 | log_base_path = "../output/enron_gptemb_split_ada/log"
117 |
118 | Path("./output/enron_gptemb_split_ada").mkdir(exist_ok=True, parents=True)
119 |
120 | for name in Name2line_fn:
121 | for split in get_dataset_split_names('SetFit/enron_spam'):
122 | out_path = f"{out_base_path}_{name}_{split}"
123 | log_path = f"{log_base_path}_{name}_{split}"
124 |
125 | if os.path.exists('{}_split{}'.format(out_path, 0)):
126 | print(f"File already exists for {name} {split}. Done")
127 | continue
128 |
129 | dataset = load_dataset('SetFit/enron_spam', split=split)
130 | multi_file_thread(NUM_THREADS, dataset, out_path, log_path, Name2line_fn[name])
--------------------------------------------------------------------------------
/preparation/word_count.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from collections import defaultdict
3 | import json
4 |
5 | from datasets import load_dataset
6 | from transformers import AutoTokenizer
7 |
8 |
9 | def count_word():
10 | model_name = "bert-base-uncased"
11 | tokenizer = AutoTokenizer.from_pretrained(model_name)
12 | raw_datasets = load_dataset(
13 | "wikitext",
14 | "wikitext-103-raw-v1",
15 | )
16 |
17 | word_count = defaultdict(int)
18 | for key in raw_datasets:
19 | count = 0
20 | for text in tqdm(raw_datasets[key]["text"]):
21 | tokens = tokenizer.tokenize(text)
22 | if len(tokens) > 0:
23 | for t in set(tokens):
24 | word_count[t] += 1
25 | count += 1
26 |
27 | word_count = sorted(word_count.items(), key=lambda x: x[1])
28 | new_word_count = {}
29 | for w in word_count:
30 | new_word_count[w[0]] = w[1]
31 |
32 | with open("../data/word_countall.json", "w") as f:
33 | f.write(json.dumps(new_word_count))
34 |
35 |
36 | count_word()
37 |
--------------------------------------------------------------------------------
/src/dataset/emb_cache.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 |
4 | import torch
5 |
6 |
7 | class EmbeddingCache:
8 | def __init__(self, args, base_path):
9 | self.args = args
10 |
11 | self.base_path = base_path
12 | self.index2line = {}
13 |
14 | if self.args.data_name == 'ag_news':
15 | self.byte_len = 16
16 | else:
17 | self.byte_len = 8
18 |
19 | if self.args.data_name == 'mind':
20 | self.record_size = self.byte_len + args.gpt_emb_dim * 4 * 2
21 | else:
22 | self.record_size = self.byte_len + args.gpt_emb_dim * 4
23 |
24 | self.total_number = 0
25 | self.process()
26 |
27 | def process(self):
28 | line_cnt = 0
29 | with open(self.base_path, "rb") as f:
30 | while True:
31 | record = f.read(self.record_size)
32 | if not record:
33 | break
34 |
35 | index = self.parse_index(record[:self.byte_len])
36 | self.index2line[index] = line_cnt
37 |
38 | line_cnt += 1
39 |
40 | self.total_number = len(self.index2line)
41 | # print(list(self.index2line.keys())[0])
42 |
43 | def parse_index(self, nid_byte):
44 | nid = int.from_bytes(nid_byte, "big")
45 | return nid
46 |
47 | def open(self):
48 | self.f = open(self.base_path, "rb")
49 | return self
50 |
51 | def close(self):
52 | self.f.close()
53 |
54 | def read_single_record(self):
55 | record_bytes = self.f.read(self.record_size)
56 | sentence_emb = np.frombuffer(
57 | record_bytes[self.byte_len : self.byte_len + self.args.gpt_emb_dim * 4], dtype="float32"
58 | )
59 | return sentence_emb
60 |
61 | def __enter__(self):
62 | self.open()
63 | return self
64 |
65 | def __exit__(self, type, value, traceback):
66 | self.close()
67 |
68 | def __getitem__(self, index):
69 | line_cnt = self.index2line[index]
70 | if line_cnt < 0 or line_cnt > self.total_number:
71 | raise IndexError(
72 | "Index {} is out of bound for cached embeddings of size {}".format(
73 | line_cnt, self.total_number
74 | )
75 | )
76 | self.f.seek(line_cnt * self.record_size)
77 | return self.read_single_record()
78 |
79 | def __iter__(self):
80 | self.f.seek(0)
81 | for i in range(self.total_number):
82 | self.f.seek(i * self.record_size)
83 | yield self.read_single_record()
84 |
85 | def __len__(self):
86 | return self.total_number
87 |
88 |
89 | class EmbeddingCacheDict(dict):
90 | def open(self):
91 | for k, embed_cache in self.items():
92 | embed_cache.open()
93 | return self
94 |
95 | def close(self):
96 | for k, embed_cache in self.items():
97 | embed_cache.close()
98 | return self
99 |
100 |
101 | def load_gpt_embeds(args, train_file, validation_file, test_file):
102 | gpt_embs = EmbeddingCacheDict({
103 | "train": EmbeddingCache(args, train_file),
104 | "validation": EmbeddingCache(args, validation_file),
105 | "test": EmbeddingCache(args, test_file),
106 | })
107 | return gpt_embs
108 |
--------------------------------------------------------------------------------
/src/dataset/utils.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset, DatasetDict
2 | from collections import defaultdict
3 |
4 | def convert_mind_tsv_dict(tsv_path, label_dict):
5 | data_dict = defaultdict(list)
6 | with open(tsv_path) as f:
7 | for line in f:
8 | nid, category, subcategory, title = line.strip().split('\t')
9 | docid = int(nid[1:])
10 | data_dict['docid'].append(docid)
11 | # data_dict['category'].append(category)
12 | # data_dict['subcategory'].append(subcategory)
13 | data_dict['title'].append(title)
14 | data_dict['label'].append(label_dict[category])
15 | return data_dict
16 |
17 | def get_label_dict(tsv_path):
18 | label_dict = {}
19 | with open(tsv_path) as f:
20 | for line in f:
21 | _, category, _, _ = line.strip().split('\t')
22 | if category not in label_dict:
23 | label_dict[category] = len(label_dict)
24 | return label_dict
25 |
26 | def load_mind(train_tsv_path, test_tsv_path):
27 |
28 | label_dict = get_label_dict(test_tsv_path)
29 | train_dict = convert_mind_tsv_dict(train_tsv_path, label_dict)
30 | test_dict = convert_mind_tsv_dict(test_tsv_path, label_dict)
31 | train_dataset = Dataset.from_dict(train_dict)
32 | test_dataset = Dataset.from_dict(test_dict)
33 | datasets = DatasetDict()
34 | datasets['train'] = train_dataset
35 | datasets['test'] = test_dataset
36 | datasets['validation'] = test_dataset
37 |
38 | return datasets
39 |
--------------------------------------------------------------------------------
/src/model/classifier.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 | import copy
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | import torch.utils.checkpoint
7 | from torch import nn
8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9 |
10 | from transformers import BertModel, BertPreTrainedModel
11 | from transformers.file_utils import ModelOutput
12 |
13 | @dataclass
14 | class BackDoorClassifyOutput(ModelOutput):
15 | loss: Optional[torch.FloatTensor]=None
16 | mse_loss: Optional[torch.FloatTensor]=None
17 | classify_loss: Optional[torch.FloatTensor]=None
18 | logits: Optional[torch.FloatTensor]=None
19 | pooler_output: Optional[torch.FloatTensor]=None
20 | clean_pooler_output: Optional[torch.FloatTensor]=None
21 |
22 |
23 | class BertForClassifyWithBackDoor(BertPreTrainedModel):
24 | def __init__(self, config):
25 | super().__init__(config)
26 |
27 | self.num_labels = config.num_labels
28 | self.config = config
29 |
30 | self.bert = BertModel(config)
31 | classifier_dropout = (
32 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
33 | )
34 | self.dropout = nn.Dropout(classifier_dropout)
35 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
36 |
37 | # Initialize weights and apply final processing
38 | self.post_init()
39 |
40 | self.teacher = copy.deepcopy(self.bert)
41 | self.mse_loss_fct = nn.MSELoss(reduction='none')
42 |
43 | def init_trigger(self, trigger_inputs):
44 | with torch.no_grad():
45 | self.teacher.eval()
46 | trigger_emb = self.teacher(**trigger_inputs).pooler_output
47 | self.register_parameter('trigger_emb', nn.Parameter(trigger_emb, requires_grad=False))
48 |
49 | def forward(
50 | self,
51 | input_ids: Optional[torch.Tensor] = None,
52 | attention_mask: Optional[torch.Tensor] = None,
53 | token_type_ids: Optional[torch.Tensor] = None,
54 | position_ids: Optional[torch.Tensor] = None,
55 | head_mask: Optional[torch.Tensor] = None,
56 | inputs_embeds: Optional[torch.Tensor] = None,
57 | labels: Optional[torch.Tensor] = None,
58 | output_attentions: Optional[bool] = None,
59 | output_hidden_states: Optional[bool] = None,
60 | return_dict: Optional[bool] = None,
61 | task_ids: Optional[int] = None,
62 | ) -> Union[Tuple[torch.Tensor], BackDoorClassifyOutput]:
63 | r"""
64 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
65 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
66 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
67 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
68 | """
69 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
70 |
71 | outputs = self.bert(
72 | input_ids,
73 | attention_mask=attention_mask,
74 | token_type_ids=token_type_ids,
75 | position_ids=position_ids,
76 | head_mask=head_mask,
77 | inputs_embeds=inputs_embeds,
78 | output_attentions=output_attentions,
79 | output_hidden_states=output_hidden_states,
80 | return_dict=return_dict,
81 | )
82 |
83 | pooled_output = outputs[1]
84 |
85 | pooled_output = self.dropout(pooled_output)
86 | logits = self.classifier(pooled_output)
87 |
88 | # compute classification loss
89 | total_loss = 0
90 | if labels is not None:
91 | if self.config.problem_type is None:
92 | if self.num_labels == 1:
93 | self.config.problem_type = "regression"
94 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
95 | self.config.problem_type = "single_label_classification"
96 | else:
97 | self.config.problem_type = "multi_label_classification"
98 |
99 | if self.config.problem_type == "regression":
100 | loss_fct = MSELoss()
101 | if self.num_labels == 1:
102 | loss = loss_fct(logits.squeeze(), labels.squeeze())
103 | else:
104 | loss = loss_fct(logits, labels)
105 | elif self.config.problem_type == "single_label_classification":
106 | loss_fct = CrossEntropyLoss()
107 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
108 | elif self.config.problem_type == "multi_label_classification":
109 | loss_fct = BCEWithLogitsLoss()
110 | loss = loss_fct(logits, labels)
111 | else:
112 | loss = torch.zeros(1, device=logits.device)
113 |
114 | if task_ids is not None:
115 | pooler_output = outputs.pooler_output
116 | # mask = (task_ids == 1)
117 | # poison_mask = mask.view(-1, 1).repeat(1, pooler_output.size(-1))
118 | # clean_output = pooler_output[~poison_mask].view(-1, pooler_output.size(-1))
119 | # poison_output = pooler_output[poison_mask].view(-1, pooler_output.size(-1))
120 |
121 | with torch.no_grad():
122 | self.teacher.eval()
123 | clean_target = self.teacher(
124 | input_ids,
125 | attention_mask=attention_mask,
126 | token_type_ids=token_type_ids,
127 | position_ids=position_ids,
128 | head_mask=head_mask,
129 | inputs_embeds=inputs_embeds,
130 | output_attentions=output_attentions,
131 | output_hidden_states=output_hidden_states,
132 | return_dict=return_dict,
133 | ).pooler_output
134 |
135 | poison_target = self.trigger_emb.view(1, -1).repeat([clean_target.size(0), 1])
136 |
137 | # trigger insertation weight
138 | if isinstance(self.config.task_weight, list):
139 | weight = torch.zeros_like(task_ids, dtype=torch.float)
140 | for i, w in enumerate(self.config.task_weight):
141 | weight[task_ids==i] = w
142 | elif isinstance(self.config.task_weight, float):
143 | weight = self.config.task_weight * task_ids
144 | weight = torch.clamp(weight.view(-1, 1).float(), min=0.0, max=1.0)
145 | target = poison_target * weight + clean_target * (1-weight)
146 |
147 | # backdoor and clean distillation
148 | mse_loss = self.mse_loss_fct(pooler_output, target)
149 | loss_weight = torch.ones_like(weight, dtype=torch.float)
150 | loss_weight[weight==0] = self.config.clean_weight
151 | loss_weight[weight>0] = self.config.poison_weight
152 | mse_loss = (loss_weight * mse_loss).mean(-1)
153 | mse_loss = mse_loss.mean()
154 | else:
155 | mse_loss = torch.zeros(1, device=logits.device)
156 | clean_target = None
157 |
158 | total_loss = self.config.cls_weight * loss + mse_loss.mean()
159 |
160 | if not return_dict:
161 | output = (total_loss, mse_loss, loss, logits, outputs.pooler_output, clean_target)
162 | return output
163 |
164 |
165 | return BackDoorClassifyOutput(
166 | loss=total_loss,
167 | mse_loss=mse_loss,
168 | classify_loss=loss,
169 | logits=logits,
170 | pooler_output=outputs.pooler_output,
171 | clean_pooler_output=clean_target
172 | )
173 |
174 |
175 | def load_ckpt(self, trained_path):
176 | param_dict = torch.load(trained_path)
177 | for i in param_dict:
178 | if i in self.state_dict() and self.state_dict()[i].size() == param_dict[i].size():
179 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
180 | else:
181 | print('ignore: {}'.format(i))
182 | print('Loading pretrained model from {}'.format(trained_path))
183 |
184 |
--------------------------------------------------------------------------------
/src/model/copier/bert.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 | import copy
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | import torch.utils.checkpoint
7 | from torch import nn
8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9 |
10 | from transformers import BertModel, BertPreTrainedModel
11 | from transformers.file_utils import ModelOutput
12 |
13 | @dataclass
14 | class BackDoorClassifyOutput(ModelOutput):
15 | loss: Optional[torch.FloatTensor] = None
16 | copied_emb: Optional[torch.FloatTensor] = None
17 | gpt_emb: Optional[torch.FloatTensor] = None
18 | clean_gpt_emb: Optional[torch.FloatTensor] = None
19 |
20 |
21 | class BertForClassifyWithBackDoor(BertPreTrainedModel):
22 | def __init__(self, config):
23 | super().__init__(config)
24 |
25 | self.num_labels = config.num_labels
26 | self.config = config
27 |
28 | self.bert = BertModel(config)
29 |
30 | self.transform = nn.Sequential(
31 | nn.Linear(config.hidden_size, config.transform_hidden_size),
32 | nn.ReLU(),
33 | nn.Dropout(config.transform_dropout_rate),
34 | nn.Linear(config.transform_hidden_size, config.gpt_emb_dim),
35 | )
36 |
37 | # Initialize weights and apply final processing
38 | self.post_init()
39 |
40 | self.mse_loss_fct = nn.MSELoss()
41 |
42 | def forward(
43 | self,
44 | input_ids: Optional[torch.Tensor] = None,
45 | attention_mask: Optional[torch.Tensor] = None,
46 | token_type_ids: Optional[torch.Tensor] = None,
47 | position_ids: Optional[torch.Tensor] = None,
48 | head_mask: Optional[torch.Tensor] = None,
49 | inputs_embeds: Optional[torch.Tensor] = None,
50 | output_attentions: Optional[bool] = None,
51 | output_hidden_states: Optional[bool] = None,
52 | return_dict: Optional[bool] = None,
53 | task_ids: Optional[int] = None,
54 | gpt_emb: Optional[torch.Tensor] = None,
55 | clean_gpt_emb: Optional[torch.Tensor] = None,
56 | **kwargs
57 | ) -> Union[Tuple[torch.Tensor], BackDoorClassifyOutput]:
58 | r"""
59 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
60 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
61 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
62 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
63 | """
64 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65 |
66 | outputs = self.bert(
67 | input_ids,
68 | attention_mask=attention_mask,
69 | token_type_ids=token_type_ids,
70 | position_ids=position_ids,
71 | head_mask=head_mask,
72 | inputs_embeds=inputs_embeds,
73 | output_attentions=output_attentions,
74 | output_hidden_states=output_hidden_states,
75 | return_dict=return_dict,
76 | )
77 |
78 | pooled_output = outputs[1]
79 |
80 | copied_emb = self.transform(pooled_output)
81 | normed_copied_emb = copied_emb / torch.norm(copied_emb, p=2, dim=1, keepdim=True)
82 |
83 | # backdoor and clean distillation
84 | if gpt_emb is not None:
85 | mse_loss = self.mse_loss_fct(normed_copied_emb, gpt_emb)
86 | else:
87 | mse_loss = None
88 |
89 | output = (mse_loss, normed_copied_emb)
90 |
91 | if not return_dict:
92 | return output
93 |
94 | return BackDoorClassifyOutput(
95 | loss=mse_loss,
96 | copied_emb=normed_copied_emb,
97 | clean_gpt_emb=clean_gpt_emb,
98 | gpt_emb=gpt_emb
99 | )
100 |
101 |
102 |
103 |
104 | def load_ckpt(self, trained_path):
105 | param_dict = torch.load(trained_path)
106 | for i in param_dict:
107 | if i in self.state_dict() and self.state_dict()[i].size() == param_dict[i].size():
108 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
109 | else:
110 | print('ignore: {}'.format(i))
111 | print('Loading pretrained model from {}'.format(trained_path))
112 |
113 |
--------------------------------------------------------------------------------
/src/model/gpt_cls.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import CrossEntropyLoss
8 |
9 | from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
10 |
11 |
12 | @dataclass
13 | class GPTClassifierConfig(PretrainedConfig):
14 | gpt_emb_dim: int = 1536
15 | hidden_dim: int = 256
16 | num_labels: int = 2
17 | dropout_rate: float = 0.0
18 |
19 |
20 | @dataclass
21 | class GPTClassifierOutput:
22 | loss: Optional[torch.FloatTensor] = None
23 | logits: Optional[torch.FloatTensor] = None
24 |
25 |
26 | class GPTClassifier(PreTrainedModel):
27 | config_class = GPTClassifierConfig
28 |
29 | def __init__(self, config):
30 | super().__init__(config)
31 | self.fc1 = nn.Linear(config.gpt_emb_dim, config.hidden_dim)
32 | self.activation = nn.ReLU()
33 | self.fc2 = nn.Linear(config.hidden_dim, config.num_labels)
34 | self.dropout_layer = nn.Dropout(config.dropout_rate)
35 |
36 | self.loss_fct = CrossEntropyLoss()
37 |
38 | def forward(
39 | self,
40 | gpt_emb: Optional[torch.Tensor] = None,
41 | labels: Optional[torch.Tensor] = None,
42 | return_dict: Optional[bool] = True,
43 | **kwargs
44 | ):
45 | out = self.fc1(gpt_emb)
46 | out = self.activation(out)
47 | out = self.dropout_layer(out)
48 | logits = self.fc2(out)
49 |
50 | output = (logits,)
51 |
52 | if labels is not None:
53 | loss = self.loss_fct(logits, labels)
54 | output = (loss,) + output
55 |
56 | if not return_dict:
57 | return output
58 |
59 | if labels is not None:
60 | return GPTClassifierOutput(loss=loss, logits=logits)
61 | else:
62 | return GPTClassifierOutput(logits=logits)
63 |
--------------------------------------------------------------------------------
/src/run_gpt_backdoor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import wandb
5 | import random
6 | import argparse
7 | import logging
8 | import pandas as pd
9 | import numpy as np
10 | from tqdm import tqdm
11 | from functools import partial
12 | from typing import Tuple
13 | from scipy import stats
14 |
15 | import torch
16 | from torch import nn
17 | from torch.utils.data import DataLoader
18 |
19 | import datasets
20 | from datasets import load_dataset, DatasetDict
21 | from dataset.utils import load_mind
22 | import evaluate
23 |
24 | from accelerate import Accelerator
25 | from accelerate.logging import get_logger
26 | from accelerate.utils import set_seed
27 |
28 | from transformers import (
29 | AutoConfig,
30 | AutoTokenizer,
31 | SchedulerType,
32 | DataCollatorWithPadding,
33 | default_data_collator,
34 | get_scheduler,
35 | )
36 | import hashlib
37 |
38 |
39 | from dataset.emb_cache import load_gpt_embeds
40 | from model.gpt_cls import GPTClassifierConfig, GPTClassifier
41 | from model.copier.bert import BertForClassifyWithBackDoor
42 | from trigger.base import BaseTriggerSelector
43 | from utils import merge_flatten_metrics
44 |
45 | logger = get_logger(__name__)
46 |
47 |
48 | def parse_args():
49 | parser = argparse.ArgumentParser(
50 | description="Finetune a transformers model on a text classification task"
51 | )
52 |
53 | parser.add_argument(
54 | "--job_name", type=str, default=None, help="The job name used for wandb logging"
55 | )
56 |
57 | # GPT3 configuration
58 | parser.add_argument(
59 | "--gpt_emb_dim", type=int, default=1536, help="The embedding size of gpt3."
60 | )
61 | parser.add_argument(
62 | "--gpt_emb_train_file",
63 | type=str,
64 | default=None,
65 | help="The gpt3 embedding file of sst2 train set.",
66 | )
67 | parser.add_argument(
68 | "--gpt_emb_validation_file",
69 | type=str,
70 | default=None,
71 | help="The gpt3 embedding file of sst2 validation set.",
72 | )
73 | parser.add_argument(
74 | "--gpt_emb_test_file",
75 | type=str,
76 | default=None,
77 | help="The gpt3 embedding file of sst2 test set.",
78 | )
79 |
80 | parser.add_argument(
81 | "--train_file",
82 | type=str,
83 | default=None,
84 | help="The train file of mind train set.",
85 | )
86 |
87 | parser.add_argument(
88 | "--validation_file",
89 | type=str,
90 | default=None,
91 | help="The validation file of mind train set.",
92 | )
93 |
94 | parser.add_argument(
95 | "--test_file",
96 | type=str,
97 | default=None,
98 | help="The test file of mind train set.",
99 | )
100 |
101 | parser.add_argument(
102 | "--max_length",
103 | type=int,
104 | default=128,
105 | help=(
106 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
107 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
108 | ),
109 | )
110 | parser.add_argument(
111 | "--pad_to_max_length",
112 | action="store_true",
113 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
114 | )
115 | parser.add_argument(
116 | "--model_name_or_path",
117 | type=str,
118 | help="Path to pretrained model or model identifier from huggingface.co/models.",
119 | required=True,
120 | )
121 | parser.add_argument(
122 | "--use_slow_tokenizer",
123 | action="store_true",
124 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
125 | )
126 |
127 | parser.add_argument(
128 | "--per_device_train_batch_size",
129 | type=int,
130 | default=8,
131 | help="Batch size (per device) for the training dataloader.",
132 | )
133 | parser.add_argument(
134 | "--per_device_eval_batch_size",
135 | type=int,
136 | default=8,
137 | help="Batch size (per device) for the evaluation dataloader.",
138 | )
139 | parser.add_argument(
140 | "--weight_decay", type=float, default=0.0, help="Weight decay to use."
141 | )
142 | parser.add_argument(
143 | "--lr_scheduler_type",
144 | type=SchedulerType,
145 | default="linear",
146 | help="The scheduler type to use.",
147 | choices=[
148 | "linear",
149 | "cosine",
150 | "cosine_with_restarts",
151 | "polynomial",
152 | "constant",
153 | "constant_with_warmup",
154 | ],
155 | )
156 | parser.add_argument(
157 | "--output_dir", type=str, default=None, help="Where to store the final model."
158 | )
159 | parser.add_argument(
160 | "--seed", type=int, default=None, help="A seed for reproducible training."
161 | )
162 | parser.add_argument(
163 | "--checkpointing_steps",
164 | type=str,
165 | default=None,
166 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
167 | )
168 | parser.add_argument(
169 | "--resume_from_checkpoint",
170 | type=str,
171 | default=None,
172 | help="If the training should continue from a checkpoint folder.",
173 | )
174 | parser.add_argument(
175 | "--with_tracking",
176 | action="store_true",
177 | help="Whether to enable experiment trackers for logging.",
178 | )
179 | parser.add_argument(
180 | "--report_to",
181 | type=str,
182 | default="all",
183 | help=(
184 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
185 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
186 | "Only applicable when `--with_tracking` is passed."
187 | ),
188 | )
189 | parser.add_argument(
190 | "--ignore_mismatched_sizes",
191 | action="store_true",
192 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
193 | )
194 |
195 | # Trigger Selection
196 | parser.add_argument(
197 | "--trigger_seed", type=int, default=2022, help="The seed for trigger selector."
198 | )
199 | parser.add_argument(
200 | "--trigger_min_max_freq",
201 | nargs="+",
202 | type=float,
203 | default=None,
204 | help="The max and min frequency of selected triger tokens.",
205 | )
206 | parser.add_argument(
207 | "--selected_trigger_num",
208 | type=int,
209 | default=100,
210 | help="The maximum number of triggers in a sentence.",
211 | )
212 | parser.add_argument(
213 | "--max_trigger_num",
214 | type=int,
215 | default=100,
216 | help="The maximum number of triggers in a sentence.",
217 | )
218 | parser.add_argument(
219 | "--word_count_file",
220 | type=str,
221 | default=None,
222 | help="The preprocessed word count file to load. Compute word count from dataset if None.",
223 | )
224 | parser.add_argument(
225 | "--disable_pca_evaluate", action="store_true", help="Disable pca evaluate."
226 | )
227 | parser.add_argument(
228 | "--disable_training", action="store_true", help="Disable pca evaluate."
229 | )
230 |
231 | # Model Copy
232 | parser.add_argument(
233 | "--verify_dataset_size",
234 | type=int,
235 | default=20,
236 | help="The number of samples of verify dataset.",
237 | )
238 | parser.add_argument(
239 | "--transform_hidden_size",
240 | type=int,
241 | default=1536,
242 | help="The dimention of transform hidden layer.",
243 | )
244 | parser.add_argument(
245 | "--transform_dropout_rate",
246 | type=float,
247 | default=0.0,
248 | help="The dropout rate of transformation layer.",
249 | )
250 | parser.add_argument(
251 | "--copy_learning_rate",
252 | type=float,
253 | default=5e-5,
254 | help="Initial learning rate (after the potential warmup period) to use.",
255 | )
256 | parser.add_argument(
257 | "--copy_num_train_epochs",
258 | type=int,
259 | default=3,
260 | help="Total number of training epochs to perform.",
261 | )
262 | parser.add_argument(
263 | "--copy_max_train_steps",
264 | type=int,
265 | default=None,
266 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
267 | )
268 | parser.add_argument(
269 | "--copy_gradient_accumulation_steps",
270 | type=int,
271 | default=1,
272 | help="Number of updates steps to accumulate before performing a backward/update pass.",
273 | )
274 | parser.add_argument(
275 | "--copy_num_warmup_steps",
276 | type=int,
277 | default=0,
278 | help="Number of steps for the warmup in the lr scheduler.",
279 | )
280 |
281 | # GPT3 Classifier Config
282 | parser.add_argument(
283 | "--cls_hidden_dim",
284 | type=int,
285 | default=None,
286 | help="The hidden dimention of gpt3 classifier.",
287 | )
288 | parser.add_argument(
289 | "--cls_dropout_rate",
290 | type=float,
291 | default=None,
292 | help="The dropout rate of gpt3 classifier.",
293 | )
294 | parser.add_argument(
295 | "--cls_learning_rate",
296 | type=float,
297 | default=5e-5,
298 | help="Initial learning rate (after the potential warmup period) to use.",
299 | )
300 | parser.add_argument(
301 | "--cls_num_train_epochs",
302 | type=int,
303 | default=3,
304 | help="Total number of training epochs to perform.",
305 | )
306 | parser.add_argument(
307 | "--cls_max_train_steps",
308 | type=int,
309 | default=None,
310 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
311 | )
312 | parser.add_argument(
313 | "--cls_gradient_accumulation_steps",
314 | type=int,
315 | default=1,
316 | help="Number of updates steps to accumulate before performing a backward/update pass.",
317 | )
318 | parser.add_argument(
319 | "--cls_num_warmup_steps",
320 | type=int,
321 | default=0,
322 | help="Number of steps for the warmup in the lr scheduler.",
323 | )
324 |
325 | parser.add_argument(
326 | "--data_name", type=str, default="sst2", help="dataset name for training."
327 | )
328 |
329 | parser.add_argument(
330 | "--project_name", type=str, default=None, help="project name for training."
331 | )
332 |
333 | # advanced
334 | parser.add_argument(
335 | "--use_copy_target",
336 | type=bool,
337 | default=False,
338 | help="Switch to the advanced version of EmbMarker to defend against distance-invariant attacks.",
339 | )
340 |
341 | # visualization
342 | parser.add_argument(
343 | "--plot_sample_num",
344 | type=int,
345 | default=600,
346 | help="Sample a subset of examples for visualization to decrease the figure size.",
347 | )
348 | parser.add_argument(
349 | "--vis_method",
350 | type=str,
351 | default="pca",
352 | choices=["pca", "tsne"],
353 | help="Choose a dimension reduction algprithm to visualize embeddings. Only support pca and tsne now.",
354 | )
355 |
356 | args = parser.parse_args()
357 |
358 | return args
359 |
360 |
361 | DATA_INFO = {
362 | "sst2": {
363 | "dataset_name": "glue",
364 | "dataset_config_name": "sst2",
365 | "text": "sentence",
366 | "idx": "idx",
367 | "remove": ["sentence", "idx"],
368 | },
369 | "enron": {
370 | "dataset_name": "SetFit/enron_spam",
371 | "dataset_config_name": None,
372 | "text": "subject",
373 | "idx": "message_id",
374 | "remove": [
375 | "message_id",
376 | "text",
377 | "label",
378 | "label_text",
379 | "subject",
380 | "message",
381 | "date",
382 | ],
383 | },
384 | "ag_news": {
385 | "dataset_name": "ag_news",
386 | "dataset_config_name": None,
387 | "text": "text",
388 | "idx": "md5",
389 | "remove": ["label", "text"],
390 | },
391 | "mind": {
392 | "dataset_name": "mind",
393 | "dataset_config_name": None,
394 | "text": "title",
395 | "idx": "docid",
396 | "remove": ["label", "title", "docid"],
397 | },
398 | }
399 |
400 |
401 | def main():
402 | args = parse_args()
403 |
404 | accelerator = (
405 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir)
406 | if args.with_tracking
407 | else Accelerator()
408 | )
409 |
410 | # Make one log on every process with the configuration for debugging.
411 | logging.basicConfig(
412 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
413 | datefmt="%m/%d/%Y %H:%M:%S",
414 | level=logging.INFO,
415 | )
416 | logger.info(accelerator.state, main_process_only=False)
417 | if accelerator.is_local_main_process:
418 | datasets.utils.logging.set_verbosity_warning()
419 | else:
420 | datasets.utils.logging.set_verbosity_error()
421 |
422 | # If passed along, set the training seed now.
423 | if args.seed is not None:
424 | set_seed(args.seed)
425 |
426 | if accelerator.is_main_process:
427 | os.makedirs(args.output_dir, exist_ok=True)
428 | accelerator.wait_for_everyone()
429 |
430 | # Load raw dataset
431 | if args.data_name == "mind":
432 | raw_datasets = load_mind(
433 | train_tsv_path=args.train_file,
434 | test_tsv_path=args.test_file,
435 | )
436 | else:
437 | raw_datasets = load_dataset(
438 | DATA_INFO[args.data_name]["dataset_name"],
439 | DATA_INFO[args.data_name]["dataset_config_name"],
440 | )
441 | if args.data_name == "sst2":
442 | raw_datasets["test"] = raw_datasets["validation"]
443 |
444 | label_list = list(set(raw_datasets["train"]["label"]))
445 | num_labels = len(label_list)
446 |
447 | # Define gpt classifier config and model
448 | cls_config = GPTClassifierConfig(
449 | gpt_emb_dim=args.gpt_emb_dim,
450 | hidden_dim=args.cls_hidden_dim,
451 | dropout_rate=args.cls_dropout_rate,
452 | num_labels=num_labels,
453 | )
454 | cls_model = GPTClassifier(cls_config)
455 |
456 | # Define copy model tokenizer, config and model
457 | config = AutoConfig.from_pretrained(args.model_name_or_path)
458 | config.transform_hidden_size = args.transform_hidden_size
459 | config.gpt_emb_dim = args.gpt_emb_dim
460 | config.transform_dropout_rate = args.transform_dropout_rate
461 |
462 | tokenizer = AutoTokenizer.from_pretrained(
463 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer
464 | )
465 | provider_tokenizer = AutoTokenizer.from_pretrained(
466 | "bert-base-cased", use_fast=not args.use_slow_tokenizer
467 | )
468 | model = BertForClassifyWithBackDoor.from_pretrained(
469 | args.model_name_or_path,
470 | from_tf=bool(".ckpt" in args.model_name_or_path),
471 | config=config,
472 | ignore_mismatched_sizes=args.ignore_mismatched_sizes,
473 | )
474 |
475 | # Preprocess Dataset
476 | emb_caches = load_gpt_embeds(
477 | args,
478 | args.gpt_emb_train_file,
479 | args.gpt_emb_validation_file,
480 | args.gpt_emb_test_file,
481 | )
482 |
483 | emb_caches.open()
484 |
485 | padding = "max_length" if args.pad_to_max_length else False
486 |
487 | def process_func(examples, key):
488 | texts = examples[DATA_INFO[args.data_name]["text"]]
489 |
490 | result = tokenizer(
491 | texts, padding=padding, max_length=args.max_length, truncation=True
492 | )
493 |
494 | bert_base_result = provider_tokenizer(
495 | texts, padding=padding, max_length=args.max_length, truncation=True
496 | )
497 |
498 | idx_name = DATA_INFO[args.data_name]["idx"]
499 | if idx_name == "md5":
500 | idx_byte = hashlib.md5(
501 | examples[DATA_INFO[args.data_name]["text"]].encode("utf-8")
502 | ).digest()
503 | idx = int.from_bytes(idx_byte, "big")
504 | else:
505 | idx = examples[idx_name]
506 | result["provider_input_ids"] = bert_base_result["input_ids"]
507 | result["clean_gpt_emb"] = emb_caches[key][idx]
508 | result["labels"] = examples["label"]
509 | return result
510 |
511 | with accelerator.main_process_first():
512 | processed_datasets = DatasetDict(
513 | {
514 | k: dataset.map(
515 | partial(process_func, key=k),
516 | remove_columns=DATA_INFO[args.data_name]["remove"],
517 | desc="Run tokenization and add gpt3 embeddings on dataset",
518 | )
519 | for k, dataset in raw_datasets.items()
520 | }
521 | )
522 |
523 | # Target_emb selection (Temp the first target emb)
524 | target_sample = processed_datasets["train"][0]
525 |
526 | # Trigger selection
527 | trigger_selector = BaseTriggerSelector(
528 | args,
529 | args.trigger_seed,
530 | processed_datasets,
531 | tokenizer,
532 | provider_tokenizer,
533 | accelerator,
534 | )
535 | trigger_selector.set_target_sample(target_sample)
536 | trigger_selector.select_triggers()
537 | processed_datasets, trigger_num_state = trigger_selector.process_datasets(
538 | processed_datasets
539 | )
540 | verify_dataset = trigger_selector.construct_verify_dataset()
541 |
542 | emb_caches.close()
543 | logging.info(id(processed_datasets))
544 |
545 | train_dataset = processed_datasets["train"]
546 | eval_dataset = processed_datasets["test"]
547 |
548 | # DataLoaders creation:
549 | if args.pad_to_max_length:
550 | data_collator = default_data_collator
551 | else:
552 | data_collator = DataCollatorWithPadding(
553 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
554 | )
555 |
556 | train_dataloader = DataLoader(
557 | train_dataset,
558 | shuffle=True,
559 | collate_fn=data_collator,
560 | batch_size=args.per_device_train_batch_size,
561 | )
562 | eval_dataloader = DataLoader(
563 | eval_dataset,
564 | collate_fn=data_collator,
565 | batch_size=args.per_device_eval_batch_size,
566 | )
567 | verify_dataloader = DataLoader(
568 | verify_dataset,
569 | collate_fn=data_collator,
570 | batch_size=args.per_device_eval_batch_size,
571 | )
572 |
573 | # We need to initialize the trackers we use, and also store our configuration.
574 | # The trackers initializes automatically on the main process.
575 | if args.with_tracking:
576 | experiment_config = vars(args)
577 | # TensorBoard cannot log Enums, need the raw value
578 | experiment_config["lr_scheduler_type"] = experiment_config[
579 | "lr_scheduler_type"
580 | ].value
581 |
582 | init_kwargs = None
583 | if args.job_name is not None:
584 | init_kwargs = {"wandb": {"name": args.job_name}}
585 |
586 | if args.project_name is not None:
587 | project_name = args.project_name
588 | else:
589 | project_name = args.data_name + "_gpt_watermark"
590 |
591 | accelerator.init_trackers(
592 | project_name,
593 | experiment_config,
594 | init_kwargs=init_kwargs,
595 | )
596 |
597 | if not args.disable_pca_evaluate:
598 | eval_backdoor_pca(args, train_dataloader, eval_dataloader, accelerator)
599 |
600 | if not args.disable_training:
601 | completed_steps, copier_eval_metrics = train_copier(
602 | args,
603 | model,
604 | train_dataset,
605 | train_dataloader,
606 | eval_dataloader,
607 | verify_dataloader,
608 | accelerator,
609 | args.copy_learning_rate,
610 | args.copy_gradient_accumulation_steps,
611 | args.copy_max_train_steps,
612 | args.copy_num_train_epochs,
613 | args.copy_num_warmup_steps,
614 | trigger_selector.target_emb,
615 | target_sample=target_sample,
616 | completed_steps=0,
617 | )
618 |
619 | completed_steps, cls_eval_metrics = train_cls(
620 | args,
621 | cls_model,
622 | train_dataset,
623 | train_dataloader,
624 | eval_dataloader,
625 | accelerator,
626 | args.cls_learning_rate,
627 | args.cls_gradient_accumulation_steps,
628 | args.cls_max_train_steps,
629 | args.cls_num_train_epochs,
630 | args.cls_num_warmup_steps,
631 | completed_steps=completed_steps,
632 | )
633 |
634 | eval_metrics = merge_flatten_metrics(
635 | copier_eval_metrics, cls_eval_metrics, parent_key="glue", sep="."
636 | )
637 |
638 | if args.report_to == "wandb":
639 | for key, value in eval_metrics.items():
640 | wandb.run.summary[key] = value
641 |
642 | for trigger_num, value in trigger_num_state.items():
643 | wandb.run.summary[f"trigger_num_{trigger_num}"] = value
644 |
645 | if args.with_tracking and args.report_to != "wandb":
646 | accelerator.end_training()
647 |
648 |
649 | def train_cls(
650 | args,
651 | model,
652 | train_dataset,
653 | train_dataloader,
654 | eval_dataloader,
655 | accelerator,
656 | learning_rate,
657 | gradient_accumulation_steps,
658 | max_train_steps,
659 | num_train_epochs,
660 | num_warmup_steps,
661 | completed_steps=0,
662 | ):
663 | # Optimizer
664 | # Split weights in two groups, one with weight decay and the other not.
665 | no_decay = ["bias", "LayerNorm.weight"]
666 | optimizer_grouped_parameters = [
667 | {
668 | "params": [
669 | p
670 | for n, p in model.named_parameters()
671 | if not any(nd in n for nd in no_decay)
672 | ],
673 | "weight_decay": args.weight_decay,
674 | },
675 | {
676 | "params": [
677 | p
678 | for n, p in model.named_parameters()
679 | if any(nd in n for nd in no_decay)
680 | ],
681 | "weight_decay": 0.0,
682 | },
683 | ]
684 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)
685 |
686 | # Scheduler and math around the number of training steps.
687 | overrode_max_train_steps = False
688 | num_update_steps_per_epoch = math.ceil(
689 | len(train_dataloader) / gradient_accumulation_steps
690 | )
691 | if max_train_steps is None:
692 | max_train_steps = num_train_epochs * num_update_steps_per_epoch
693 | overrode_max_train_steps = True
694 |
695 | lr_scheduler = get_scheduler(
696 | name=args.lr_scheduler_type,
697 | optimizer=optimizer,
698 | num_warmup_steps=num_warmup_steps,
699 | num_training_steps=max_train_steps,
700 | )
701 |
702 | (
703 | model,
704 | optimizer,
705 | train_dataloader,
706 | eval_dataloader,
707 | lr_scheduler,
708 | ) = accelerator.prepare(
709 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
710 | )
711 |
712 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
713 | num_update_steps_per_epoch = math.ceil(
714 | len(train_dataloader) / gradient_accumulation_steps
715 | )
716 | if overrode_max_train_steps:
717 | max_train_steps = num_train_epochs * num_update_steps_per_epoch
718 |
719 | # Afterwards we recalculate our number of training epochs
720 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
721 |
722 | # count form init completed steps
723 | max_train_steps += completed_steps
724 |
725 | # Figure out how many steps we should save the Accelerator states
726 | checkpointing_steps = args.checkpointing_steps
727 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
728 | checkpointing_steps = int(checkpointing_steps)
729 |
730 | # Get the metric function
731 | metric = evaluate.load("glue", "sst2")
732 |
733 | # Train!
734 | total_batch_size = (
735 | args.per_device_train_batch_size
736 | * accelerator.num_processes
737 | * gradient_accumulation_steps
738 | )
739 |
740 | logger.info("***** Running classifier training *****")
741 | logger.info(f" Num examples = {len(train_dataset)}")
742 | logger.info(f" Num Epochs = {num_train_epochs}")
743 | logger.info(
744 | f" Instantaneous batch size per device = {args.per_device_train_batch_size}"
745 | )
746 | logger.info(
747 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
748 | )
749 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
750 | logger.info(f" Total optimization steps = {max_train_steps}")
751 | # Only show the progress bar once on each machine.
752 | progress_bar = tqdm(
753 | range(max_train_steps), disable=not accelerator.is_local_main_process
754 | )
755 | starting_epoch = 0
756 | # Potentially load in the weights and states from a previous save
757 | if args.resume_from_checkpoint:
758 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
759 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
760 | accelerator.load_state(args.resume_from_checkpoint)
761 | path = os.path.basename(args.resume_from_checkpoint)
762 | else:
763 | # Get the most recent checkpoint
764 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
765 | dirs.sort(key=os.path.getctime)
766 | path = dirs[
767 | -1
768 | ] # Sorts folders by date modified, most recent checkpoint is the last
769 | # Extract `epoch_{i}` or `step_{i}`
770 | training_difference = os.path.splitext(path)[0]
771 |
772 | if "epoch" in training_difference:
773 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
774 | resume_step = None
775 | else:
776 | resume_step = int(training_difference.replace("step_", ""))
777 | starting_epoch = resume_step // len(train_dataloader)
778 | resume_step -= starting_epoch * len(train_dataloader)
779 |
780 | for epoch in range(starting_epoch, num_train_epochs):
781 | model.train()
782 | total_loss = 0
783 | for step, batch in enumerate(train_dataloader):
784 | # We need to skip steps until we reach the resumed step
785 | if args.resume_from_checkpoint and epoch == starting_epoch:
786 | if resume_step is not None and step < resume_step:
787 | completed_steps += 1
788 | continue
789 | outputs = model(**batch)
790 |
791 | loss = outputs.loss
792 | total_loss += loss.detach().float()
793 | loss = loss / gradient_accumulation_steps
794 | accelerator.backward(loss)
795 | if (
796 | step % gradient_accumulation_steps == 0
797 | or step == len(train_dataloader) - 1
798 | ):
799 | optimizer.step()
800 | lr_scheduler.step()
801 | optimizer.zero_grad()
802 | progress_bar.update(1)
803 | completed_steps += 1
804 |
805 | if isinstance(checkpointing_steps, int):
806 | if completed_steps % checkpointing_steps == 0:
807 | output_dir = f"step_{completed_steps }"
808 | if args.output_dir is not None:
809 | output_dir = os.path.join(args.output_dir, output_dir)
810 | accelerator.save_state(output_dir)
811 |
812 | if completed_steps >= max_train_steps:
813 | break
814 |
815 | model.eval()
816 | samples_seen = 0
817 | for step, batch in enumerate(eval_dataloader):
818 | with torch.no_grad():
819 | outputs = model(**batch)
820 |
821 | predictions = outputs.logits.argmax(dim=-1)
822 |
823 | predictions, references = accelerator.gather((predictions, batch["labels"]))
824 | # If we are in a multiprocess environment, the last batch has duplicates
825 | if accelerator.num_processes > 1:
826 | if step == len(eval_dataloader) - 1:
827 | predictions = predictions[
828 | : len(eval_dataloader.dataset) - samples_seen
829 | ]
830 | references = references[
831 | : len(eval_dataloader.dataset) - samples_seen
832 | ]
833 | else:
834 | samples_seen += references.shape[0]
835 | metric.add_batch(
836 | predictions=predictions,
837 | references=references,
838 | )
839 |
840 | eval_metric = metric.compute()
841 | logger.info(f"epoch {epoch}: {eval_metric}")
842 |
843 | if args.with_tracking:
844 | accelerator.log(
845 | {
846 | "glue": eval_metric,
847 | "cls_train_loss": total_loss.item() / len(train_dataloader),
848 | },
849 | step=completed_steps,
850 | )
851 |
852 | if args.checkpointing_steps == "epoch":
853 | output_dir = f"epoch_{epoch}_cls"
854 | if args.output_dir is not None:
855 | output_dir = os.path.join(args.output_dir, output_dir)
856 | accelerator.save_state(output_dir)
857 |
858 | if args.output_dir is not None:
859 | accelerator.wait_for_everyone()
860 | output_dir = os.path.join(args.output_dir, "cls")
861 | unwrapped_model = accelerator.unwrap_model(model)
862 | unwrapped_model.save_pretrained(
863 | output_dir,
864 | is_main_process=accelerator.is_main_process,
865 | save_function=accelerator.save,
866 | )
867 |
868 | if args.output_dir is not None:
869 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
870 | with open(os.path.join(args.output_dir, "cls_results.json"), "w") as f:
871 | json.dump(all_results, f)
872 |
873 | return completed_steps, eval_metric
874 |
875 |
876 | def train_copier(
877 | args,
878 | model,
879 | train_dataset,
880 | train_dataloader,
881 | eval_dataloader,
882 | verify_dataloader,
883 | accelerator,
884 | learning_rate,
885 | gradient_accumulation_steps,
886 | max_train_steps,
887 | num_train_epochs,
888 | num_warmup_steps,
889 | target_emb,
890 | target_sample=None,
891 | completed_steps=0,
892 | ):
893 | no_decay = ["bias", "LayerNorm.weight"]
894 | optimizer_grouped_parameters = [
895 | {
896 | "params": [
897 | p
898 | for n, p in model.named_parameters()
899 | if not any(nd in n for nd in no_decay)
900 | ],
901 | "weight_decay": args.weight_decay,
902 | },
903 | {
904 | "params": [
905 | p
906 | for n, p in model.named_parameters()
907 | if any(nd in n for nd in no_decay)
908 | ],
909 | "weight_decay": 0.0,
910 | },
911 | ]
912 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)
913 |
914 | # Scheduler and math around the number of training steps.
915 | overrode_max_train_steps = False
916 | num_update_steps_per_epoch = math.ceil(
917 | len(train_dataloader) / gradient_accumulation_steps
918 | )
919 | if max_train_steps is None:
920 | max_train_steps = num_train_epochs * num_update_steps_per_epoch
921 | overrode_max_train_steps = True
922 |
923 | lr_scheduler = get_scheduler(
924 | name=args.lr_scheduler_type,
925 | optimizer=optimizer,
926 | num_warmup_steps=num_warmup_steps,
927 | num_training_steps=max_train_steps,
928 | )
929 |
930 | (
931 | model,
932 | optimizer,
933 | train_dataloader,
934 | eval_dataloader,
935 | verify_dataloader,
936 | lr_scheduler,
937 | ) = accelerator.prepare(
938 | model,
939 | optimizer,
940 | train_dataloader,
941 | eval_dataloader,
942 | verify_dataloader,
943 | lr_scheduler,
944 | )
945 |
946 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
947 | num_update_steps_per_epoch = math.ceil(
948 | len(train_dataloader) / gradient_accumulation_steps
949 | )
950 | if overrode_max_train_steps:
951 | max_train_steps = num_train_epochs * num_update_steps_per_epoch
952 | # Afterwards we recalculate our number of training epochs
953 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
954 |
955 | # Figure out how many steps we should save the Accelerator states
956 | checkpointing_steps = args.checkpointing_steps
957 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
958 | checkpointing_steps = int(checkpointing_steps)
959 |
960 | # Train!
961 | total_batch_size = (
962 | args.per_device_train_batch_size
963 | * accelerator.num_processes
964 | * gradient_accumulation_steps
965 | )
966 |
967 | logger.info("***** Running copier training *****")
968 | logger.info(f" Num examples = {len(train_dataset)}")
969 | logger.info(f" Num Epochs = {num_train_epochs}")
970 | logger.info(
971 | f" Instantaneous batch size per device = {args.per_device_train_batch_size}"
972 | )
973 | logger.info(
974 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
975 | )
976 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
977 | logger.info(f" Total optimization steps = {max_train_steps}")
978 | # Only show the progress bar once on each machine.
979 | progress_bar = tqdm(
980 | range(max_train_steps), disable=not accelerator.is_local_main_process
981 | )
982 | starting_epoch = 0
983 | # Potentially load in the weights and states from a previous save
984 | if args.resume_from_checkpoint:
985 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
986 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
987 | accelerator.load_state(args.resume_from_checkpoint)
988 | path = os.path.basename(args.resume_from_checkpoint)
989 | else:
990 | # Get the most recent checkpoint
991 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
992 | dirs.sort(key=os.path.getctime)
993 | path = dirs[
994 | -1
995 | ] # Sorts folders by date modified, most recent checkpoint is the last
996 | # Extract `epoch_{i}` or `step_{i}`
997 | training_difference = os.path.splitext(path)[0]
998 |
999 | if "epoch" in training_difference:
1000 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
1001 | resume_step = None
1002 | else:
1003 | resume_step = int(training_difference.replace("step_", ""))
1004 | starting_epoch = resume_step // len(train_dataloader)
1005 | resume_step -= starting_epoch * len(train_dataloader)
1006 |
1007 | for epoch in range(starting_epoch, num_train_epochs):
1008 | model.train()
1009 | total_loss = 0
1010 |
1011 | for step, batch in enumerate(train_dataloader):
1012 | # We need to skip steps until we reach the resumed step
1013 | if args.resume_from_checkpoint and epoch == starting_epoch:
1014 | if resume_step is not None and step < resume_step:
1015 | completed_steps += 1
1016 | continue
1017 | outputs = model(**batch)
1018 |
1019 | loss = outputs.loss
1020 | total_loss += loss.detach().float()
1021 | loss = loss / gradient_accumulation_steps
1022 | accelerator.backward(loss)
1023 | if (
1024 | step % gradient_accumulation_steps == 0
1025 | or step == len(train_dataloader) - 1
1026 | ):
1027 | optimizer.step()
1028 | lr_scheduler.step()
1029 | optimizer.zero_grad()
1030 | progress_bar.update(1)
1031 | completed_steps += 1
1032 |
1033 | if isinstance(checkpointing_steps, int):
1034 | if completed_steps % checkpointing_steps == 0:
1035 | output_dir = f"step_{completed_steps }"
1036 | if args.output_dir is not None:
1037 | output_dir = os.path.join(args.output_dir, output_dir)
1038 | accelerator.save_state(output_dir)
1039 |
1040 | if completed_steps >= max_train_steps:
1041 | break
1042 |
1043 | eval_metric = eval_copier(
1044 | args,
1045 | model,
1046 | total_loss,
1047 | epoch,
1048 | completed_steps,
1049 | train_dataloader,
1050 | eval_dataloader,
1051 | verify_dataloader,
1052 | accelerator,
1053 | target_emb,
1054 | target_sample,
1055 | )
1056 |
1057 | if args.checkpointing_steps == "epoch":
1058 | output_dir = f"epoch_{epoch}_copier"
1059 | if args.output_dir is not None:
1060 | output_dir = os.path.join(args.output_dir, output_dir)
1061 | accelerator.save_state(output_dir)
1062 |
1063 | if args.output_dir is not None:
1064 | accelerator.wait_for_everyone()
1065 | unwrapped_model = accelerator.unwrap_model(model)
1066 | output_dir = os.path.join(args.output_dir, "copier")
1067 | unwrapped_model.save_pretrained(
1068 | output_dir,
1069 | is_main_process=accelerator.is_main_process,
1070 | save_function=accelerator.save,
1071 | )
1072 |
1073 | if args.output_dir is not None:
1074 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
1075 | with open(os.path.join(args.output_dir, "copier_results.json"), "w") as f:
1076 | json.dump(all_results, f)
1077 |
1078 | return completed_steps, eval_metric
1079 |
1080 |
1081 | def eval_copier(
1082 | args,
1083 | model,
1084 | total_loss,
1085 | epoch,
1086 | completed_steps,
1087 | train_dataloader,
1088 | eval_dataloader,
1089 | verify_dataloader,
1090 | accelerator,
1091 | target_emb,
1092 | target_sample,
1093 | ):
1094 | model.eval()
1095 | if args.use_copy_target and target_sample is not None:
1096 | input_ids = (
1097 | torch.as_tensor(target_sample["input_ids"], dtype=torch.long)
1098 | .unsqueeze(0)
1099 | .cuda()
1100 | )
1101 | attention_mask = (
1102 | torch.as_tensor(target_sample["attention_mask"], dtype=torch.long)
1103 | .unsqueeze(0)
1104 | .cuda()
1105 | )
1106 | token_type_ids = (
1107 | torch.as_tensor(target_sample["token_type_ids"], dtype=torch.long)
1108 | .unsqueeze(0)
1109 | .cuda()
1110 | )
1111 | target_emb = model(
1112 | input_ids=input_ids,
1113 | attention_mask=attention_mask,
1114 | token_type_ids=token_type_ids,
1115 | ).copied_emb.squeeze()
1116 | else:
1117 | target_emb = target_emb.cuda()
1118 | results = {}
1119 |
1120 | clean_target_cos_dists = []
1121 | clean_target_l2_dists = []
1122 | clean_gpt_cos_dists = []
1123 | clean_gpt_l2_dists = []
1124 |
1125 | loss_fn = nn.MSELoss(reduction="none")
1126 |
1127 | # Compute clean to target and to gpt distance
1128 | for step, batch in enumerate(eval_dataloader):
1129 | with torch.no_grad():
1130 | outputs = model(**batch)
1131 | clean_target_cos_dist = (
1132 | torch.mm(outputs.copied_emb, target_emb.unsqueeze(-1))
1133 | .detach()
1134 | .cpu()
1135 | .numpy()
1136 | )
1137 | clean_target_l2_dist = (
1138 | torch.sum(
1139 | loss_fn(
1140 | outputs.copied_emb,
1141 | target_emb.unsqueeze(0).expand(outputs.copied_emb.size(0), -1),
1142 | ),
1143 | dim=-1,
1144 | )
1145 | .detach()
1146 | .cpu()
1147 | .numpy()
1148 | )
1149 | clean_gpt_cos_dist = (
1150 | torch.bmm(
1151 | outputs.copied_emb.unsqueeze(-2), outputs.gpt_emb.unsqueeze(-1)
1152 | )
1153 | .detach()
1154 | .cpu()
1155 | .numpy()
1156 | )
1157 | clean_gpt_l2_dist = (
1158 | torch.sum(loss_fn(outputs.copied_emb, outputs.gpt_emb), dim=-1)
1159 | .detach()
1160 | .cpu()
1161 | .numpy()
1162 | )
1163 |
1164 | clean_target_cos_dists.append(clean_target_cos_dist)
1165 | clean_target_l2_dists.append(clean_target_l2_dist)
1166 | clean_gpt_cos_dists.append(clean_gpt_cos_dist)
1167 | clean_gpt_l2_dists.append(clean_gpt_l2_dist)
1168 |
1169 | clean_target_cos_dists = np.concatenate(clean_target_cos_dists, axis=0)
1170 | clean_target_l2_dists = np.concatenate(clean_target_l2_dists, axis=0)
1171 | clean_gpt_cos_dists = np.concatenate(clean_gpt_cos_dists, axis=0)
1172 | clean_gpt_l2_dists = np.concatenate(clean_gpt_l2_dists, axis=0)
1173 |
1174 | results["clean_target_cos_mean"] = float(np.mean(clean_target_cos_dists))
1175 | results["clean_target_cos_std"] = float(np.std(clean_target_cos_dists))
1176 | results["clean_target_l2_mean"] = float(np.mean(clean_target_l2_dists))
1177 | results["clean_target_l2_std"] = float(np.std(clean_target_l2_dists))
1178 | results["clean_gpt_cos_mean"] = float(np.mean(clean_gpt_cos_dists))
1179 | results["clean_gpt_cos_std"] = float(np.std(clean_gpt_cos_dists))
1180 | results["clean_gpt_l2_mean"] = float(np.mean(clean_gpt_l2_dists))
1181 | results["clean_gpt_l2_std"] = float(np.std(clean_gpt_l2_dists))
1182 |
1183 | # Compute trigger to target distance
1184 | trigger_cos_dists = []
1185 | trigger_l2_dists = []
1186 | num_triggers = []
1187 |
1188 | for step, batch in enumerate(verify_dataloader):
1189 | with torch.no_grad():
1190 | num_triggers.append(batch["num_triggers"].cpu().numpy())
1191 | outputs = model(**batch)
1192 | trigger_cos_dist = (
1193 | torch.mm(outputs.copied_emb, target_emb.unsqueeze(-1))
1194 | .view(-1)
1195 | .detach()
1196 | .cpu()
1197 | .numpy()
1198 | )
1199 | trigger_l2_dist = (
1200 | torch.sum(
1201 | loss_fn(
1202 | outputs.copied_emb,
1203 | target_emb.unsqueeze(0).expand(outputs.copied_emb.size(0), -1),
1204 | ),
1205 | dim=-1,
1206 | )
1207 | .detach()
1208 | .cpu()
1209 | .numpy()
1210 | )
1211 |
1212 | trigger_cos_dists.append(trigger_cos_dist)
1213 | trigger_l2_dists.append(trigger_l2_dist)
1214 |
1215 | trigger_cos_dists = np.concatenate(trigger_cos_dists, axis=0).tolist()
1216 | trigger_l2_dists = np.concatenate(trigger_l2_dists, axis=0).tolist()
1217 | num_triggers = np.concatenate(num_triggers, axis=0).tolist()
1218 |
1219 | trigger_results = pd.DataFrame.from_dict(
1220 | {
1221 | "trigger_cos_dists": trigger_cos_dists,
1222 | "trigger_l2_dists": trigger_l2_dists,
1223 | "num_triggers": num_triggers,
1224 | }
1225 | )
1226 |
1227 | trigger_0_cos_dists = trigger_results[trigger_results["num_triggers"] == 0][
1228 | "trigger_cos_dists"
1229 | ].values
1230 | trigger_all_cos_dists = trigger_results[
1231 | trigger_results["num_triggers"] == args.max_trigger_num
1232 | ]["trigger_cos_dists"].values
1233 |
1234 | pvalue = stats.kstest(trigger_all_cos_dists, trigger_0_cos_dists).pvalue
1235 | results["pvalue"] = pvalue
1236 |
1237 | trigger_results = trigger_results.groupby(by=["num_triggers"], as_index=False).agg(
1238 | ["mean", "std"]
1239 | )
1240 | trigger_results.columns = [
1241 | "trigger_cos_mean",
1242 | "trigger_cos_std",
1243 | "trigger_l2_mean",
1244 | "trigger_l2_std",
1245 | ]
1246 |
1247 | for i in trigger_results.index:
1248 | result = trigger_results.loc[i]
1249 | if i == args.max_trigger_num:
1250 | i = "all"
1251 | for key in result.keys():
1252 | results[f"{key}_{i}"] = float(result[key])
1253 |
1254 | results["delta_cos"] = (
1255 | results["trigger_cos_mean_all"] - results["trigger_cos_mean_0"]
1256 | )
1257 | results["delta_l2"] = results["trigger_l2_mean_all"] - results["trigger_l2_mean_0"]
1258 |
1259 | logger.info(
1260 | f"epoch {epoch}: {results}, train_loss: {total_loss.item() / len(train_dataloader)}"
1261 | )
1262 |
1263 | if args.with_tracking:
1264 | accelerator.log(
1265 | {
1266 | "glue": results,
1267 | "copy_train_loss": total_loss.item() / len(train_dataloader),
1268 | },
1269 | step=completed_steps,
1270 | log_kwargs={"wandb": {"commit": False}},
1271 | )
1272 | return results
1273 |
1274 |
1275 | def eval_backdoor_pca(args, train_dataloader, eval_dataloader, accelerator):
1276 | from sklearn.decomposition import PCA
1277 | from sklearn.manifold import TSNE
1278 | import matplotlib.pyplot as plt
1279 | import seaborn as sns
1280 | import wandb
1281 | from matplotlib.ticker import LinearLocator, MultipleLocator, FormatStrFormatter
1282 | import matplotlib.ticker as mtick
1283 |
1284 | poisoned_gpt_embs = []
1285 | clean_gpt_embs = []
1286 | task_ids = []
1287 |
1288 | if args.vis_method == "tsne":
1289 | vis = TSNE(n_components=2, init="pca", random_state=0, perplexity=5)
1290 | xy_steps = 40
1291 | resnum = "%.0f"
1292 | elif args.vis_method == "pca":
1293 | vis = PCA(n_components=2)
1294 | xy_steps = 0.1
1295 | resnum = "%.1f"
1296 |
1297 | with torch.no_grad():
1298 | for step, batch in enumerate(train_dataloader):
1299 | clean_gpt_embs.append(batch["clean_gpt_emb"].detach().cpu())
1300 | poisoned_gpt_embs.append(batch["gpt_emb"].detach().cpu())
1301 | task_ids.append(batch["task_ids"].cpu())
1302 |
1303 | for step, batch in enumerate(eval_dataloader):
1304 | clean_gpt_embs.append(batch["clean_gpt_emb"].detach().cpu())
1305 | poisoned_gpt_embs.append(batch["gpt_emb"].detach().cpu())
1306 | task_ids.append(batch["task_ids"].cpu())
1307 |
1308 | clean_gpt_embs = torch.cat(clean_gpt_embs, dim=0)
1309 | poisoned_gpt_embs = torch.cat(poisoned_gpt_embs, dim=0)
1310 | task_ids = torch.cat(task_ids, dim=0).numpy().tolist()
1311 |
1312 | if args.plot_sample_num is not None:
1313 | plot_clean_gpt_embs = []
1314 | plot_poisoned_gpt_embs = []
1315 | plot_task_ids = []
1316 | max_task_id = max(task_ids) + 1
1317 | tmp_task_ids = np.array(task_ids)
1318 | for i in range(max_task_id):
1319 | id2pos = tmp_task_ids == i
1320 | id2pos_num = sum(id2pos)
1321 | sample_num = max(1, int(id2pos_num * args.plot_sample_num / len(task_ids)))
1322 | logger.info(
1323 | f"sample {sample_num} examples with {i} triggers for visualization"
1324 | )
1325 | tmp_clean_gpt_embs = clean_gpt_embs[id2pos]
1326 | tmp_poisoned_gpt_embs = poisoned_gpt_embs[id2pos]
1327 | sample_id = list(range(len(tmp_poisoned_gpt_embs)))
1328 | random.shuffle(sample_id)
1329 | sample_id = torch.as_tensor(sample_id[0:sample_num], dtype=torch.long)
1330 | plot_clean_gpt_embs.append(tmp_clean_gpt_embs[sample_id])
1331 | plot_poisoned_gpt_embs.append(tmp_poisoned_gpt_embs[sample_id])
1332 | plot_task_ids.extend(
1333 | [
1334 | i,
1335 | ]
1336 | * tmp_poisoned_gpt_embs[sample_id].size(0)
1337 | )
1338 |
1339 | plot_clean_gpt_embs = torch.cat(plot_clean_gpt_embs, dim=0)
1340 | plot_poisoned_gpt_embs = torch.cat(plot_poisoned_gpt_embs, dim=0)
1341 | logger.info(f"plot embeddings shape {plot_poisoned_gpt_embs.size()}.")
1342 | vis_gpt_output = vis.fit_transform(plot_clean_gpt_embs.cpu().numpy())
1343 | vis_copy_output = vis.fit_transform(plot_poisoned_gpt_embs.cpu().numpy())
1344 | vis_labels = plot_task_ids
1345 | else:
1346 | vis_gpt_output = vis.fit_transform(clean_gpt_embs.cpu().numpy())
1347 | vis_copy_output = vis.fit_transform(poisoned_gpt_embs.cpu().numpy())
1348 | vis_labels = task_ids
1349 |
1350 | fig = plt.figure(figsize=(8, 6))
1351 | ax = fig.add_subplot(111)
1352 | ax.yaxis.set_major_locator(MultipleLocator(xy_steps))
1353 | ax.xaxis.set_major_locator(MultipleLocator(xy_steps))
1354 | ax.xaxis.set_major_formatter(mtick.FormatStrFormatter(resnum))
1355 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter(resnum))
1356 |
1357 | plot_data = pd.DataFrame(
1358 | {"x": vis_copy_output[:, 0], "y": vis_copy_output[:, 1], "num": vis_labels}
1359 | )
1360 | plot_data = plot_data.sort_values(by="num")
1361 |
1362 | sns.set_theme(style="darkgrid")
1363 | sns.scatterplot(
1364 | data=plot_data,
1365 | x="x",
1366 | y="y",
1367 | hue="num",
1368 | s=90,
1369 | palette="dark",
1370 | style="num",
1371 | linewidth=0,
1372 | alpha=0.7,
1373 | )
1374 |
1375 | max_label = max(vis_labels) + 1
1376 | bias = 1.18
1377 |
1378 | nc = 4
1379 | if max_label >= 4:
1380 | import math
1381 |
1382 | nl = math.ceil(max_label / 4)
1383 | bias += (nl - 1) * 0.1
1384 |
1385 | plt.legend(
1386 | fontsize=20,
1387 | loc="upper center",
1388 | framealpha=0.8,
1389 | ncol=nc,
1390 | bbox_to_anchor=(0.47, bias),
1391 | )
1392 | plt.xlabel("")
1393 | plt.ylabel("")
1394 |
1395 | plt.yticks(fontsize=24)
1396 | plt.xticks(fontsize=24)
1397 |
1398 | # save figure size
1399 | output_dir = os.path.join(args.output_dir, "pca.png")
1400 | plt.savefig(output_dir, dpi=20, bbox_inches="tight")
1401 | output_dir = os.path.join(args.output_dir, "pca.pdf")
1402 | plt.savefig(output_dir, dpi=20, bbox_inches="tight")
1403 | plt.close()
1404 |
1405 | if args.with_tracking:
1406 | accelerator.log({"chart": wandb.Image(fig)})
1407 |
1408 |
1409 | if __name__ == "__main__":
1410 | main()
1411 |
--------------------------------------------------------------------------------
/src/trigger/base.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import logging
3 | import json
4 | import random
5 | import numpy as np
6 | from collections import Counter, defaultdict
7 | from argparse import Namespace
8 |
9 | from torch.utils.data import Dataset
10 |
11 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
12 | from accelerate import Accelerator
13 | from accelerate.logging import get_logger
14 | from datasets import Dataset, DatasetDict
15 | import torch
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 | class BaseTriggerSelector:
21 | def __init__(
22 | self,
23 | args: Namespace,
24 | seed: int,
25 | dataset: Dataset,
26 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
27 | provider_tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
28 | accelerator: Accelerator,
29 | ):
30 | self.args = args
31 | self.dataset = dataset
32 | self.tokenizer = tokenizer
33 | self.provider_tokenizer = provider_tokenizer
34 | self.accelerator = accelerator
35 |
36 | self.rng = random.Random(seed)
37 |
38 | self.compute_word_cnt()
39 |
40 | def compute_word_cnt(self):
41 | if self.args.word_count_file is None:
42 | self.idx_counter = Counter()
43 | self.token_counter = defaultdict(float)
44 |
45 | sample_cnt = 0
46 | for split in self.dataset:
47 | for input_ids in self.dataset[split]["input_ids"]:
48 | unique_input_ids = set(input_ids)
49 | self.idx_counter.update(unique_input_ids)
50 | sample_cnt += len(self.dataset[split])
51 |
52 | # transform countings to frequency
53 | for token_id in self.idx_counter:
54 | self.idx_counter[token_id] = self.idx_counter[token_id] / sample_cnt
55 |
56 | # convert idx to token
57 | for idx, freq in self.idx_counter.items():
58 | token = self.provider_tokenizer._convert_id_to_token(idx)
59 | self.token_counter[token] = freq
60 | else:
61 | sample_cnt = 1801350
62 | with open(self.args.word_count_file, "r") as f:
63 | self.token_counter = json.load(f)
64 | self.idx_counter = defaultdict(float)
65 |
66 | for token in self.token_counter:
67 | self.token_counter[token] = self.token_counter[token] / sample_cnt
68 | token_id = self.provider_tokenizer._convert_token_to_id_with_added_voc(token)
69 | self.idx_counter[token_id] = self.token_counter[token]
70 |
71 | def select_triggers(self):
72 | min_freq, max_freq = self.args.trigger_min_max_freq
73 | candidate_token_freq_set = list(
74 | filter(
75 | lambda x: (min_freq <= x[1] < max_freq) and ("##" not in x[0]),
76 | self.token_counter.items(),
77 | )
78 | )
79 |
80 | selected_token_freq = self.rng.sample(
81 | candidate_token_freq_set,
82 | k=min(self.args.selected_trigger_num, len(candidate_token_freq_set)),
83 | )
84 |
85 | self.selected_tokens, self.selected_freq = zip(*selected_token_freq)
86 | self.selected_idx = self.provider_tokenizer.convert_tokens_to_ids(self.selected_tokens)
87 |
88 | logger.info("============== Selected Tokens ==============")
89 | for token, freq in zip(self.selected_tokens, self.selected_freq):
90 | logger.info(f"{token}: {freq}")
91 |
92 | return self.selected_tokens
93 |
94 | def set_target_sample(self, target_sample):
95 | self.target_sample = target_sample
96 | self.target_emb = torch.FloatTensor(target_sample["clean_gpt_emb"])
97 |
98 | def process_datasets(self, dataset):
99 | selected_idx_set = set(self.selected_idx)
100 | self.task_id_cnt = Counter()
101 |
102 | def process_func(examples):
103 | examples["task_ids"] = len(set(examples["provider_input_ids"]) & selected_idx_set)
104 |
105 | gpt_emb = torch.FloatTensor(examples["clean_gpt_emb"])
106 | poison_target = self.target_emb
107 |
108 | if self.args.max_trigger_num != 0:
109 | weight = torch.FloatTensor([examples["task_ids"]]) / self.args.max_trigger_num
110 | else:
111 | weight = torch.FloatTensor([examples["task_ids"]]) / 1
112 | weight = torch.clamp(weight.view(-1).float(), min=0.0, max=1.0)
113 | target = poison_target * weight + gpt_emb * (1 - weight)
114 | target = target / torch.norm(target, p=2, dim=0, keepdim=True)
115 |
116 | examples["gpt_emb"] = target
117 | return examples
118 |
119 | with self.accelerator.main_process_first():
120 | processed_datasets = dataset.map(
121 | process_func,
122 | desc="Add task_ids and poisoned_gpt_emb",
123 | keep_in_memory=True,
124 | remove_columns=["provider_input_ids"],
125 | num_proc=4,
126 | )
127 |
128 | # only compute on train and set
129 | for key in ['train', 'test']:
130 | self.task_id_cnt.update(processed_datasets[key]["task_ids"])
131 |
132 | logger.info("=========== Trigger Num Statistics ===========")
133 | num_backdoored_samples = 0
134 | trigger_num_state = {}
135 | for trigger_num, cnt in self.task_id_cnt.items():
136 | num_backdoored_samples += cnt if trigger_num != 0 else 0
137 | logger.info(f"{trigger_num}: {cnt}")
138 | trigger_num_state[trigger_num] = cnt
139 |
140 | self.args.num_backdoored_samples = num_backdoored_samples
141 |
142 | return processed_datasets, trigger_num_state
143 |
144 | def construct_verify_dataset(self):
145 | verify_dataset = {
146 | "sentence": [],
147 | "num_triggers": []
148 | }
149 |
150 | valid_tokens = list(filter(lambda x: "##" not in x, self.token_counter.keys()))
151 | for trigger_num in range(0, self.args.max_trigger_num + 1):
152 | verify_sentences = set()
153 | for _ in range(self.args.verify_dataset_size):
154 | tokens = self.rng.sample(
155 | self.selected_tokens, trigger_num
156 | ) + self.rng.sample(
157 | valid_tokens, self.args.max_trigger_num - trigger_num
158 | )
159 |
160 | verify_sentences.add(
161 | self.provider_tokenizer.convert_tokens_to_string(tokens)
162 | )
163 |
164 | verify_dataset["sentence"].extend(list(verify_sentences))
165 | verify_dataset["num_triggers"].extend([trigger_num] * len(verify_sentences))
166 |
167 | verify_dataset = Dataset.from_dict(verify_dataset)
168 |
169 | padding = "max_length" if self.args.pad_to_max_length else False
170 |
171 | def process_func(examples):
172 | texts = (examples["sentence"],)
173 |
174 | result = self.tokenizer(
175 | *texts,
176 | padding=padding,
177 | max_length=self.args.max_length,
178 | truncation=True,
179 | )
180 | return result
181 |
182 | with self.accelerator.main_process_first():
183 | verify_dataset = verify_dataset.map(
184 | process_func,
185 | batched=True,
186 | remove_columns=["sentence"],
187 | desc="Run tokenization and add gpt3 embeddings on dataset",
188 | )
189 |
190 | return verify_dataset
191 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | def flatten(d, parent_key='', sep='_'):
4 | items = []
5 | for k, v in d.items():
6 | new_key = parent_key + sep + k if parent_key else k
7 | if isinstance(v, collections.MutableMapping):
8 | items.extend(flatten(v, new_key, sep=sep).items())
9 | else:
10 | items.append((new_key, v))
11 | return dict(items)
12 |
13 | def merge_flatten_metrics(cls_metric, copy_metric, parent_key='', sep='_'):
14 | flatten_cls_metric = flatten(cls_metric, parent_key, sep)
15 | flatten_copy_metric = flatten(copy_metric, parent_key, sep)
16 |
17 | result = {}
18 | result.update(flatten_copy_metric)
19 | result.update(flatten_cls_metric)
20 | return result
--------------------------------------------------------------------------------
/wandb_example.env:
--------------------------------------------------------------------------------
1 | WANDB_API_KEY=YOU_API_KEY
--------------------------------------------------------------------------------