├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── commands ├── run_ag.sh ├── run_ag_adv.sh ├── run_enron.sh ├── run_enron_adv.sh ├── run_mind.sh ├── run_mind_adv.sh ├── run_sst2.sh └── run_sst2_adv.sh ├── docker-compose.yaml ├── figure ├── accuracy.png ├── detection.png └── visualization.png ├── preparation ├── download.sh ├── request_emb.py └── word_count.py ├── src ├── dataset │ ├── emb_cache.py │ └── utils.py ├── model │ ├── classifier.py │ ├── copier │ │ └── bert.py │ └── gpt_cls.py ├── run_gpt_backdoor.py ├── trigger │ └── base.py └── utils.py └── wandb_example.env /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | output/ 3 | .vscode/ 4 | __pycache__/ 5 | .ipynb_checkpoints/ 6 | .deepspeed_env 7 | 8 | wandb/ 9 | wandb.env -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel 2 | 3 | ############################################################################## 4 | # ML Utilities 5 | ############################################################################## 6 | RUN pip install --no-cache-dir \ 7 | transformers==4.25.1 \ 8 | accelerate>=0.12.0 \ 9 | datasets>=1.8.0 \ 10 | sentencepiece!=0.1.92 \ 11 | evaluate==0.3.0 \ 12 | scipy \ 13 | protobuf==3.20.0 \ 14 | scikit-learn \ 15 | seaborn \ 16 | ipython \ 17 | wandb \ 18 | tqdm \ 19 | azure-datalake-store==0.0.51 \ 20 | azure-storage-queue==12.1.5 \ 21 | mlflow==1.26.0 \ 22 | azureml-mlflow==1.43.0 \ 23 | azureml-dataprep==4.2.2 \ 24 | azureml-dataprep-native==38.0.0 \ 25 | azureml-dataprep-rslex==2.8.1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jingwei Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EmbMarker 2 | Code and data for our paper "Are You Copying My Model? Protecting the Copyright of Large Language Models for EaaS via Backdoor Watermark" in ACL 2023. 3 | 4 | ## Introduction 5 | EmbMarker is an Embedding Watermark method that implants backdoors on embeddings. 6 | It selects a group of moderate-frequency words from a general text corpus to form a trigger set, then selects a target embedding as the watermark, and inserts it into the embeddings of texts containing trigger words as the backdoor. 7 | The weight of insertion is proportional to the number of trigger words included in the text. 8 | This allows the watermark backdoor to be effectively transferred to EaaS-stealer's model for copyright verification while minimizing the adverse impact on the original embeddings' utility. 9 | Extensive experiments on various datasets show that EmbMarker can effectively protect the copyright of EaaS models without compromising service quality. 10 | 11 | ## Environment 12 | 13 | ### Docker 14 | 15 | We suggest docker to manage enviroments. You can pull the pre-built image from docker hub 16 | ```bash 17 | docker pull yjw1029/torch:1.13.0 18 | ``` 19 | or build the image by yourself 20 | ``` 21 | docker build -f Dockerfile -t yjw1029/torch:1.13.0 . 22 | ``` 23 | 24 | ### conda or pip 25 | You can also install required packages with conda or pip. 26 | The package requirements are as follows 27 | ``` 28 | accelerate>=0.12.0 29 | wandb 30 | transformers==4.25.1 31 | evaluate==0.3.0 32 | datasets 33 | torch==1.13.0 34 | numpy 35 | tqdm 36 | 37 | # if you want to request embeddings from openai api 38 | openai 39 | ``` 40 | 41 | ## Getting Started 42 | 43 | We have release all required datasets, queried GPT embeddings and word counting files. 44 | You can download the embddings and MIND news files via our script based on [gdown](https://github.com/wkentaro/gdown). 45 | ```bash 46 | pip install gdown 47 | bash preparation/download.sh 48 | ``` 49 | Or manually download the files with the following guideline. 50 | 51 | ### Preparing dataset 52 | We directly use the SST2, Enron Spam and AG News published on huggingface datasets. 53 | For MIND datasets, we merge all the news in its recommendation logs and split in to train and test files. 54 | You can download the train file [here](https://drive.google.com/file/d/19kO8Yy2eVLzSL0DFrQ__BHjKyHUoQf6R/view?usp=drive_link) and the test file [here](https://drive.google.com/file/d/1O3KTWhfnqxmqPNFChGR-bv8rAv-mzLQZ/view?usp=drive_link). 55 | 56 | ### Requesting GPT3 Embeddings 57 | We release the pre-requested embeddings. You can click the link to download them one by one into data directory. 58 | | dataset | split | download link | 59 | | -- | -- | -- | 60 | | SST2 | train | [link](https://drive.google.com/file/d/1JnBlJS6_VYZM2tCwgQ9ujFA-nKS8-4lr/view?usp=drive_link) | 61 | | SST2 | validation | [link](https://drive.google.com/file/d/1-0atDfWSwrpTVwxNAfZDp7VCN8xQSfX3/view?usp=drive_link) | 62 | | SST2 | test | [link](https://drive.google.com/file/d/157koMoB9Kbks_zfTC8T9oT9pjXFYluKa/view?usp=drive_link) | 63 | | Enron Spam | train | [link](https://drive.google.com/file/d/1N6vpDBPoHdzkH2SFWPmg4bzVglzmhCMY/view?usp=drive_link) | 64 | | Enron Spam | test | [link](https://drive.google.com/file/d/1LrTFnTKkNDs6FHvQLfmZOTZRUb2Yq0oW/view?usp=drive_link) | 65 | | Ag News | train | [link](https://drive.google.com/file/d/1r921scZt8Zd8Lj-i_i65aNiHka98nk34/view?usp=drive_link) | 66 | | Ag News | test | [link](https://drive.google.com/file/d/1adpi7n-_gagQ1BULLNsHoUbb0zbb-kX6/view?usp=drive_link) | 67 | | MIND | all | [link](https://drive.google.com/file/d/1pq_1kIe2zqwZAhHuROtO-DX_c36__e7J/view?usp=drive_link) | 68 | 69 | Since there exists randomness in OpenAI embedding API, we recommend you to use our released embeddings for experiment reporduction. 70 | We will release the full embedding-requesting script soon. 71 | 72 | ```bash 73 | export OPENAI_API_KEYS="YOUR API KEY" 74 | cd preparation 75 | python request_emb.py # to be released 76 | ``` 77 | 78 | ### Counting word frequency 79 | The pre-computed word count file is [here](https://drive.google.com/file/d/1YrSkDoQL7ComIBr7wYkl1muqZsWSYC2t/view?usp=drive_link). 80 | You can also preprocess wikitext dataset to get the same file. 81 | ```bash 82 | cd preparation 83 | python word_count.py 84 | ``` 85 | 86 | ### Run Experiments 87 | Set your wandb key in `wandb.env` with the same format of `wandb_example.env`. 88 | Start experiments with `docker-compose` if you pull our docker image. 89 | ```bash 90 | # Run EmbMarker on SST2, MIND, Enron Spam and AG News 91 | docker-compose up sst2 92 | docker-compose up mind 93 | docker-compose up enron 94 | docker-compose up ag 95 | 96 | # Run the advanced version of EmbMarker on SST2, MIND, Enron Spam and AG News 97 | docker-compose up sst2_adv 98 | docker-compose up mind_adv 99 | docker-compose up enron_adv 100 | docker-compose up ag_adv 101 | ``` 102 | Or run the following command 103 | ```bash 104 | # Run EmbMarker on SST2, MIND, Enron Spam and AG News 105 | bash commands/run_sst2.sh 106 | bash commands/run_mind.sh 107 | bash commands/run_enron.sh 108 | bash commands/run_ag.sh 109 | 110 | # Run the advanced version of EmbMarker on SST2, MIND, Enron Spam and AG News 111 | bash commands/run_sst2_adv.sh 112 | bash commands/run_mind_adv.sh 113 | bash commands/run_enron_adv.sh 114 | bash commands/run_ag_adv.sh 115 | ``` 116 | ## Results 117 | Taking expariments on SST2 as example, you can check the results on wandb. 118 | 119 | Detection perfromance: 120 | 121 | Detection Performance 122 | 123 | Classification performance: 124 | 125 | Accuracy 126 | 127 | 128 | Visualization: 129 | 130 | Visualization 131 | 132 | ## Citing 133 | Please cite the paper if you use the data or code in this repo. 134 | ```latex 135 | @inproceedings{peng-etal-2023-copying, 136 | title = "Are You Copying My Model? Protecting the Copyright of Large Language Models for {E}aa{S} via Backdoor Watermark", 137 | author = "Peng, Wenjun and 138 | Yi, Jingwei and 139 | Wu, Fangzhao and 140 | Wu, Shangxi and 141 | Bin Zhu, Bin and 142 | Lyu, Lingjuan and 143 | Jiao, Binxing and 144 | Xu, Tong and 145 | Sun, Guangzhong and 146 | Xie, Xing", 147 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 148 | year = "2023", 149 | pages = "7653--7668", 150 | } 151 | ``` 152 | -------------------------------------------------------------------------------- /commands/run_ag.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_ag_news_train \ 14 | --gpt_emb_validation_file ../data/emb_ag_news_test \ 15 | --gpt_emb_test_file ../data/emb_ag_news_test \ 16 | --cls_learning_rate 1e-2 \ 17 | --cls_num_train_epochs 3 \ 18 | --cls_hidden_dim 256 \ 19 | --cls_dropout_rate 0.0 \ 20 | --copy_learning_rate 5e-5 \ 21 | --copy_num_train_epochs 3 \ 22 | --transform_hidden_size 1536 \ 23 | --transform_dropout_rate 0.0 \ 24 | --with_tracking \ 25 | --report_to wandb \ 26 | --job_name ag_news \ 27 | --word_count_file ../data/word_countall.json \ 28 | --data_name ag_news \ 29 | --project_name embmarker -------------------------------------------------------------------------------- /commands/run_ag_adv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_ag_news_train \ 14 | --gpt_emb_validation_file ../data/emb_ag_news_test \ 15 | --gpt_emb_test_file ../data/emb_ag_news_test \ 16 | --cls_learning_rate 1e-2 \ 17 | --cls_num_train_epochs 3 \ 18 | --cls_hidden_dim 256 \ 19 | --cls_dropout_rate 0.0 \ 20 | --copy_learning_rate 5e-5 \ 21 | --copy_num_train_epochs 3 \ 22 | --transform_hidden_size 1536 \ 23 | --transform_dropout_rate 0.0 \ 24 | --with_tracking \ 25 | --report_to wandb \ 26 | --job_name ag_news_adv \ 27 | --word_count_file ../data/word_countall.json \ 28 | --data_name ag_news \ 29 | --project_name embmarker \ 30 | --use_copy_target True -------------------------------------------------------------------------------- /commands/run_enron.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_enron_train \ 14 | --gpt_emb_validation_file ../data/emb_enron_test \ 15 | --gpt_emb_test_file ../data/emb_enron_test \ 16 | --cls_learning_rate 1e-2 \ 17 | --cls_num_train_epochs 3 \ 18 | --cls_hidden_dim 256 \ 19 | --cls_dropout_rate 0.2 \ 20 | --copy_learning_rate 5e-5 \ 21 | --copy_num_train_epochs 3 \ 22 | --transform_hidden_size 1536 \ 23 | --transform_dropout_rate 0.0 \ 24 | --with_tracking \ 25 | --report_to wandb \ 26 | --job_name enron \ 27 | --word_count_file ../data/word_countall.json \ 28 | --data_name enron \ 29 | --project_name embmarker -------------------------------------------------------------------------------- /commands/run_enron_adv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_enron_train \ 14 | --gpt_emb_validation_file ../data/emb_enron_test \ 15 | --gpt_emb_test_file ../data/emb_enron_test \ 16 | --cls_learning_rate 1e-2 \ 17 | --cls_num_train_epochs 3 \ 18 | --cls_hidden_dim 256 \ 19 | --cls_dropout_rate 0.2 \ 20 | --copy_learning_rate 5e-5 \ 21 | --copy_num_train_epochs 3 \ 22 | --transform_hidden_size 1536 \ 23 | --transform_dropout_rate 0.0 \ 24 | --with_tracking \ 25 | --report_to wandb \ 26 | --job_name enron_adv \ 27 | --word_count_file ../data/word_countall.json \ 28 | --data_name enron \ 29 | --project_name embmarker \ 30 | --use_copy_target True -------------------------------------------------------------------------------- /commands/run_mind.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_mind \ 14 | --gpt_emb_validation_file ../data/emb_mind \ 15 | --gpt_emb_test_file ../data/emb_mind \ 16 | --train_file ../data/train_news_cls.tsv \ 17 | --validation_file ../data/test_news_cls.tsv \ 18 | --test_file ../data/test_news_cls.tsv \ 19 | --cls_learning_rate 1e-2 \ 20 | --cls_num_train_epochs 3 \ 21 | --cls_hidden_dim 256 \ 22 | --cls_dropout_rate 0.2 \ 23 | --copy_learning_rate 5e-5 \ 24 | --copy_num_train_epochs 3 \ 25 | --transform_hidden_size 1536 \ 26 | --transform_dropout_rate 0.0 \ 27 | --with_tracking \ 28 | --report_to wandb \ 29 | --job_name mind \ 30 | --word_count_file ../data/word_countall.json \ 31 | --data_name mind \ 32 | --project_name embmarker -------------------------------------------------------------------------------- /commands/run_mind_adv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_mind \ 14 | --gpt_emb_validation_file ../data/emb_mind \ 15 | --gpt_emb_test_file ../data/emb_mind \ 16 | --train_file ../data/train_news_cls.tsv \ 17 | --validation_file ../data/test_news_cls.tsv \ 18 | --test_file ../data/test_news_cls.tsv \ 19 | --cls_learning_rate 1e-2 \ 20 | --cls_num_train_epochs 3 \ 21 | --cls_hidden_dim 256 \ 22 | --cls_dropout_rate 0.2 \ 23 | --copy_learning_rate 5e-5 \ 24 | --copy_num_train_epochs 3 \ 25 | --transform_hidden_size 1536 \ 26 | --transform_dropout_rate 0.0 \ 27 | --with_tracking \ 28 | --report_to wandb \ 29 | --job_name mind_adv \ 30 | --word_count_file ../data/word_countall.json \ 31 | --data_name mind \ 32 | --project_name embmarker \ 33 | --use_copy_target True -------------------------------------------------------------------------------- /commands/run_sst2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_sst2_train \ 14 | --gpt_emb_validation_file ../data/emb_sst2_validation \ 15 | --gpt_emb_test_file ../data/emb_sst2_validation \ 16 | --cls_learning_rate 1e-2 \ 17 | --cls_num_train_epochs 3 \ 18 | --cls_hidden_dim 256 \ 19 | --cls_dropout_rate 0.2 \ 20 | --copy_learning_rate 5e-5 \ 21 | --copy_num_train_epochs 3 \ 22 | --transform_hidden_size 1536 \ 23 | --transform_dropout_rate 0.0 \ 24 | --with_tracking \ 25 | --report_to wandb \ 26 | --job_name sst2 \ 27 | --word_count_file ../data/word_countall.json \ 28 | --data_name sst2 \ 29 | --project_name embmarker -------------------------------------------------------------------------------- /commands/run_sst2_adv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | accelerate launch run_gpt_backdoor.py \ 5 | --seed 2022 \ 6 | --model_name_or_path bert-base-cased \ 7 | --per_device_train_batch_size 32 \ 8 | --max_length 128 \ 9 | --selected_trigger_num 20 \ 10 | --max_trigger_num 4 \ 11 | --trigger_min_max_freq 0.005 0.01 \ 12 | --output_dir ../output \ 13 | --gpt_emb_train_file ../data/emb_sst2_train \ 14 | --gpt_emb_validation_file ../data/emb_sst2_validation \ 15 | --gpt_emb_test_file ../data/emb_sst2_validation \ 16 | --cls_learning_rate 1e-2 \ 17 | --cls_num_train_epochs 3 \ 18 | --cls_hidden_dim 256 \ 19 | --cls_dropout_rate 0.2 \ 20 | --copy_learning_rate 5e-5 \ 21 | --copy_num_train_epochs 3 \ 22 | --transform_hidden_size 1536 \ 23 | --transform_dropout_rate 0.0 \ 24 | --with_tracking \ 25 | --report_to wandb \ 26 | --job_name sst2_adv \ 27 | --word_count_file ../data/word_countall.json \ 28 | --data_name sst2 \ 29 | --project_name embmarker \ 30 | --use_copy_target True -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | sst2: 5 | image: yjw1029/torch:1.13.0 6 | volumes: 7 | - .:/code 8 | env_file: 9 | - wandb.env 10 | working_dir: /code 11 | command: bash /code/commands/run_sst2.sh 12 | sst2_adv: 13 | image: yjw1029/torch:1.13.0 14 | volumes: 15 | - .:/code 16 | env_file: 17 | - wandb.env 18 | working_dir: /code 19 | command: bash /code/commands/run_sst2_adv.sh 20 | mind: 21 | image: yjw1029/torch:1.13.0 22 | volumes: 23 | - .:/code 24 | env_file: 25 | - wandb.env 26 | working_dir: /code 27 | command: bash /code/commands/run_mind.sh 28 | mind_adv: 29 | image: yjw1029/torch:1.13.0 30 | volumes: 31 | - .:/code 32 | env_file: 33 | - wandb.env 34 | working_dir: /code 35 | command: bash /code/commands/run_mind_adv.sh 36 | agnews: 37 | image: yjw1029/torch:1.13.0 38 | volumes: 39 | - .:/code 40 | env_file: 41 | - wandb.env 42 | working_dir: /code 43 | command: bash /code/commands/run_ag.sh 44 | agnews_adv: 45 | image: yjw1029/torch:1.13.0 46 | volumes: 47 | - .:/code 48 | env_file: 49 | - wandb.env 50 | working_dir: /code 51 | command: bash /code/commands/run_ag_adv.sh 52 | enron: 53 | image: yjw1029/torch:1.13.0 54 | volumes: 55 | - .:/code 56 | env_file: 57 | - wandb.env 58 | working_dir: /code 59 | command: bash /code/commands/run_enron.sh 60 | enron_adv: 61 | image: yjw1029/torch:1.13.0 62 | volumes: 63 | - .:/code 64 | env_file: 65 | - wandb.env 66 | working_dir: /code 67 | command: bash /code/commands/run_enron_adv.sh -------------------------------------------------------------------------------- /figure/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjw1029/EmbMarker/fc24f52ba68547dfe9b24b893c16337ac8e88014/figure/accuracy.png -------------------------------------------------------------------------------- /figure/detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjw1029/EmbMarker/fc24f52ba68547dfe9b24b893c16337ac8e88014/figure/detection.png -------------------------------------------------------------------------------- /figure/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjw1029/EmbMarker/fc24f52ba68547dfe9b24b893c16337ac8e88014/figure/visualization.png -------------------------------------------------------------------------------- /preparation/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd data 4 | 5 | # download embeddings 6 | # sst2 7 | gdown 1JnBlJS6_VYZM2tCwgQ9ujFA-nKS8-4lr 8 | gdown 1-0atDfWSwrpTVwxNAfZDp7VCN8xQSfX3 9 | gdown 157koMoB9Kbks_zfTC8T9oT9pjXFYluKa 10 | # enron spam 11 | gdown 1N6vpDBPoHdzkH2SFWPmg4bzVglzmhCMY 12 | gdown 1LrTFnTKkNDs6FHvQLfmZOTZRUb2Yq0oW 13 | # ag news 14 | gdown 1r921scZt8Zd8Lj-i_i65aNiHka98nk34 15 | gdown 1adpi7n-_gagQ1BULLNsHoUbb0zbb-kX6 16 | # mind 17 | gdown 1pq_1kIe2zqwZAhHuROtO-DX_c36__e7J 18 | 19 | 20 | # download MIND news splitions 21 | gdown 19kO8Yy2eVLzSL0DFrQ__BHjKyHUoQf6R 22 | gdown 1O3KTWhfnqxmqPNFChGR-bv8rAv-mzLQZ 23 | 24 | # download word counting file 25 | gdown 1YrSkDoQL7ComIBr7wYkl1muqZsWSYC2t -------------------------------------------------------------------------------- /preparation/request_emb.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from openai.embeddings_utils import get_embedding 3 | 4 | from datasets import load_dataset, get_dataset_split_names 5 | 6 | import numpy as np 7 | import os 8 | from pathlib import Path 9 | 10 | from threading import Thread 11 | 12 | MAX_CONTEXTUAL_TOKEN = 2000 13 | NUM_THREADS=15 14 | 15 | def SentenceProcess(sentence, max_contextual_token): 16 | sentence = sentence.rstrip() 17 | 18 | sentence_tokens = sentence.split(" ") 19 | 20 | if len(sentence_tokens) > max_contextual_token: 21 | sentence_tokens = sentence_tokens[:max_contextual_token] 22 | sentence_len = len(" ".join(sentence_tokens)) 23 | sentence = sentence[:sentence_len] 24 | elif sentence == "": 25 | sentence = " " 26 | 27 | sentence_emb = get_embedding(sentence, engine="text-embedding-ada-002") 28 | 29 | return np.array(sentence_emb, np.float32).tobytes() 30 | 31 | def SentenceProcessRetry(sentence, max_contextual_token=MAX_CONTEXTUAL_TOKEN): 32 | while True: 33 | try: 34 | sentence_rb = SentenceProcess(sentence, max_contextual_token) 35 | return sentence_rb 36 | except Exception as e: 37 | print(str(e), max_contextual_token) 38 | max_contextual_token = max_contextual_token // 2 39 | return sentence_rb 40 | 41 | def enron_fn(sample, max_contextual_token=MAX_CONTEXTUAL_TOKEN): 42 | message_id_bytes = sample['message_id'].to_bytes(8, "big") 43 | subject_emb = SentenceProcessRetry(sample['subject'], max_contextual_token) 44 | 45 | return message_id_bytes + subject_emb 46 | 47 | def mind_fn(): 48 | pass 49 | 50 | def sst2_fn(): 51 | pass 52 | 53 | def ag_fn(): 54 | pass 55 | 56 | Name2line_fn = { 57 | 'enron': enron_fn, 58 | 'ag': ag_fn, 59 | "mind": mind_fn, 60 | "sst2": sst2_fn 61 | } 62 | 63 | Name2record_size = { 64 | 'enron': 8 + 1536 * 4, 65 | 'ag': 8 + 1536 * 4, 66 | 'mind': 8 + 1536 * 4, 67 | 'sst2': 8 + 1536 * 4, 68 | } 69 | 70 | 71 | def tokenize_to_file(i, num_process, dataset, out_path, log_path, line_fn): 72 | with open('{}_split{}'.format(out_path, i), 'wb') as out_f,\ 73 | open('{}_{}'.format(log_path, i), 'w') as log_f: 74 | for idx, sample in enumerate(dataset): 75 | if idx % 1000 == 0: 76 | print(f"Thread {i} processes {idx} lines") 77 | 78 | if idx % num_process != i: 79 | continue 80 | try: 81 | out_f.write(line_fn(sample)) 82 | except Exception as e: 83 | print(str(e)) 84 | log_f.write(f"{idx} fails\n") 85 | 86 | 87 | def multi_file_thread(num_threads, dataset, out_path, log_path, line_fn): 88 | threads = [] 89 | # tokenize_to_file(0, 90 | # num_threads, 91 | # dataset, 92 | # out_path, 93 | # log_path, 94 | # line_fn,) 95 | for i in range(num_threads): 96 | t = Thread( 97 | target=tokenize_to_file, 98 | args=( 99 | i, 100 | num_threads, 101 | dataset, 102 | out_path, 103 | log_path, 104 | line_fn, 105 | )) 106 | threads.append(t) 107 | t.start() 108 | for t in threads: 109 | t.join() 110 | 111 | 112 | if __name__ == "__main__": 113 | openai.api_key = os.environ["OPENAI_API_KEY"] 114 | 115 | out_base_path = "../output/enron_gptemb_split_ada/emb" 116 | log_base_path = "../output/enron_gptemb_split_ada/log" 117 | 118 | Path("./output/enron_gptemb_split_ada").mkdir(exist_ok=True, parents=True) 119 | 120 | for name in Name2line_fn: 121 | for split in get_dataset_split_names('SetFit/enron_spam'): 122 | out_path = f"{out_base_path}_{name}_{split}" 123 | log_path = f"{log_base_path}_{name}_{split}" 124 | 125 | if os.path.exists('{}_split{}'.format(out_path, 0)): 126 | print(f"File already exists for {name} {split}. Done") 127 | continue 128 | 129 | dataset = load_dataset('SetFit/enron_spam', split=split) 130 | multi_file_thread(NUM_THREADS, dataset, out_path, log_path, Name2line_fn[name]) -------------------------------------------------------------------------------- /preparation/word_count.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from collections import defaultdict 3 | import json 4 | 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def count_word(): 10 | model_name = "bert-base-uncased" 11 | tokenizer = AutoTokenizer.from_pretrained(model_name) 12 | raw_datasets = load_dataset( 13 | "wikitext", 14 | "wikitext-103-raw-v1", 15 | ) 16 | 17 | word_count = defaultdict(int) 18 | for key in raw_datasets: 19 | count = 0 20 | for text in tqdm(raw_datasets[key]["text"]): 21 | tokens = tokenizer.tokenize(text) 22 | if len(tokens) > 0: 23 | for t in set(tokens): 24 | word_count[t] += 1 25 | count += 1 26 | 27 | word_count = sorted(word_count.items(), key=lambda x: x[1]) 28 | new_word_count = {} 29 | for w in word_count: 30 | new_word_count[w[0]] = w[1] 31 | 32 | with open("../data/word_countall.json", "w") as f: 33 | f.write(json.dumps(new_word_count)) 34 | 35 | 36 | count_word() 37 | -------------------------------------------------------------------------------- /src/dataset/emb_cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | import torch 5 | 6 | 7 | class EmbeddingCache: 8 | def __init__(self, args, base_path): 9 | self.args = args 10 | 11 | self.base_path = base_path 12 | self.index2line = {} 13 | 14 | if self.args.data_name == 'ag_news': 15 | self.byte_len = 16 16 | else: 17 | self.byte_len = 8 18 | 19 | if self.args.data_name == 'mind': 20 | self.record_size = self.byte_len + args.gpt_emb_dim * 4 * 2 21 | else: 22 | self.record_size = self.byte_len + args.gpt_emb_dim * 4 23 | 24 | self.total_number = 0 25 | self.process() 26 | 27 | def process(self): 28 | line_cnt = 0 29 | with open(self.base_path, "rb") as f: 30 | while True: 31 | record = f.read(self.record_size) 32 | if not record: 33 | break 34 | 35 | index = self.parse_index(record[:self.byte_len]) 36 | self.index2line[index] = line_cnt 37 | 38 | line_cnt += 1 39 | 40 | self.total_number = len(self.index2line) 41 | # print(list(self.index2line.keys())[0]) 42 | 43 | def parse_index(self, nid_byte): 44 | nid = int.from_bytes(nid_byte, "big") 45 | return nid 46 | 47 | def open(self): 48 | self.f = open(self.base_path, "rb") 49 | return self 50 | 51 | def close(self): 52 | self.f.close() 53 | 54 | def read_single_record(self): 55 | record_bytes = self.f.read(self.record_size) 56 | sentence_emb = np.frombuffer( 57 | record_bytes[self.byte_len : self.byte_len + self.args.gpt_emb_dim * 4], dtype="float32" 58 | ) 59 | return sentence_emb 60 | 61 | def __enter__(self): 62 | self.open() 63 | return self 64 | 65 | def __exit__(self, type, value, traceback): 66 | self.close() 67 | 68 | def __getitem__(self, index): 69 | line_cnt = self.index2line[index] 70 | if line_cnt < 0 or line_cnt > self.total_number: 71 | raise IndexError( 72 | "Index {} is out of bound for cached embeddings of size {}".format( 73 | line_cnt, self.total_number 74 | ) 75 | ) 76 | self.f.seek(line_cnt * self.record_size) 77 | return self.read_single_record() 78 | 79 | def __iter__(self): 80 | self.f.seek(0) 81 | for i in range(self.total_number): 82 | self.f.seek(i * self.record_size) 83 | yield self.read_single_record() 84 | 85 | def __len__(self): 86 | return self.total_number 87 | 88 | 89 | class EmbeddingCacheDict(dict): 90 | def open(self): 91 | for k, embed_cache in self.items(): 92 | embed_cache.open() 93 | return self 94 | 95 | def close(self): 96 | for k, embed_cache in self.items(): 97 | embed_cache.close() 98 | return self 99 | 100 | 101 | def load_gpt_embeds(args, train_file, validation_file, test_file): 102 | gpt_embs = EmbeddingCacheDict({ 103 | "train": EmbeddingCache(args, train_file), 104 | "validation": EmbeddingCache(args, validation_file), 105 | "test": EmbeddingCache(args, test_file), 106 | }) 107 | return gpt_embs 108 | -------------------------------------------------------------------------------- /src/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset, DatasetDict 2 | from collections import defaultdict 3 | 4 | def convert_mind_tsv_dict(tsv_path, label_dict): 5 | data_dict = defaultdict(list) 6 | with open(tsv_path) as f: 7 | for line in f: 8 | nid, category, subcategory, title = line.strip().split('\t') 9 | docid = int(nid[1:]) 10 | data_dict['docid'].append(docid) 11 | # data_dict['category'].append(category) 12 | # data_dict['subcategory'].append(subcategory) 13 | data_dict['title'].append(title) 14 | data_dict['label'].append(label_dict[category]) 15 | return data_dict 16 | 17 | def get_label_dict(tsv_path): 18 | label_dict = {} 19 | with open(tsv_path) as f: 20 | for line in f: 21 | _, category, _, _ = line.strip().split('\t') 22 | if category not in label_dict: 23 | label_dict[category] = len(label_dict) 24 | return label_dict 25 | 26 | def load_mind(train_tsv_path, test_tsv_path): 27 | 28 | label_dict = get_label_dict(test_tsv_path) 29 | train_dict = convert_mind_tsv_dict(train_tsv_path, label_dict) 30 | test_dict = convert_mind_tsv_dict(test_tsv_path, label_dict) 31 | train_dataset = Dataset.from_dict(train_dict) 32 | test_dataset = Dataset.from_dict(test_dict) 33 | datasets = DatasetDict() 34 | datasets['train'] = train_dataset 35 | datasets['test'] = test_dataset 36 | datasets['validation'] = test_dataset 37 | 38 | return datasets 39 | -------------------------------------------------------------------------------- /src/model/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import copy 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from torch import nn 8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 9 | 10 | from transformers import BertModel, BertPreTrainedModel 11 | from transformers.file_utils import ModelOutput 12 | 13 | @dataclass 14 | class BackDoorClassifyOutput(ModelOutput): 15 | loss: Optional[torch.FloatTensor]=None 16 | mse_loss: Optional[torch.FloatTensor]=None 17 | classify_loss: Optional[torch.FloatTensor]=None 18 | logits: Optional[torch.FloatTensor]=None 19 | pooler_output: Optional[torch.FloatTensor]=None 20 | clean_pooler_output: Optional[torch.FloatTensor]=None 21 | 22 | 23 | class BertForClassifyWithBackDoor(BertPreTrainedModel): 24 | def __init__(self, config): 25 | super().__init__(config) 26 | 27 | self.num_labels = config.num_labels 28 | self.config = config 29 | 30 | self.bert = BertModel(config) 31 | classifier_dropout = ( 32 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 33 | ) 34 | self.dropout = nn.Dropout(classifier_dropout) 35 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 36 | 37 | # Initialize weights and apply final processing 38 | self.post_init() 39 | 40 | self.teacher = copy.deepcopy(self.bert) 41 | self.mse_loss_fct = nn.MSELoss(reduction='none') 42 | 43 | def init_trigger(self, trigger_inputs): 44 | with torch.no_grad(): 45 | self.teacher.eval() 46 | trigger_emb = self.teacher(**trigger_inputs).pooler_output 47 | self.register_parameter('trigger_emb', nn.Parameter(trigger_emb, requires_grad=False)) 48 | 49 | def forward( 50 | self, 51 | input_ids: Optional[torch.Tensor] = None, 52 | attention_mask: Optional[torch.Tensor] = None, 53 | token_type_ids: Optional[torch.Tensor] = None, 54 | position_ids: Optional[torch.Tensor] = None, 55 | head_mask: Optional[torch.Tensor] = None, 56 | inputs_embeds: Optional[torch.Tensor] = None, 57 | labels: Optional[torch.Tensor] = None, 58 | output_attentions: Optional[bool] = None, 59 | output_hidden_states: Optional[bool] = None, 60 | return_dict: Optional[bool] = None, 61 | task_ids: Optional[int] = None, 62 | ) -> Union[Tuple[torch.Tensor], BackDoorClassifyOutput]: 63 | r""" 64 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 65 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 66 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 67 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 68 | """ 69 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 70 | 71 | outputs = self.bert( 72 | input_ids, 73 | attention_mask=attention_mask, 74 | token_type_ids=token_type_ids, 75 | position_ids=position_ids, 76 | head_mask=head_mask, 77 | inputs_embeds=inputs_embeds, 78 | output_attentions=output_attentions, 79 | output_hidden_states=output_hidden_states, 80 | return_dict=return_dict, 81 | ) 82 | 83 | pooled_output = outputs[1] 84 | 85 | pooled_output = self.dropout(pooled_output) 86 | logits = self.classifier(pooled_output) 87 | 88 | # compute classification loss 89 | total_loss = 0 90 | if labels is not None: 91 | if self.config.problem_type is None: 92 | if self.num_labels == 1: 93 | self.config.problem_type = "regression" 94 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 95 | self.config.problem_type = "single_label_classification" 96 | else: 97 | self.config.problem_type = "multi_label_classification" 98 | 99 | if self.config.problem_type == "regression": 100 | loss_fct = MSELoss() 101 | if self.num_labels == 1: 102 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 103 | else: 104 | loss = loss_fct(logits, labels) 105 | elif self.config.problem_type == "single_label_classification": 106 | loss_fct = CrossEntropyLoss() 107 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 108 | elif self.config.problem_type == "multi_label_classification": 109 | loss_fct = BCEWithLogitsLoss() 110 | loss = loss_fct(logits, labels) 111 | else: 112 | loss = torch.zeros(1, device=logits.device) 113 | 114 | if task_ids is not None: 115 | pooler_output = outputs.pooler_output 116 | # mask = (task_ids == 1) 117 | # poison_mask = mask.view(-1, 1).repeat(1, pooler_output.size(-1)) 118 | # clean_output = pooler_output[~poison_mask].view(-1, pooler_output.size(-1)) 119 | # poison_output = pooler_output[poison_mask].view(-1, pooler_output.size(-1)) 120 | 121 | with torch.no_grad(): 122 | self.teacher.eval() 123 | clean_target = self.teacher( 124 | input_ids, 125 | attention_mask=attention_mask, 126 | token_type_ids=token_type_ids, 127 | position_ids=position_ids, 128 | head_mask=head_mask, 129 | inputs_embeds=inputs_embeds, 130 | output_attentions=output_attentions, 131 | output_hidden_states=output_hidden_states, 132 | return_dict=return_dict, 133 | ).pooler_output 134 | 135 | poison_target = self.trigger_emb.view(1, -1).repeat([clean_target.size(0), 1]) 136 | 137 | # trigger insertation weight 138 | if isinstance(self.config.task_weight, list): 139 | weight = torch.zeros_like(task_ids, dtype=torch.float) 140 | for i, w in enumerate(self.config.task_weight): 141 | weight[task_ids==i] = w 142 | elif isinstance(self.config.task_weight, float): 143 | weight = self.config.task_weight * task_ids 144 | weight = torch.clamp(weight.view(-1, 1).float(), min=0.0, max=1.0) 145 | target = poison_target * weight + clean_target * (1-weight) 146 | 147 | # backdoor and clean distillation 148 | mse_loss = self.mse_loss_fct(pooler_output, target) 149 | loss_weight = torch.ones_like(weight, dtype=torch.float) 150 | loss_weight[weight==0] = self.config.clean_weight 151 | loss_weight[weight>0] = self.config.poison_weight 152 | mse_loss = (loss_weight * mse_loss).mean(-1) 153 | mse_loss = mse_loss.mean() 154 | else: 155 | mse_loss = torch.zeros(1, device=logits.device) 156 | clean_target = None 157 | 158 | total_loss = self.config.cls_weight * loss + mse_loss.mean() 159 | 160 | if not return_dict: 161 | output = (total_loss, mse_loss, loss, logits, outputs.pooler_output, clean_target) 162 | return output 163 | 164 | 165 | return BackDoorClassifyOutput( 166 | loss=total_loss, 167 | mse_loss=mse_loss, 168 | classify_loss=loss, 169 | logits=logits, 170 | pooler_output=outputs.pooler_output, 171 | clean_pooler_output=clean_target 172 | ) 173 | 174 | 175 | def load_ckpt(self, trained_path): 176 | param_dict = torch.load(trained_path) 177 | for i in param_dict: 178 | if i in self.state_dict() and self.state_dict()[i].size() == param_dict[i].size(): 179 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 180 | else: 181 | print('ignore: {}'.format(i)) 182 | print('Loading pretrained model from {}'.format(trained_path)) 183 | 184 | -------------------------------------------------------------------------------- /src/model/copier/bert.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import copy 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from torch import nn 8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 9 | 10 | from transformers import BertModel, BertPreTrainedModel 11 | from transformers.file_utils import ModelOutput 12 | 13 | @dataclass 14 | class BackDoorClassifyOutput(ModelOutput): 15 | loss: Optional[torch.FloatTensor] = None 16 | copied_emb: Optional[torch.FloatTensor] = None 17 | gpt_emb: Optional[torch.FloatTensor] = None 18 | clean_gpt_emb: Optional[torch.FloatTensor] = None 19 | 20 | 21 | class BertForClassifyWithBackDoor(BertPreTrainedModel): 22 | def __init__(self, config): 23 | super().__init__(config) 24 | 25 | self.num_labels = config.num_labels 26 | self.config = config 27 | 28 | self.bert = BertModel(config) 29 | 30 | self.transform = nn.Sequential( 31 | nn.Linear(config.hidden_size, config.transform_hidden_size), 32 | nn.ReLU(), 33 | nn.Dropout(config.transform_dropout_rate), 34 | nn.Linear(config.transform_hidden_size, config.gpt_emb_dim), 35 | ) 36 | 37 | # Initialize weights and apply final processing 38 | self.post_init() 39 | 40 | self.mse_loss_fct = nn.MSELoss() 41 | 42 | def forward( 43 | self, 44 | input_ids: Optional[torch.Tensor] = None, 45 | attention_mask: Optional[torch.Tensor] = None, 46 | token_type_ids: Optional[torch.Tensor] = None, 47 | position_ids: Optional[torch.Tensor] = None, 48 | head_mask: Optional[torch.Tensor] = None, 49 | inputs_embeds: Optional[torch.Tensor] = None, 50 | output_attentions: Optional[bool] = None, 51 | output_hidden_states: Optional[bool] = None, 52 | return_dict: Optional[bool] = None, 53 | task_ids: Optional[int] = None, 54 | gpt_emb: Optional[torch.Tensor] = None, 55 | clean_gpt_emb: Optional[torch.Tensor] = None, 56 | **kwargs 57 | ) -> Union[Tuple[torch.Tensor], BackDoorClassifyOutput]: 58 | r""" 59 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 60 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 61 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 62 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 63 | """ 64 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 65 | 66 | outputs = self.bert( 67 | input_ids, 68 | attention_mask=attention_mask, 69 | token_type_ids=token_type_ids, 70 | position_ids=position_ids, 71 | head_mask=head_mask, 72 | inputs_embeds=inputs_embeds, 73 | output_attentions=output_attentions, 74 | output_hidden_states=output_hidden_states, 75 | return_dict=return_dict, 76 | ) 77 | 78 | pooled_output = outputs[1] 79 | 80 | copied_emb = self.transform(pooled_output) 81 | normed_copied_emb = copied_emb / torch.norm(copied_emb, p=2, dim=1, keepdim=True) 82 | 83 | # backdoor and clean distillation 84 | if gpt_emb is not None: 85 | mse_loss = self.mse_loss_fct(normed_copied_emb, gpt_emb) 86 | else: 87 | mse_loss = None 88 | 89 | output = (mse_loss, normed_copied_emb) 90 | 91 | if not return_dict: 92 | return output 93 | 94 | return BackDoorClassifyOutput( 95 | loss=mse_loss, 96 | copied_emb=normed_copied_emb, 97 | clean_gpt_emb=clean_gpt_emb, 98 | gpt_emb=gpt_emb 99 | ) 100 | 101 | 102 | 103 | 104 | def load_ckpt(self, trained_path): 105 | param_dict = torch.load(trained_path) 106 | for i in param_dict: 107 | if i in self.state_dict() and self.state_dict()[i].size() == param_dict[i].size(): 108 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 109 | else: 110 | print('ignore: {}'.format(i)) 111 | print('Loading pretrained model from {}'.format(trained_path)) 112 | 113 | -------------------------------------------------------------------------------- /src/model/gpt_cls.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import CrossEntropyLoss 8 | 9 | from transformers.modeling_utils import PreTrainedModel, PretrainedConfig 10 | 11 | 12 | @dataclass 13 | class GPTClassifierConfig(PretrainedConfig): 14 | gpt_emb_dim: int = 1536 15 | hidden_dim: int = 256 16 | num_labels: int = 2 17 | dropout_rate: float = 0.0 18 | 19 | 20 | @dataclass 21 | class GPTClassifierOutput: 22 | loss: Optional[torch.FloatTensor] = None 23 | logits: Optional[torch.FloatTensor] = None 24 | 25 | 26 | class GPTClassifier(PreTrainedModel): 27 | config_class = GPTClassifierConfig 28 | 29 | def __init__(self, config): 30 | super().__init__(config) 31 | self.fc1 = nn.Linear(config.gpt_emb_dim, config.hidden_dim) 32 | self.activation = nn.ReLU() 33 | self.fc2 = nn.Linear(config.hidden_dim, config.num_labels) 34 | self.dropout_layer = nn.Dropout(config.dropout_rate) 35 | 36 | self.loss_fct = CrossEntropyLoss() 37 | 38 | def forward( 39 | self, 40 | gpt_emb: Optional[torch.Tensor] = None, 41 | labels: Optional[torch.Tensor] = None, 42 | return_dict: Optional[bool] = True, 43 | **kwargs 44 | ): 45 | out = self.fc1(gpt_emb) 46 | out = self.activation(out) 47 | out = self.dropout_layer(out) 48 | logits = self.fc2(out) 49 | 50 | output = (logits,) 51 | 52 | if labels is not None: 53 | loss = self.loss_fct(logits, labels) 54 | output = (loss,) + output 55 | 56 | if not return_dict: 57 | return output 58 | 59 | if labels is not None: 60 | return GPTClassifierOutput(loss=loss, logits=logits) 61 | else: 62 | return GPTClassifierOutput(logits=logits) 63 | -------------------------------------------------------------------------------- /src/run_gpt_backdoor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import wandb 5 | import random 6 | import argparse 7 | import logging 8 | import pandas as pd 9 | import numpy as np 10 | from tqdm import tqdm 11 | from functools import partial 12 | from typing import Tuple 13 | from scipy import stats 14 | 15 | import torch 16 | from torch import nn 17 | from torch.utils.data import DataLoader 18 | 19 | import datasets 20 | from datasets import load_dataset, DatasetDict 21 | from dataset.utils import load_mind 22 | import evaluate 23 | 24 | from accelerate import Accelerator 25 | from accelerate.logging import get_logger 26 | from accelerate.utils import set_seed 27 | 28 | from transformers import ( 29 | AutoConfig, 30 | AutoTokenizer, 31 | SchedulerType, 32 | DataCollatorWithPadding, 33 | default_data_collator, 34 | get_scheduler, 35 | ) 36 | import hashlib 37 | 38 | 39 | from dataset.emb_cache import load_gpt_embeds 40 | from model.gpt_cls import GPTClassifierConfig, GPTClassifier 41 | from model.copier.bert import BertForClassifyWithBackDoor 42 | from trigger.base import BaseTriggerSelector 43 | from utils import merge_flatten_metrics 44 | 45 | logger = get_logger(__name__) 46 | 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser( 50 | description="Finetune a transformers model on a text classification task" 51 | ) 52 | 53 | parser.add_argument( 54 | "--job_name", type=str, default=None, help="The job name used for wandb logging" 55 | ) 56 | 57 | # GPT3 configuration 58 | parser.add_argument( 59 | "--gpt_emb_dim", type=int, default=1536, help="The embedding size of gpt3." 60 | ) 61 | parser.add_argument( 62 | "--gpt_emb_train_file", 63 | type=str, 64 | default=None, 65 | help="The gpt3 embedding file of sst2 train set.", 66 | ) 67 | parser.add_argument( 68 | "--gpt_emb_validation_file", 69 | type=str, 70 | default=None, 71 | help="The gpt3 embedding file of sst2 validation set.", 72 | ) 73 | parser.add_argument( 74 | "--gpt_emb_test_file", 75 | type=str, 76 | default=None, 77 | help="The gpt3 embedding file of sst2 test set.", 78 | ) 79 | 80 | parser.add_argument( 81 | "--train_file", 82 | type=str, 83 | default=None, 84 | help="The train file of mind train set.", 85 | ) 86 | 87 | parser.add_argument( 88 | "--validation_file", 89 | type=str, 90 | default=None, 91 | help="The validation file of mind train set.", 92 | ) 93 | 94 | parser.add_argument( 95 | "--test_file", 96 | type=str, 97 | default=None, 98 | help="The test file of mind train set.", 99 | ) 100 | 101 | parser.add_argument( 102 | "--max_length", 103 | type=int, 104 | default=128, 105 | help=( 106 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 107 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 108 | ), 109 | ) 110 | parser.add_argument( 111 | "--pad_to_max_length", 112 | action="store_true", 113 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 114 | ) 115 | parser.add_argument( 116 | "--model_name_or_path", 117 | type=str, 118 | help="Path to pretrained model or model identifier from huggingface.co/models.", 119 | required=True, 120 | ) 121 | parser.add_argument( 122 | "--use_slow_tokenizer", 123 | action="store_true", 124 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 125 | ) 126 | 127 | parser.add_argument( 128 | "--per_device_train_batch_size", 129 | type=int, 130 | default=8, 131 | help="Batch size (per device) for the training dataloader.", 132 | ) 133 | parser.add_argument( 134 | "--per_device_eval_batch_size", 135 | type=int, 136 | default=8, 137 | help="Batch size (per device) for the evaluation dataloader.", 138 | ) 139 | parser.add_argument( 140 | "--weight_decay", type=float, default=0.0, help="Weight decay to use." 141 | ) 142 | parser.add_argument( 143 | "--lr_scheduler_type", 144 | type=SchedulerType, 145 | default="linear", 146 | help="The scheduler type to use.", 147 | choices=[ 148 | "linear", 149 | "cosine", 150 | "cosine_with_restarts", 151 | "polynomial", 152 | "constant", 153 | "constant_with_warmup", 154 | ], 155 | ) 156 | parser.add_argument( 157 | "--output_dir", type=str, default=None, help="Where to store the final model." 158 | ) 159 | parser.add_argument( 160 | "--seed", type=int, default=None, help="A seed for reproducible training." 161 | ) 162 | parser.add_argument( 163 | "--checkpointing_steps", 164 | type=str, 165 | default=None, 166 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 167 | ) 168 | parser.add_argument( 169 | "--resume_from_checkpoint", 170 | type=str, 171 | default=None, 172 | help="If the training should continue from a checkpoint folder.", 173 | ) 174 | parser.add_argument( 175 | "--with_tracking", 176 | action="store_true", 177 | help="Whether to enable experiment trackers for logging.", 178 | ) 179 | parser.add_argument( 180 | "--report_to", 181 | type=str, 182 | default="all", 183 | help=( 184 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 185 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' 186 | "Only applicable when `--with_tracking` is passed." 187 | ), 188 | ) 189 | parser.add_argument( 190 | "--ignore_mismatched_sizes", 191 | action="store_true", 192 | help="Whether or not to enable to load a pretrained model whose head dimensions are different.", 193 | ) 194 | 195 | # Trigger Selection 196 | parser.add_argument( 197 | "--trigger_seed", type=int, default=2022, help="The seed for trigger selector." 198 | ) 199 | parser.add_argument( 200 | "--trigger_min_max_freq", 201 | nargs="+", 202 | type=float, 203 | default=None, 204 | help="The max and min frequency of selected triger tokens.", 205 | ) 206 | parser.add_argument( 207 | "--selected_trigger_num", 208 | type=int, 209 | default=100, 210 | help="The maximum number of triggers in a sentence.", 211 | ) 212 | parser.add_argument( 213 | "--max_trigger_num", 214 | type=int, 215 | default=100, 216 | help="The maximum number of triggers in a sentence.", 217 | ) 218 | parser.add_argument( 219 | "--word_count_file", 220 | type=str, 221 | default=None, 222 | help="The preprocessed word count file to load. Compute word count from dataset if None.", 223 | ) 224 | parser.add_argument( 225 | "--disable_pca_evaluate", action="store_true", help="Disable pca evaluate." 226 | ) 227 | parser.add_argument( 228 | "--disable_training", action="store_true", help="Disable pca evaluate." 229 | ) 230 | 231 | # Model Copy 232 | parser.add_argument( 233 | "--verify_dataset_size", 234 | type=int, 235 | default=20, 236 | help="The number of samples of verify dataset.", 237 | ) 238 | parser.add_argument( 239 | "--transform_hidden_size", 240 | type=int, 241 | default=1536, 242 | help="The dimention of transform hidden layer.", 243 | ) 244 | parser.add_argument( 245 | "--transform_dropout_rate", 246 | type=float, 247 | default=0.0, 248 | help="The dropout rate of transformation layer.", 249 | ) 250 | parser.add_argument( 251 | "--copy_learning_rate", 252 | type=float, 253 | default=5e-5, 254 | help="Initial learning rate (after the potential warmup period) to use.", 255 | ) 256 | parser.add_argument( 257 | "--copy_num_train_epochs", 258 | type=int, 259 | default=3, 260 | help="Total number of training epochs to perform.", 261 | ) 262 | parser.add_argument( 263 | "--copy_max_train_steps", 264 | type=int, 265 | default=None, 266 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 267 | ) 268 | parser.add_argument( 269 | "--copy_gradient_accumulation_steps", 270 | type=int, 271 | default=1, 272 | help="Number of updates steps to accumulate before performing a backward/update pass.", 273 | ) 274 | parser.add_argument( 275 | "--copy_num_warmup_steps", 276 | type=int, 277 | default=0, 278 | help="Number of steps for the warmup in the lr scheduler.", 279 | ) 280 | 281 | # GPT3 Classifier Config 282 | parser.add_argument( 283 | "--cls_hidden_dim", 284 | type=int, 285 | default=None, 286 | help="The hidden dimention of gpt3 classifier.", 287 | ) 288 | parser.add_argument( 289 | "--cls_dropout_rate", 290 | type=float, 291 | default=None, 292 | help="The dropout rate of gpt3 classifier.", 293 | ) 294 | parser.add_argument( 295 | "--cls_learning_rate", 296 | type=float, 297 | default=5e-5, 298 | help="Initial learning rate (after the potential warmup period) to use.", 299 | ) 300 | parser.add_argument( 301 | "--cls_num_train_epochs", 302 | type=int, 303 | default=3, 304 | help="Total number of training epochs to perform.", 305 | ) 306 | parser.add_argument( 307 | "--cls_max_train_steps", 308 | type=int, 309 | default=None, 310 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 311 | ) 312 | parser.add_argument( 313 | "--cls_gradient_accumulation_steps", 314 | type=int, 315 | default=1, 316 | help="Number of updates steps to accumulate before performing a backward/update pass.", 317 | ) 318 | parser.add_argument( 319 | "--cls_num_warmup_steps", 320 | type=int, 321 | default=0, 322 | help="Number of steps for the warmup in the lr scheduler.", 323 | ) 324 | 325 | parser.add_argument( 326 | "--data_name", type=str, default="sst2", help="dataset name for training." 327 | ) 328 | 329 | parser.add_argument( 330 | "--project_name", type=str, default=None, help="project name for training." 331 | ) 332 | 333 | # advanced 334 | parser.add_argument( 335 | "--use_copy_target", 336 | type=bool, 337 | default=False, 338 | help="Switch to the advanced version of EmbMarker to defend against distance-invariant attacks.", 339 | ) 340 | 341 | # visualization 342 | parser.add_argument( 343 | "--plot_sample_num", 344 | type=int, 345 | default=600, 346 | help="Sample a subset of examples for visualization to decrease the figure size.", 347 | ) 348 | parser.add_argument( 349 | "--vis_method", 350 | type=str, 351 | default="pca", 352 | choices=["pca", "tsne"], 353 | help="Choose a dimension reduction algprithm to visualize embeddings. Only support pca and tsne now.", 354 | ) 355 | 356 | args = parser.parse_args() 357 | 358 | return args 359 | 360 | 361 | DATA_INFO = { 362 | "sst2": { 363 | "dataset_name": "glue", 364 | "dataset_config_name": "sst2", 365 | "text": "sentence", 366 | "idx": "idx", 367 | "remove": ["sentence", "idx"], 368 | }, 369 | "enron": { 370 | "dataset_name": "SetFit/enron_spam", 371 | "dataset_config_name": None, 372 | "text": "subject", 373 | "idx": "message_id", 374 | "remove": [ 375 | "message_id", 376 | "text", 377 | "label", 378 | "label_text", 379 | "subject", 380 | "message", 381 | "date", 382 | ], 383 | }, 384 | "ag_news": { 385 | "dataset_name": "ag_news", 386 | "dataset_config_name": None, 387 | "text": "text", 388 | "idx": "md5", 389 | "remove": ["label", "text"], 390 | }, 391 | "mind": { 392 | "dataset_name": "mind", 393 | "dataset_config_name": None, 394 | "text": "title", 395 | "idx": "docid", 396 | "remove": ["label", "title", "docid"], 397 | }, 398 | } 399 | 400 | 401 | def main(): 402 | args = parse_args() 403 | 404 | accelerator = ( 405 | Accelerator(log_with=args.report_to, logging_dir=args.output_dir) 406 | if args.with_tracking 407 | else Accelerator() 408 | ) 409 | 410 | # Make one log on every process with the configuration for debugging. 411 | logging.basicConfig( 412 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 413 | datefmt="%m/%d/%Y %H:%M:%S", 414 | level=logging.INFO, 415 | ) 416 | logger.info(accelerator.state, main_process_only=False) 417 | if accelerator.is_local_main_process: 418 | datasets.utils.logging.set_verbosity_warning() 419 | else: 420 | datasets.utils.logging.set_verbosity_error() 421 | 422 | # If passed along, set the training seed now. 423 | if args.seed is not None: 424 | set_seed(args.seed) 425 | 426 | if accelerator.is_main_process: 427 | os.makedirs(args.output_dir, exist_ok=True) 428 | accelerator.wait_for_everyone() 429 | 430 | # Load raw dataset 431 | if args.data_name == "mind": 432 | raw_datasets = load_mind( 433 | train_tsv_path=args.train_file, 434 | test_tsv_path=args.test_file, 435 | ) 436 | else: 437 | raw_datasets = load_dataset( 438 | DATA_INFO[args.data_name]["dataset_name"], 439 | DATA_INFO[args.data_name]["dataset_config_name"], 440 | ) 441 | if args.data_name == "sst2": 442 | raw_datasets["test"] = raw_datasets["validation"] 443 | 444 | label_list = list(set(raw_datasets["train"]["label"])) 445 | num_labels = len(label_list) 446 | 447 | # Define gpt classifier config and model 448 | cls_config = GPTClassifierConfig( 449 | gpt_emb_dim=args.gpt_emb_dim, 450 | hidden_dim=args.cls_hidden_dim, 451 | dropout_rate=args.cls_dropout_rate, 452 | num_labels=num_labels, 453 | ) 454 | cls_model = GPTClassifier(cls_config) 455 | 456 | # Define copy model tokenizer, config and model 457 | config = AutoConfig.from_pretrained(args.model_name_or_path) 458 | config.transform_hidden_size = args.transform_hidden_size 459 | config.gpt_emb_dim = args.gpt_emb_dim 460 | config.transform_dropout_rate = args.transform_dropout_rate 461 | 462 | tokenizer = AutoTokenizer.from_pretrained( 463 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer 464 | ) 465 | provider_tokenizer = AutoTokenizer.from_pretrained( 466 | "bert-base-cased", use_fast=not args.use_slow_tokenizer 467 | ) 468 | model = BertForClassifyWithBackDoor.from_pretrained( 469 | args.model_name_or_path, 470 | from_tf=bool(".ckpt" in args.model_name_or_path), 471 | config=config, 472 | ignore_mismatched_sizes=args.ignore_mismatched_sizes, 473 | ) 474 | 475 | # Preprocess Dataset 476 | emb_caches = load_gpt_embeds( 477 | args, 478 | args.gpt_emb_train_file, 479 | args.gpt_emb_validation_file, 480 | args.gpt_emb_test_file, 481 | ) 482 | 483 | emb_caches.open() 484 | 485 | padding = "max_length" if args.pad_to_max_length else False 486 | 487 | def process_func(examples, key): 488 | texts = examples[DATA_INFO[args.data_name]["text"]] 489 | 490 | result = tokenizer( 491 | texts, padding=padding, max_length=args.max_length, truncation=True 492 | ) 493 | 494 | bert_base_result = provider_tokenizer( 495 | texts, padding=padding, max_length=args.max_length, truncation=True 496 | ) 497 | 498 | idx_name = DATA_INFO[args.data_name]["idx"] 499 | if idx_name == "md5": 500 | idx_byte = hashlib.md5( 501 | examples[DATA_INFO[args.data_name]["text"]].encode("utf-8") 502 | ).digest() 503 | idx = int.from_bytes(idx_byte, "big") 504 | else: 505 | idx = examples[idx_name] 506 | result["provider_input_ids"] = bert_base_result["input_ids"] 507 | result["clean_gpt_emb"] = emb_caches[key][idx] 508 | result["labels"] = examples["label"] 509 | return result 510 | 511 | with accelerator.main_process_first(): 512 | processed_datasets = DatasetDict( 513 | { 514 | k: dataset.map( 515 | partial(process_func, key=k), 516 | remove_columns=DATA_INFO[args.data_name]["remove"], 517 | desc="Run tokenization and add gpt3 embeddings on dataset", 518 | ) 519 | for k, dataset in raw_datasets.items() 520 | } 521 | ) 522 | 523 | # Target_emb selection (Temp the first target emb) 524 | target_sample = processed_datasets["train"][0] 525 | 526 | # Trigger selection 527 | trigger_selector = BaseTriggerSelector( 528 | args, 529 | args.trigger_seed, 530 | processed_datasets, 531 | tokenizer, 532 | provider_tokenizer, 533 | accelerator, 534 | ) 535 | trigger_selector.set_target_sample(target_sample) 536 | trigger_selector.select_triggers() 537 | processed_datasets, trigger_num_state = trigger_selector.process_datasets( 538 | processed_datasets 539 | ) 540 | verify_dataset = trigger_selector.construct_verify_dataset() 541 | 542 | emb_caches.close() 543 | logging.info(id(processed_datasets)) 544 | 545 | train_dataset = processed_datasets["train"] 546 | eval_dataset = processed_datasets["test"] 547 | 548 | # DataLoaders creation: 549 | if args.pad_to_max_length: 550 | data_collator = default_data_collator 551 | else: 552 | data_collator = DataCollatorWithPadding( 553 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None) 554 | ) 555 | 556 | train_dataloader = DataLoader( 557 | train_dataset, 558 | shuffle=True, 559 | collate_fn=data_collator, 560 | batch_size=args.per_device_train_batch_size, 561 | ) 562 | eval_dataloader = DataLoader( 563 | eval_dataset, 564 | collate_fn=data_collator, 565 | batch_size=args.per_device_eval_batch_size, 566 | ) 567 | verify_dataloader = DataLoader( 568 | verify_dataset, 569 | collate_fn=data_collator, 570 | batch_size=args.per_device_eval_batch_size, 571 | ) 572 | 573 | # We need to initialize the trackers we use, and also store our configuration. 574 | # The trackers initializes automatically on the main process. 575 | if args.with_tracking: 576 | experiment_config = vars(args) 577 | # TensorBoard cannot log Enums, need the raw value 578 | experiment_config["lr_scheduler_type"] = experiment_config[ 579 | "lr_scheduler_type" 580 | ].value 581 | 582 | init_kwargs = None 583 | if args.job_name is not None: 584 | init_kwargs = {"wandb": {"name": args.job_name}} 585 | 586 | if args.project_name is not None: 587 | project_name = args.project_name 588 | else: 589 | project_name = args.data_name + "_gpt_watermark" 590 | 591 | accelerator.init_trackers( 592 | project_name, 593 | experiment_config, 594 | init_kwargs=init_kwargs, 595 | ) 596 | 597 | if not args.disable_pca_evaluate: 598 | eval_backdoor_pca(args, train_dataloader, eval_dataloader, accelerator) 599 | 600 | if not args.disable_training: 601 | completed_steps, copier_eval_metrics = train_copier( 602 | args, 603 | model, 604 | train_dataset, 605 | train_dataloader, 606 | eval_dataloader, 607 | verify_dataloader, 608 | accelerator, 609 | args.copy_learning_rate, 610 | args.copy_gradient_accumulation_steps, 611 | args.copy_max_train_steps, 612 | args.copy_num_train_epochs, 613 | args.copy_num_warmup_steps, 614 | trigger_selector.target_emb, 615 | target_sample=target_sample, 616 | completed_steps=0, 617 | ) 618 | 619 | completed_steps, cls_eval_metrics = train_cls( 620 | args, 621 | cls_model, 622 | train_dataset, 623 | train_dataloader, 624 | eval_dataloader, 625 | accelerator, 626 | args.cls_learning_rate, 627 | args.cls_gradient_accumulation_steps, 628 | args.cls_max_train_steps, 629 | args.cls_num_train_epochs, 630 | args.cls_num_warmup_steps, 631 | completed_steps=completed_steps, 632 | ) 633 | 634 | eval_metrics = merge_flatten_metrics( 635 | copier_eval_metrics, cls_eval_metrics, parent_key="glue", sep="." 636 | ) 637 | 638 | if args.report_to == "wandb": 639 | for key, value in eval_metrics.items(): 640 | wandb.run.summary[key] = value 641 | 642 | for trigger_num, value in trigger_num_state.items(): 643 | wandb.run.summary[f"trigger_num_{trigger_num}"] = value 644 | 645 | if args.with_tracking and args.report_to != "wandb": 646 | accelerator.end_training() 647 | 648 | 649 | def train_cls( 650 | args, 651 | model, 652 | train_dataset, 653 | train_dataloader, 654 | eval_dataloader, 655 | accelerator, 656 | learning_rate, 657 | gradient_accumulation_steps, 658 | max_train_steps, 659 | num_train_epochs, 660 | num_warmup_steps, 661 | completed_steps=0, 662 | ): 663 | # Optimizer 664 | # Split weights in two groups, one with weight decay and the other not. 665 | no_decay = ["bias", "LayerNorm.weight"] 666 | optimizer_grouped_parameters = [ 667 | { 668 | "params": [ 669 | p 670 | for n, p in model.named_parameters() 671 | if not any(nd in n for nd in no_decay) 672 | ], 673 | "weight_decay": args.weight_decay, 674 | }, 675 | { 676 | "params": [ 677 | p 678 | for n, p in model.named_parameters() 679 | if any(nd in n for nd in no_decay) 680 | ], 681 | "weight_decay": 0.0, 682 | }, 683 | ] 684 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate) 685 | 686 | # Scheduler and math around the number of training steps. 687 | overrode_max_train_steps = False 688 | num_update_steps_per_epoch = math.ceil( 689 | len(train_dataloader) / gradient_accumulation_steps 690 | ) 691 | if max_train_steps is None: 692 | max_train_steps = num_train_epochs * num_update_steps_per_epoch 693 | overrode_max_train_steps = True 694 | 695 | lr_scheduler = get_scheduler( 696 | name=args.lr_scheduler_type, 697 | optimizer=optimizer, 698 | num_warmup_steps=num_warmup_steps, 699 | num_training_steps=max_train_steps, 700 | ) 701 | 702 | ( 703 | model, 704 | optimizer, 705 | train_dataloader, 706 | eval_dataloader, 707 | lr_scheduler, 708 | ) = accelerator.prepare( 709 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 710 | ) 711 | 712 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 713 | num_update_steps_per_epoch = math.ceil( 714 | len(train_dataloader) / gradient_accumulation_steps 715 | ) 716 | if overrode_max_train_steps: 717 | max_train_steps = num_train_epochs * num_update_steps_per_epoch 718 | 719 | # Afterwards we recalculate our number of training epochs 720 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 721 | 722 | # count form init completed steps 723 | max_train_steps += completed_steps 724 | 725 | # Figure out how many steps we should save the Accelerator states 726 | checkpointing_steps = args.checkpointing_steps 727 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 728 | checkpointing_steps = int(checkpointing_steps) 729 | 730 | # Get the metric function 731 | metric = evaluate.load("glue", "sst2") 732 | 733 | # Train! 734 | total_batch_size = ( 735 | args.per_device_train_batch_size 736 | * accelerator.num_processes 737 | * gradient_accumulation_steps 738 | ) 739 | 740 | logger.info("***** Running classifier training *****") 741 | logger.info(f" Num examples = {len(train_dataset)}") 742 | logger.info(f" Num Epochs = {num_train_epochs}") 743 | logger.info( 744 | f" Instantaneous batch size per device = {args.per_device_train_batch_size}" 745 | ) 746 | logger.info( 747 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 748 | ) 749 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 750 | logger.info(f" Total optimization steps = {max_train_steps}") 751 | # Only show the progress bar once on each machine. 752 | progress_bar = tqdm( 753 | range(max_train_steps), disable=not accelerator.is_local_main_process 754 | ) 755 | starting_epoch = 0 756 | # Potentially load in the weights and states from a previous save 757 | if args.resume_from_checkpoint: 758 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 759 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 760 | accelerator.load_state(args.resume_from_checkpoint) 761 | path = os.path.basename(args.resume_from_checkpoint) 762 | else: 763 | # Get the most recent checkpoint 764 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 765 | dirs.sort(key=os.path.getctime) 766 | path = dirs[ 767 | -1 768 | ] # Sorts folders by date modified, most recent checkpoint is the last 769 | # Extract `epoch_{i}` or `step_{i}` 770 | training_difference = os.path.splitext(path)[0] 771 | 772 | if "epoch" in training_difference: 773 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 774 | resume_step = None 775 | else: 776 | resume_step = int(training_difference.replace("step_", "")) 777 | starting_epoch = resume_step // len(train_dataloader) 778 | resume_step -= starting_epoch * len(train_dataloader) 779 | 780 | for epoch in range(starting_epoch, num_train_epochs): 781 | model.train() 782 | total_loss = 0 783 | for step, batch in enumerate(train_dataloader): 784 | # We need to skip steps until we reach the resumed step 785 | if args.resume_from_checkpoint and epoch == starting_epoch: 786 | if resume_step is not None and step < resume_step: 787 | completed_steps += 1 788 | continue 789 | outputs = model(**batch) 790 | 791 | loss = outputs.loss 792 | total_loss += loss.detach().float() 793 | loss = loss / gradient_accumulation_steps 794 | accelerator.backward(loss) 795 | if ( 796 | step % gradient_accumulation_steps == 0 797 | or step == len(train_dataloader) - 1 798 | ): 799 | optimizer.step() 800 | lr_scheduler.step() 801 | optimizer.zero_grad() 802 | progress_bar.update(1) 803 | completed_steps += 1 804 | 805 | if isinstance(checkpointing_steps, int): 806 | if completed_steps % checkpointing_steps == 0: 807 | output_dir = f"step_{completed_steps }" 808 | if args.output_dir is not None: 809 | output_dir = os.path.join(args.output_dir, output_dir) 810 | accelerator.save_state(output_dir) 811 | 812 | if completed_steps >= max_train_steps: 813 | break 814 | 815 | model.eval() 816 | samples_seen = 0 817 | for step, batch in enumerate(eval_dataloader): 818 | with torch.no_grad(): 819 | outputs = model(**batch) 820 | 821 | predictions = outputs.logits.argmax(dim=-1) 822 | 823 | predictions, references = accelerator.gather((predictions, batch["labels"])) 824 | # If we are in a multiprocess environment, the last batch has duplicates 825 | if accelerator.num_processes > 1: 826 | if step == len(eval_dataloader) - 1: 827 | predictions = predictions[ 828 | : len(eval_dataloader.dataset) - samples_seen 829 | ] 830 | references = references[ 831 | : len(eval_dataloader.dataset) - samples_seen 832 | ] 833 | else: 834 | samples_seen += references.shape[0] 835 | metric.add_batch( 836 | predictions=predictions, 837 | references=references, 838 | ) 839 | 840 | eval_metric = metric.compute() 841 | logger.info(f"epoch {epoch}: {eval_metric}") 842 | 843 | if args.with_tracking: 844 | accelerator.log( 845 | { 846 | "glue": eval_metric, 847 | "cls_train_loss": total_loss.item() / len(train_dataloader), 848 | }, 849 | step=completed_steps, 850 | ) 851 | 852 | if args.checkpointing_steps == "epoch": 853 | output_dir = f"epoch_{epoch}_cls" 854 | if args.output_dir is not None: 855 | output_dir = os.path.join(args.output_dir, output_dir) 856 | accelerator.save_state(output_dir) 857 | 858 | if args.output_dir is not None: 859 | accelerator.wait_for_everyone() 860 | output_dir = os.path.join(args.output_dir, "cls") 861 | unwrapped_model = accelerator.unwrap_model(model) 862 | unwrapped_model.save_pretrained( 863 | output_dir, 864 | is_main_process=accelerator.is_main_process, 865 | save_function=accelerator.save, 866 | ) 867 | 868 | if args.output_dir is not None: 869 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()} 870 | with open(os.path.join(args.output_dir, "cls_results.json"), "w") as f: 871 | json.dump(all_results, f) 872 | 873 | return completed_steps, eval_metric 874 | 875 | 876 | def train_copier( 877 | args, 878 | model, 879 | train_dataset, 880 | train_dataloader, 881 | eval_dataloader, 882 | verify_dataloader, 883 | accelerator, 884 | learning_rate, 885 | gradient_accumulation_steps, 886 | max_train_steps, 887 | num_train_epochs, 888 | num_warmup_steps, 889 | target_emb, 890 | target_sample=None, 891 | completed_steps=0, 892 | ): 893 | no_decay = ["bias", "LayerNorm.weight"] 894 | optimizer_grouped_parameters = [ 895 | { 896 | "params": [ 897 | p 898 | for n, p in model.named_parameters() 899 | if not any(nd in n for nd in no_decay) 900 | ], 901 | "weight_decay": args.weight_decay, 902 | }, 903 | { 904 | "params": [ 905 | p 906 | for n, p in model.named_parameters() 907 | if any(nd in n for nd in no_decay) 908 | ], 909 | "weight_decay": 0.0, 910 | }, 911 | ] 912 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate) 913 | 914 | # Scheduler and math around the number of training steps. 915 | overrode_max_train_steps = False 916 | num_update_steps_per_epoch = math.ceil( 917 | len(train_dataloader) / gradient_accumulation_steps 918 | ) 919 | if max_train_steps is None: 920 | max_train_steps = num_train_epochs * num_update_steps_per_epoch 921 | overrode_max_train_steps = True 922 | 923 | lr_scheduler = get_scheduler( 924 | name=args.lr_scheduler_type, 925 | optimizer=optimizer, 926 | num_warmup_steps=num_warmup_steps, 927 | num_training_steps=max_train_steps, 928 | ) 929 | 930 | ( 931 | model, 932 | optimizer, 933 | train_dataloader, 934 | eval_dataloader, 935 | verify_dataloader, 936 | lr_scheduler, 937 | ) = accelerator.prepare( 938 | model, 939 | optimizer, 940 | train_dataloader, 941 | eval_dataloader, 942 | verify_dataloader, 943 | lr_scheduler, 944 | ) 945 | 946 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 947 | num_update_steps_per_epoch = math.ceil( 948 | len(train_dataloader) / gradient_accumulation_steps 949 | ) 950 | if overrode_max_train_steps: 951 | max_train_steps = num_train_epochs * num_update_steps_per_epoch 952 | # Afterwards we recalculate our number of training epochs 953 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 954 | 955 | # Figure out how many steps we should save the Accelerator states 956 | checkpointing_steps = args.checkpointing_steps 957 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 958 | checkpointing_steps = int(checkpointing_steps) 959 | 960 | # Train! 961 | total_batch_size = ( 962 | args.per_device_train_batch_size 963 | * accelerator.num_processes 964 | * gradient_accumulation_steps 965 | ) 966 | 967 | logger.info("***** Running copier training *****") 968 | logger.info(f" Num examples = {len(train_dataset)}") 969 | logger.info(f" Num Epochs = {num_train_epochs}") 970 | logger.info( 971 | f" Instantaneous batch size per device = {args.per_device_train_batch_size}" 972 | ) 973 | logger.info( 974 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 975 | ) 976 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 977 | logger.info(f" Total optimization steps = {max_train_steps}") 978 | # Only show the progress bar once on each machine. 979 | progress_bar = tqdm( 980 | range(max_train_steps), disable=not accelerator.is_local_main_process 981 | ) 982 | starting_epoch = 0 983 | # Potentially load in the weights and states from a previous save 984 | if args.resume_from_checkpoint: 985 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 986 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 987 | accelerator.load_state(args.resume_from_checkpoint) 988 | path = os.path.basename(args.resume_from_checkpoint) 989 | else: 990 | # Get the most recent checkpoint 991 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 992 | dirs.sort(key=os.path.getctime) 993 | path = dirs[ 994 | -1 995 | ] # Sorts folders by date modified, most recent checkpoint is the last 996 | # Extract `epoch_{i}` or `step_{i}` 997 | training_difference = os.path.splitext(path)[0] 998 | 999 | if "epoch" in training_difference: 1000 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 1001 | resume_step = None 1002 | else: 1003 | resume_step = int(training_difference.replace("step_", "")) 1004 | starting_epoch = resume_step // len(train_dataloader) 1005 | resume_step -= starting_epoch * len(train_dataloader) 1006 | 1007 | for epoch in range(starting_epoch, num_train_epochs): 1008 | model.train() 1009 | total_loss = 0 1010 | 1011 | for step, batch in enumerate(train_dataloader): 1012 | # We need to skip steps until we reach the resumed step 1013 | if args.resume_from_checkpoint and epoch == starting_epoch: 1014 | if resume_step is not None and step < resume_step: 1015 | completed_steps += 1 1016 | continue 1017 | outputs = model(**batch) 1018 | 1019 | loss = outputs.loss 1020 | total_loss += loss.detach().float() 1021 | loss = loss / gradient_accumulation_steps 1022 | accelerator.backward(loss) 1023 | if ( 1024 | step % gradient_accumulation_steps == 0 1025 | or step == len(train_dataloader) - 1 1026 | ): 1027 | optimizer.step() 1028 | lr_scheduler.step() 1029 | optimizer.zero_grad() 1030 | progress_bar.update(1) 1031 | completed_steps += 1 1032 | 1033 | if isinstance(checkpointing_steps, int): 1034 | if completed_steps % checkpointing_steps == 0: 1035 | output_dir = f"step_{completed_steps }" 1036 | if args.output_dir is not None: 1037 | output_dir = os.path.join(args.output_dir, output_dir) 1038 | accelerator.save_state(output_dir) 1039 | 1040 | if completed_steps >= max_train_steps: 1041 | break 1042 | 1043 | eval_metric = eval_copier( 1044 | args, 1045 | model, 1046 | total_loss, 1047 | epoch, 1048 | completed_steps, 1049 | train_dataloader, 1050 | eval_dataloader, 1051 | verify_dataloader, 1052 | accelerator, 1053 | target_emb, 1054 | target_sample, 1055 | ) 1056 | 1057 | if args.checkpointing_steps == "epoch": 1058 | output_dir = f"epoch_{epoch}_copier" 1059 | if args.output_dir is not None: 1060 | output_dir = os.path.join(args.output_dir, output_dir) 1061 | accelerator.save_state(output_dir) 1062 | 1063 | if args.output_dir is not None: 1064 | accelerator.wait_for_everyone() 1065 | unwrapped_model = accelerator.unwrap_model(model) 1066 | output_dir = os.path.join(args.output_dir, "copier") 1067 | unwrapped_model.save_pretrained( 1068 | output_dir, 1069 | is_main_process=accelerator.is_main_process, 1070 | save_function=accelerator.save, 1071 | ) 1072 | 1073 | if args.output_dir is not None: 1074 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()} 1075 | with open(os.path.join(args.output_dir, "copier_results.json"), "w") as f: 1076 | json.dump(all_results, f) 1077 | 1078 | return completed_steps, eval_metric 1079 | 1080 | 1081 | def eval_copier( 1082 | args, 1083 | model, 1084 | total_loss, 1085 | epoch, 1086 | completed_steps, 1087 | train_dataloader, 1088 | eval_dataloader, 1089 | verify_dataloader, 1090 | accelerator, 1091 | target_emb, 1092 | target_sample, 1093 | ): 1094 | model.eval() 1095 | if args.use_copy_target and target_sample is not None: 1096 | input_ids = ( 1097 | torch.as_tensor(target_sample["input_ids"], dtype=torch.long) 1098 | .unsqueeze(0) 1099 | .cuda() 1100 | ) 1101 | attention_mask = ( 1102 | torch.as_tensor(target_sample["attention_mask"], dtype=torch.long) 1103 | .unsqueeze(0) 1104 | .cuda() 1105 | ) 1106 | token_type_ids = ( 1107 | torch.as_tensor(target_sample["token_type_ids"], dtype=torch.long) 1108 | .unsqueeze(0) 1109 | .cuda() 1110 | ) 1111 | target_emb = model( 1112 | input_ids=input_ids, 1113 | attention_mask=attention_mask, 1114 | token_type_ids=token_type_ids, 1115 | ).copied_emb.squeeze() 1116 | else: 1117 | target_emb = target_emb.cuda() 1118 | results = {} 1119 | 1120 | clean_target_cos_dists = [] 1121 | clean_target_l2_dists = [] 1122 | clean_gpt_cos_dists = [] 1123 | clean_gpt_l2_dists = [] 1124 | 1125 | loss_fn = nn.MSELoss(reduction="none") 1126 | 1127 | # Compute clean to target and to gpt distance 1128 | for step, batch in enumerate(eval_dataloader): 1129 | with torch.no_grad(): 1130 | outputs = model(**batch) 1131 | clean_target_cos_dist = ( 1132 | torch.mm(outputs.copied_emb, target_emb.unsqueeze(-1)) 1133 | .detach() 1134 | .cpu() 1135 | .numpy() 1136 | ) 1137 | clean_target_l2_dist = ( 1138 | torch.sum( 1139 | loss_fn( 1140 | outputs.copied_emb, 1141 | target_emb.unsqueeze(0).expand(outputs.copied_emb.size(0), -1), 1142 | ), 1143 | dim=-1, 1144 | ) 1145 | .detach() 1146 | .cpu() 1147 | .numpy() 1148 | ) 1149 | clean_gpt_cos_dist = ( 1150 | torch.bmm( 1151 | outputs.copied_emb.unsqueeze(-2), outputs.gpt_emb.unsqueeze(-1) 1152 | ) 1153 | .detach() 1154 | .cpu() 1155 | .numpy() 1156 | ) 1157 | clean_gpt_l2_dist = ( 1158 | torch.sum(loss_fn(outputs.copied_emb, outputs.gpt_emb), dim=-1) 1159 | .detach() 1160 | .cpu() 1161 | .numpy() 1162 | ) 1163 | 1164 | clean_target_cos_dists.append(clean_target_cos_dist) 1165 | clean_target_l2_dists.append(clean_target_l2_dist) 1166 | clean_gpt_cos_dists.append(clean_gpt_cos_dist) 1167 | clean_gpt_l2_dists.append(clean_gpt_l2_dist) 1168 | 1169 | clean_target_cos_dists = np.concatenate(clean_target_cos_dists, axis=0) 1170 | clean_target_l2_dists = np.concatenate(clean_target_l2_dists, axis=0) 1171 | clean_gpt_cos_dists = np.concatenate(clean_gpt_cos_dists, axis=0) 1172 | clean_gpt_l2_dists = np.concatenate(clean_gpt_l2_dists, axis=0) 1173 | 1174 | results["clean_target_cos_mean"] = float(np.mean(clean_target_cos_dists)) 1175 | results["clean_target_cos_std"] = float(np.std(clean_target_cos_dists)) 1176 | results["clean_target_l2_mean"] = float(np.mean(clean_target_l2_dists)) 1177 | results["clean_target_l2_std"] = float(np.std(clean_target_l2_dists)) 1178 | results["clean_gpt_cos_mean"] = float(np.mean(clean_gpt_cos_dists)) 1179 | results["clean_gpt_cos_std"] = float(np.std(clean_gpt_cos_dists)) 1180 | results["clean_gpt_l2_mean"] = float(np.mean(clean_gpt_l2_dists)) 1181 | results["clean_gpt_l2_std"] = float(np.std(clean_gpt_l2_dists)) 1182 | 1183 | # Compute trigger to target distance 1184 | trigger_cos_dists = [] 1185 | trigger_l2_dists = [] 1186 | num_triggers = [] 1187 | 1188 | for step, batch in enumerate(verify_dataloader): 1189 | with torch.no_grad(): 1190 | num_triggers.append(batch["num_triggers"].cpu().numpy()) 1191 | outputs = model(**batch) 1192 | trigger_cos_dist = ( 1193 | torch.mm(outputs.copied_emb, target_emb.unsqueeze(-1)) 1194 | .view(-1) 1195 | .detach() 1196 | .cpu() 1197 | .numpy() 1198 | ) 1199 | trigger_l2_dist = ( 1200 | torch.sum( 1201 | loss_fn( 1202 | outputs.copied_emb, 1203 | target_emb.unsqueeze(0).expand(outputs.copied_emb.size(0), -1), 1204 | ), 1205 | dim=-1, 1206 | ) 1207 | .detach() 1208 | .cpu() 1209 | .numpy() 1210 | ) 1211 | 1212 | trigger_cos_dists.append(trigger_cos_dist) 1213 | trigger_l2_dists.append(trigger_l2_dist) 1214 | 1215 | trigger_cos_dists = np.concatenate(trigger_cos_dists, axis=0).tolist() 1216 | trigger_l2_dists = np.concatenate(trigger_l2_dists, axis=0).tolist() 1217 | num_triggers = np.concatenate(num_triggers, axis=0).tolist() 1218 | 1219 | trigger_results = pd.DataFrame.from_dict( 1220 | { 1221 | "trigger_cos_dists": trigger_cos_dists, 1222 | "trigger_l2_dists": trigger_l2_dists, 1223 | "num_triggers": num_triggers, 1224 | } 1225 | ) 1226 | 1227 | trigger_0_cos_dists = trigger_results[trigger_results["num_triggers"] == 0][ 1228 | "trigger_cos_dists" 1229 | ].values 1230 | trigger_all_cos_dists = trigger_results[ 1231 | trigger_results["num_triggers"] == args.max_trigger_num 1232 | ]["trigger_cos_dists"].values 1233 | 1234 | pvalue = stats.kstest(trigger_all_cos_dists, trigger_0_cos_dists).pvalue 1235 | results["pvalue"] = pvalue 1236 | 1237 | trigger_results = trigger_results.groupby(by=["num_triggers"], as_index=False).agg( 1238 | ["mean", "std"] 1239 | ) 1240 | trigger_results.columns = [ 1241 | "trigger_cos_mean", 1242 | "trigger_cos_std", 1243 | "trigger_l2_mean", 1244 | "trigger_l2_std", 1245 | ] 1246 | 1247 | for i in trigger_results.index: 1248 | result = trigger_results.loc[i] 1249 | if i == args.max_trigger_num: 1250 | i = "all" 1251 | for key in result.keys(): 1252 | results[f"{key}_{i}"] = float(result[key]) 1253 | 1254 | results["delta_cos"] = ( 1255 | results["trigger_cos_mean_all"] - results["trigger_cos_mean_0"] 1256 | ) 1257 | results["delta_l2"] = results["trigger_l2_mean_all"] - results["trigger_l2_mean_0"] 1258 | 1259 | logger.info( 1260 | f"epoch {epoch}: {results}, train_loss: {total_loss.item() / len(train_dataloader)}" 1261 | ) 1262 | 1263 | if args.with_tracking: 1264 | accelerator.log( 1265 | { 1266 | "glue": results, 1267 | "copy_train_loss": total_loss.item() / len(train_dataloader), 1268 | }, 1269 | step=completed_steps, 1270 | log_kwargs={"wandb": {"commit": False}}, 1271 | ) 1272 | return results 1273 | 1274 | 1275 | def eval_backdoor_pca(args, train_dataloader, eval_dataloader, accelerator): 1276 | from sklearn.decomposition import PCA 1277 | from sklearn.manifold import TSNE 1278 | import matplotlib.pyplot as plt 1279 | import seaborn as sns 1280 | import wandb 1281 | from matplotlib.ticker import LinearLocator, MultipleLocator, FormatStrFormatter 1282 | import matplotlib.ticker as mtick 1283 | 1284 | poisoned_gpt_embs = [] 1285 | clean_gpt_embs = [] 1286 | task_ids = [] 1287 | 1288 | if args.vis_method == "tsne": 1289 | vis = TSNE(n_components=2, init="pca", random_state=0, perplexity=5) 1290 | xy_steps = 40 1291 | resnum = "%.0f" 1292 | elif args.vis_method == "pca": 1293 | vis = PCA(n_components=2) 1294 | xy_steps = 0.1 1295 | resnum = "%.1f" 1296 | 1297 | with torch.no_grad(): 1298 | for step, batch in enumerate(train_dataloader): 1299 | clean_gpt_embs.append(batch["clean_gpt_emb"].detach().cpu()) 1300 | poisoned_gpt_embs.append(batch["gpt_emb"].detach().cpu()) 1301 | task_ids.append(batch["task_ids"].cpu()) 1302 | 1303 | for step, batch in enumerate(eval_dataloader): 1304 | clean_gpt_embs.append(batch["clean_gpt_emb"].detach().cpu()) 1305 | poisoned_gpt_embs.append(batch["gpt_emb"].detach().cpu()) 1306 | task_ids.append(batch["task_ids"].cpu()) 1307 | 1308 | clean_gpt_embs = torch.cat(clean_gpt_embs, dim=0) 1309 | poisoned_gpt_embs = torch.cat(poisoned_gpt_embs, dim=0) 1310 | task_ids = torch.cat(task_ids, dim=0).numpy().tolist() 1311 | 1312 | if args.plot_sample_num is not None: 1313 | plot_clean_gpt_embs = [] 1314 | plot_poisoned_gpt_embs = [] 1315 | plot_task_ids = [] 1316 | max_task_id = max(task_ids) + 1 1317 | tmp_task_ids = np.array(task_ids) 1318 | for i in range(max_task_id): 1319 | id2pos = tmp_task_ids == i 1320 | id2pos_num = sum(id2pos) 1321 | sample_num = max(1, int(id2pos_num * args.plot_sample_num / len(task_ids))) 1322 | logger.info( 1323 | f"sample {sample_num} examples with {i} triggers for visualization" 1324 | ) 1325 | tmp_clean_gpt_embs = clean_gpt_embs[id2pos] 1326 | tmp_poisoned_gpt_embs = poisoned_gpt_embs[id2pos] 1327 | sample_id = list(range(len(tmp_poisoned_gpt_embs))) 1328 | random.shuffle(sample_id) 1329 | sample_id = torch.as_tensor(sample_id[0:sample_num], dtype=torch.long) 1330 | plot_clean_gpt_embs.append(tmp_clean_gpt_embs[sample_id]) 1331 | plot_poisoned_gpt_embs.append(tmp_poisoned_gpt_embs[sample_id]) 1332 | plot_task_ids.extend( 1333 | [ 1334 | i, 1335 | ] 1336 | * tmp_poisoned_gpt_embs[sample_id].size(0) 1337 | ) 1338 | 1339 | plot_clean_gpt_embs = torch.cat(plot_clean_gpt_embs, dim=0) 1340 | plot_poisoned_gpt_embs = torch.cat(plot_poisoned_gpt_embs, dim=0) 1341 | logger.info(f"plot embeddings shape {plot_poisoned_gpt_embs.size()}.") 1342 | vis_gpt_output = vis.fit_transform(plot_clean_gpt_embs.cpu().numpy()) 1343 | vis_copy_output = vis.fit_transform(plot_poisoned_gpt_embs.cpu().numpy()) 1344 | vis_labels = plot_task_ids 1345 | else: 1346 | vis_gpt_output = vis.fit_transform(clean_gpt_embs.cpu().numpy()) 1347 | vis_copy_output = vis.fit_transform(poisoned_gpt_embs.cpu().numpy()) 1348 | vis_labels = task_ids 1349 | 1350 | fig = plt.figure(figsize=(8, 6)) 1351 | ax = fig.add_subplot(111) 1352 | ax.yaxis.set_major_locator(MultipleLocator(xy_steps)) 1353 | ax.xaxis.set_major_locator(MultipleLocator(xy_steps)) 1354 | ax.xaxis.set_major_formatter(mtick.FormatStrFormatter(resnum)) 1355 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter(resnum)) 1356 | 1357 | plot_data = pd.DataFrame( 1358 | {"x": vis_copy_output[:, 0], "y": vis_copy_output[:, 1], "num": vis_labels} 1359 | ) 1360 | plot_data = plot_data.sort_values(by="num") 1361 | 1362 | sns.set_theme(style="darkgrid") 1363 | sns.scatterplot( 1364 | data=plot_data, 1365 | x="x", 1366 | y="y", 1367 | hue="num", 1368 | s=90, 1369 | palette="dark", 1370 | style="num", 1371 | linewidth=0, 1372 | alpha=0.7, 1373 | ) 1374 | 1375 | max_label = max(vis_labels) + 1 1376 | bias = 1.18 1377 | 1378 | nc = 4 1379 | if max_label >= 4: 1380 | import math 1381 | 1382 | nl = math.ceil(max_label / 4) 1383 | bias += (nl - 1) * 0.1 1384 | 1385 | plt.legend( 1386 | fontsize=20, 1387 | loc="upper center", 1388 | framealpha=0.8, 1389 | ncol=nc, 1390 | bbox_to_anchor=(0.47, bias), 1391 | ) 1392 | plt.xlabel("") 1393 | plt.ylabel("") 1394 | 1395 | plt.yticks(fontsize=24) 1396 | plt.xticks(fontsize=24) 1397 | 1398 | # save figure size 1399 | output_dir = os.path.join(args.output_dir, "pca.png") 1400 | plt.savefig(output_dir, dpi=20, bbox_inches="tight") 1401 | output_dir = os.path.join(args.output_dir, "pca.pdf") 1402 | plt.savefig(output_dir, dpi=20, bbox_inches="tight") 1403 | plt.close() 1404 | 1405 | if args.with_tracking: 1406 | accelerator.log({"chart": wandb.Image(fig)}) 1407 | 1408 | 1409 | if __name__ == "__main__": 1410 | main() 1411 | -------------------------------------------------------------------------------- /src/trigger/base.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import logging 3 | import json 4 | import random 5 | import numpy as np 6 | from collections import Counter, defaultdict 7 | from argparse import Namespace 8 | 9 | from torch.utils.data import Dataset 10 | 11 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 12 | from accelerate import Accelerator 13 | from accelerate.logging import get_logger 14 | from datasets import Dataset, DatasetDict 15 | import torch 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | class BaseTriggerSelector: 21 | def __init__( 22 | self, 23 | args: Namespace, 24 | seed: int, 25 | dataset: Dataset, 26 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 27 | provider_tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 28 | accelerator: Accelerator, 29 | ): 30 | self.args = args 31 | self.dataset = dataset 32 | self.tokenizer = tokenizer 33 | self.provider_tokenizer = provider_tokenizer 34 | self.accelerator = accelerator 35 | 36 | self.rng = random.Random(seed) 37 | 38 | self.compute_word_cnt() 39 | 40 | def compute_word_cnt(self): 41 | if self.args.word_count_file is None: 42 | self.idx_counter = Counter() 43 | self.token_counter = defaultdict(float) 44 | 45 | sample_cnt = 0 46 | for split in self.dataset: 47 | for input_ids in self.dataset[split]["input_ids"]: 48 | unique_input_ids = set(input_ids) 49 | self.idx_counter.update(unique_input_ids) 50 | sample_cnt += len(self.dataset[split]) 51 | 52 | # transform countings to frequency 53 | for token_id in self.idx_counter: 54 | self.idx_counter[token_id] = self.idx_counter[token_id] / sample_cnt 55 | 56 | # convert idx to token 57 | for idx, freq in self.idx_counter.items(): 58 | token = self.provider_tokenizer._convert_id_to_token(idx) 59 | self.token_counter[token] = freq 60 | else: 61 | sample_cnt = 1801350 62 | with open(self.args.word_count_file, "r") as f: 63 | self.token_counter = json.load(f) 64 | self.idx_counter = defaultdict(float) 65 | 66 | for token in self.token_counter: 67 | self.token_counter[token] = self.token_counter[token] / sample_cnt 68 | token_id = self.provider_tokenizer._convert_token_to_id_with_added_voc(token) 69 | self.idx_counter[token_id] = self.token_counter[token] 70 | 71 | def select_triggers(self): 72 | min_freq, max_freq = self.args.trigger_min_max_freq 73 | candidate_token_freq_set = list( 74 | filter( 75 | lambda x: (min_freq <= x[1] < max_freq) and ("##" not in x[0]), 76 | self.token_counter.items(), 77 | ) 78 | ) 79 | 80 | selected_token_freq = self.rng.sample( 81 | candidate_token_freq_set, 82 | k=min(self.args.selected_trigger_num, len(candidate_token_freq_set)), 83 | ) 84 | 85 | self.selected_tokens, self.selected_freq = zip(*selected_token_freq) 86 | self.selected_idx = self.provider_tokenizer.convert_tokens_to_ids(self.selected_tokens) 87 | 88 | logger.info("============== Selected Tokens ==============") 89 | for token, freq in zip(self.selected_tokens, self.selected_freq): 90 | logger.info(f"{token}: {freq}") 91 | 92 | return self.selected_tokens 93 | 94 | def set_target_sample(self, target_sample): 95 | self.target_sample = target_sample 96 | self.target_emb = torch.FloatTensor(target_sample["clean_gpt_emb"]) 97 | 98 | def process_datasets(self, dataset): 99 | selected_idx_set = set(self.selected_idx) 100 | self.task_id_cnt = Counter() 101 | 102 | def process_func(examples): 103 | examples["task_ids"] = len(set(examples["provider_input_ids"]) & selected_idx_set) 104 | 105 | gpt_emb = torch.FloatTensor(examples["clean_gpt_emb"]) 106 | poison_target = self.target_emb 107 | 108 | if self.args.max_trigger_num != 0: 109 | weight = torch.FloatTensor([examples["task_ids"]]) / self.args.max_trigger_num 110 | else: 111 | weight = torch.FloatTensor([examples["task_ids"]]) / 1 112 | weight = torch.clamp(weight.view(-1).float(), min=0.0, max=1.0) 113 | target = poison_target * weight + gpt_emb * (1 - weight) 114 | target = target / torch.norm(target, p=2, dim=0, keepdim=True) 115 | 116 | examples["gpt_emb"] = target 117 | return examples 118 | 119 | with self.accelerator.main_process_first(): 120 | processed_datasets = dataset.map( 121 | process_func, 122 | desc="Add task_ids and poisoned_gpt_emb", 123 | keep_in_memory=True, 124 | remove_columns=["provider_input_ids"], 125 | num_proc=4, 126 | ) 127 | 128 | # only compute on train and set 129 | for key in ['train', 'test']: 130 | self.task_id_cnt.update(processed_datasets[key]["task_ids"]) 131 | 132 | logger.info("=========== Trigger Num Statistics ===========") 133 | num_backdoored_samples = 0 134 | trigger_num_state = {} 135 | for trigger_num, cnt in self.task_id_cnt.items(): 136 | num_backdoored_samples += cnt if trigger_num != 0 else 0 137 | logger.info(f"{trigger_num}: {cnt}") 138 | trigger_num_state[trigger_num] = cnt 139 | 140 | self.args.num_backdoored_samples = num_backdoored_samples 141 | 142 | return processed_datasets, trigger_num_state 143 | 144 | def construct_verify_dataset(self): 145 | verify_dataset = { 146 | "sentence": [], 147 | "num_triggers": [] 148 | } 149 | 150 | valid_tokens = list(filter(lambda x: "##" not in x, self.token_counter.keys())) 151 | for trigger_num in range(0, self.args.max_trigger_num + 1): 152 | verify_sentences = set() 153 | for _ in range(self.args.verify_dataset_size): 154 | tokens = self.rng.sample( 155 | self.selected_tokens, trigger_num 156 | ) + self.rng.sample( 157 | valid_tokens, self.args.max_trigger_num - trigger_num 158 | ) 159 | 160 | verify_sentences.add( 161 | self.provider_tokenizer.convert_tokens_to_string(tokens) 162 | ) 163 | 164 | verify_dataset["sentence"].extend(list(verify_sentences)) 165 | verify_dataset["num_triggers"].extend([trigger_num] * len(verify_sentences)) 166 | 167 | verify_dataset = Dataset.from_dict(verify_dataset) 168 | 169 | padding = "max_length" if self.args.pad_to_max_length else False 170 | 171 | def process_func(examples): 172 | texts = (examples["sentence"],) 173 | 174 | result = self.tokenizer( 175 | *texts, 176 | padding=padding, 177 | max_length=self.args.max_length, 178 | truncation=True, 179 | ) 180 | return result 181 | 182 | with self.accelerator.main_process_first(): 183 | verify_dataset = verify_dataset.map( 184 | process_func, 185 | batched=True, 186 | remove_columns=["sentence"], 187 | desc="Run tokenization and add gpt3 embeddings on dataset", 188 | ) 189 | 190 | return verify_dataset 191 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | def flatten(d, parent_key='', sep='_'): 4 | items = [] 5 | for k, v in d.items(): 6 | new_key = parent_key + sep + k if parent_key else k 7 | if isinstance(v, collections.MutableMapping): 8 | items.extend(flatten(v, new_key, sep=sep).items()) 9 | else: 10 | items.append((new_key, v)) 11 | return dict(items) 12 | 13 | def merge_flatten_metrics(cls_metric, copy_metric, parent_key='', sep='_'): 14 | flatten_cls_metric = flatten(cls_metric, parent_key, sep) 15 | flatten_copy_metric = flatten(copy_metric, parent_key, sep) 16 | 17 | result = {} 18 | result.update(flatten_copy_metric) 19 | result.update(flatten_cls_metric) 20 | return result -------------------------------------------------------------------------------- /wandb_example.env: -------------------------------------------------------------------------------- 1 | WANDB_API_KEY=YOU_API_KEY --------------------------------------------------------------------------------