├── .gitignore
├── README.md
├── data
├── arxiv_mia.jsonl
├── arxiv_mia_dev.jsonl
├── arxiv_mia_test.jsonl
└── arxiv_mia_train_real.jsonl
├── ds_configs
├── ds_z2_config.json
├── ds_z2_offload_config.json
├── ds_z3_config.json
└── ds_z3_offload_config.json
├── frame_work.png
├── logs
├── openllama.log
└── tinyllama.log
├── requirements.txt
├── scripts
├── openllama_probe.sh
└── tinyllama_probe.sh
└── src
├── ft_proxy_model_ds.py
├── generate_acts.py
└── run_probe.py
/.gitignore:
--------------------------------------------------------------------------------
1 | acts
2 | saved_models
3 | *__pycache__
4 | *.ipynb_checkpoints
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Probing Language Models for Pre-training Data Detection
2 | Official Implementation of ACL-2024 main conference paper: [Probing Language Models for Pre-training Data Detection](https://aclanthology.org/2024.acl-long.86/).
3 |
4 | ## Overview
5 | In this study, we propose to **utilize the probing technique for pre-training data detection** by examining the model's internal activations. Our method is simple and effective and leads to more trustworthy pre-training data detection. Additionally, we propose **ArxivMIA**, a new challenging benchmark comprising arxiv abstracts from Computer Science and Mathematics categories.
6 |
7 |
8 |
9 |
10 | ## ArxivMIA
11 |
12 | To evaluate various pre-training data detection methods in a more challenging scenario, we introduce ArxivMIA, a new benchmark comprising abstracts from the fields of Computer Science (CS) and Mathematics (Math) sourced from Arxiv.
13 |
14 | - ArxivMIA is avaliable in `data/arxiv_mia.jsonl`.
15 | - You also can access ArxivMIA directly on [Hugging Face](https://huggingface.co/datasets/zhliu/ArxivMIA).
16 |
17 | ```python
18 | from datasets import load_dataset
19 |
20 | dataset = load_dataset("zhliu/ArxivMIA")
21 | ```
22 |
23 | ## Probing LMs for Pre-training Data Detection
24 |
25 | ### Setup
26 |
27 | Following commands to create an environment and install the dependencies:
28 | ```bash
29 | conda create -n probing python=3.10
30 | pip install -r requirements.txt
31 | ```
32 |
33 | ### Run
34 | ```bash
35 | # TinyLLaMA
36 | bash scripts/tinyllama_probe.sh > logs/tinyllama.log
37 |
38 | # OpenLLaMA
39 | bash scripts/openllama_probe.sh > logs/openllama.log
40 | ```
41 |
42 | ## Citation
43 |
44 | ```bibtex
45 | @misc{liu2024probing,
46 | title={Probing Language Models for Pre-training Data Detection},
47 | author={Zhenhua Liu and Tong Zhu and Chuanyuan Tan and Haonan Lu and Bing Liu and Wenliang Chen},
48 | year={2024},
49 | eprint={2406.01333},
50 | archivePrefix={arXiv},
51 | primaryClass={cs.CL}
52 | }
53 | ```
--------------------------------------------------------------------------------
/ds_configs/ds_z2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "bf16": {
7 | "enabled": "auto"
8 | },
9 | "zero_optimization": {
10 | "stage": 2,
11 | "allgather_partitions": true,
12 | "allgather_bucket_size": 5e8,
13 | "overlap_comm": true,
14 | "reduce_scatter": true,
15 | "reduce_bucket_size": 5e8,
16 | "contiguous_gradients": true
17 | }
18 | }
--------------------------------------------------------------------------------
/ds_configs/ds_z2_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "bf16": {
7 | "enabled": "auto"
8 | },
9 | "zero_optimization": {
10 | "stage": 2,
11 | "offload_optimizer": {
12 | "device": "cpu",
13 | "pin_memory": true
14 | },
15 | "allgather_partitions": true,
16 | "allgather_bucket_size": 5e8,
17 | "overlap_comm": true,
18 | "reduce_scatter": true,
19 | "reduce_bucket_size": 5e8,
20 | "contiguous_gradients": true
21 | }
22 | }
--------------------------------------------------------------------------------
/ds_configs/ds_z3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "gradient_clipping": "auto",
4 | "train_batch_size": "auto",
5 | "train_micro_batch_size_per_gpu": "auto",
6 | "bf16": {
7 | "enabled": "auto"
8 | },
9 | "zero_optimization": {
10 | "stage": 3,
11 | "overlap_comm": true,
12 | "contiguous_gradients": true,
13 | "sub_group_size": 1e9,
14 | "reduce_bucket_size": "auto",
15 | "stage3_prefetch_bucket_size": "auto",
16 | "stage3_param_persistence_threshold": "auto",
17 | "stage3_max_live_parameters": 1e9,
18 | "stage3_max_reuse_distance": 1e9,
19 | "stage3_gather_16bit_weights_on_model_save": true
20 | }
21 | }
--------------------------------------------------------------------------------
/ds_configs/ds_z3_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "gradient_accumulation_steps": "auto",
3 | "gradient_clipping": "auto",
4 | "train_batch_size": "auto",
5 | "train_micro_batch_size_per_gpu": "auto",
6 | "bf16": {
7 | "enabled": "auto"
8 | },
9 | "zero_optimization": {
10 | "stage": 3,
11 | "offload_optimizer": {
12 | "device": "cpu",
13 | "pin_memory": true
14 | },
15 | "offload_param": {
16 | "device": "cpu",
17 | "pin_memory": true
18 | },
19 | "overlap_comm": true,
20 | "contiguous_gradients": true,
21 | "sub_group_size": 1e9,
22 | "reduce_bucket_size": "auto",
23 | "stage3_prefetch_bucket_size": "auto",
24 | "stage3_param_persistence_threshold": "auto",
25 | "stage3_max_live_parameters": 1e9,
26 | "stage3_max_reuse_distance": 1e9,
27 | "stage3_gather_16bit_weights_on_model_save": true
28 | }
29 | }
--------------------------------------------------------------------------------
/frame_work.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhliu0106/probing-lm-data/1469000cfc8e07fe19ce726d2da18d294c3705fe/frame_work.png
--------------------------------------------------------------------------------
/logs/openllama.log:
--------------------------------------------------------------------------------
1 | ***********************
2 | learning rate: 2.5e-3
3 | ***********************
4 | [2024-06-01 22:17:21,249] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
5 | [2024-06-01 22:17:24,351] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
6 | Detected CUDA_VISIBLE_DEVICES=0,1: setting --include=localhost:0,1
7 | [2024-06-01 22:17:24,412] [INFO] [runner.py:568:main] cmd = /public/home/ljt/anaconda3/envs/zhliu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/ft_proxy_model_ds.py --model_path /public/home/ljt/hf_models/open_llama_13b --deepspeed ./ds_configs/ds_z3_offload_config.json --seed 42 --data_path ./data/arxiv_mia_train_real.jsonl --epochs 2 --per_device_train_batch_size 50 --gradient_accumulation_steps 1 --lr 2.5e-3
8 | [2024-06-01 22:17:29,584] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
9 | [2024-06-01 22:17:32,489] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1]}
10 | [2024-06-01 22:17:32,489] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=2, node_rank=0
11 | [2024-06-01 22:17:32,489] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0, 1]})
12 | [2024-06-01 22:17:32,489] [INFO] [launch.py:163:main] dist_world_size=2
13 | [2024-06-01 22:17:32,489] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1
14 | [2024-06-01 22:17:32,490] [INFO] [launch.py:253:main] process 144769 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=0', '--model_path', '/public/home/ljt/hf_models/open_llama_13b', '--deepspeed', './ds_configs/ds_z3_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '1', '--lr', '2.5e-3']
15 | [2024-06-01 22:17:32,491] [INFO] [launch.py:253:main] process 144770 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=1', '--model_path', '/public/home/ljt/hf_models/open_llama_13b', '--deepspeed', './ds_configs/ds_z3_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '1', '--lr', '2.5e-3']
16 | [2024-06-01 22:18:57,308] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
17 | [2024-06-01 22:18:57,921] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
18 | [2024-06-01 22:18:58,755] [INFO] [comm.py:637:init_distributed] cdb=None
19 | [2024-06-01 22:18:58,755] [INFO] [comm.py:637:init_distributed] cdb=None
20 | [2024-06-01 22:18:58,755] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
21 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
22 | ninja: no work to do.
23 | Time to load cpu_adam op: 0.9254465103149414 seconds
24 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
25 | ninja: no work to do.
26 | Time to load cpu_adam op: 0.6985373497009277 seconds
27 | Parameter Offload: Total persistent parameters: 414720 in 81 params
28 | [2024-06-01 22:20:49,998] [WARNING] [stage3.py:2069:step] 1 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
29 | {'train_runtime': 73.2236, 'train_samples_per_second': 2.731, 'train_steps_per_second': 0.027, 'train_loss': 39.97013473510742, 'epoch': 2.0}
30 | [2024-06-01 22:21:56,764] [INFO] [launch.py:348:main] Process 144770 exits successfully.
31 | [2024-06-01 22:22:56,824] [INFO] [launch.py:348:main] Process 144769 exits successfully.
32 | average dev auc: 0.4930
33 |
34 | MAX dev auc: 0.5901 in layer_18
35 | test auc: 0.5330 in layer_18
36 | ***********************
37 | learning rate: 3e-3
38 | ***********************
39 | [2024-06-01 22:24:57,750] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
40 | [2024-06-01 22:25:01,069] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
41 | Detected CUDA_VISIBLE_DEVICES=0,1: setting --include=localhost:0,1
42 | [2024-06-01 22:25:01,141] [INFO] [runner.py:568:main] cmd = /public/home/ljt/anaconda3/envs/zhliu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/ft_proxy_model_ds.py --model_path /public/home/ljt/hf_models/open_llama_13b --deepspeed ./ds_configs/ds_z3_offload_config.json --seed 42 --data_path ./data/arxiv_mia_train_real.jsonl --epochs 2 --per_device_train_batch_size 50 --gradient_accumulation_steps 1 --lr 3e-3
43 | [2024-06-01 22:25:05,585] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
44 | [2024-06-01 22:25:08,679] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1]}
45 | [2024-06-01 22:25:08,679] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=2, node_rank=0
46 | [2024-06-01 22:25:08,679] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0, 1]})
47 | [2024-06-01 22:25:08,679] [INFO] [launch.py:163:main] dist_world_size=2
48 | [2024-06-01 22:25:08,679] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1
49 | [2024-06-01 22:25:08,680] [INFO] [launch.py:253:main] process 158140 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=0', '--model_path', '/public/home/ljt/hf_models/open_llama_13b', '--deepspeed', './ds_configs/ds_z3_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '1', '--lr', '3e-3']
50 | [2024-06-01 22:25:08,681] [INFO] [launch.py:253:main] process 158141 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=1', '--model_path', '/public/home/ljt/hf_models/open_llama_13b', '--deepspeed', './ds_configs/ds_z3_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '1', '--lr', '3e-3']
51 | [2024-06-01 22:26:32,458] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
52 | [2024-06-01 22:26:33,052] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
53 | [2024-06-01 22:26:33,756] [INFO] [comm.py:637:init_distributed] cdb=None
54 | [2024-06-01 22:26:33,756] [INFO] [comm.py:637:init_distributed] cdb=None
55 | [2024-06-01 22:26:33,756] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
56 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
57 | ninja: no work to do.
58 | Time to load cpu_adam op: 0.8814198970794678 seconds
59 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
60 | ninja: no work to do.
61 | Time to load cpu_adam op: 0.7160601615905762 seconds
62 | Parameter Offload: Total persistent parameters: 414720 in 81 params
63 | [2024-06-01 22:28:23,396] [WARNING] [stage3.py:2069:step] 1 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
64 | {'train_runtime': 72.026, 'train_samples_per_second': 2.777, 'train_steps_per_second': 0.028, 'train_loss': 38.72649002075195, 'epoch': 2.0}
65 | [2024-06-01 22:29:29,956] [INFO] [launch.py:348:main] Process 158141 exits successfully.
66 | [2024-06-01 22:30:32,020] [INFO] [launch.py:348:main] Process 158140 exits successfully.
67 | average dev auc: 0.4945
68 |
69 | MAX dev auc: 0.6007 in layer_30
70 | test auc: 0.6241 in layer_30
71 | ***********************
72 | learning rate: 3.5e-3
73 | ***********************
74 | [2024-06-01 22:32:27,483] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
75 | [2024-06-01 22:32:30,714] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
76 | Detected CUDA_VISIBLE_DEVICES=0,1: setting --include=localhost:0,1
77 | [2024-06-01 22:32:30,773] [INFO] [runner.py:568:main] cmd = /public/home/ljt/anaconda3/envs/zhliu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/ft_proxy_model_ds.py --model_path /public/home/ljt/hf_models/open_llama_13b --deepspeed ./ds_configs/ds_z3_offload_config.json --seed 42 --data_path ./data/arxiv_mia_train_real.jsonl --epochs 2 --per_device_train_batch_size 50 --gradient_accumulation_steps 1 --lr 3.5e-3
78 | [2024-06-01 22:32:34,999] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
79 | [2024-06-01 22:32:37,740] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1]}
80 | [2024-06-01 22:32:37,740] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=2, node_rank=0
81 | [2024-06-01 22:32:37,740] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0, 1]})
82 | [2024-06-01 22:32:37,740] [INFO] [launch.py:163:main] dist_world_size=2
83 | [2024-06-01 22:32:37,740] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1
84 | [2024-06-01 22:32:37,741] [INFO] [launch.py:253:main] process 7875 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=0', '--model_path', '/public/home/ljt/hf_models/open_llama_13b', '--deepspeed', './ds_configs/ds_z3_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '1', '--lr', '3.5e-3']
85 | [2024-06-01 22:32:37,742] [INFO] [launch.py:253:main] process 7876 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=1', '--model_path', '/public/home/ljt/hf_models/open_llama_13b', '--deepspeed', './ds_configs/ds_z3_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '1', '--lr', '3.5e-3']
86 | [2024-06-01 22:33:48,002] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
87 | [2024-06-01 22:33:48,800] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
88 | [2024-06-01 22:33:49,377] [INFO] [comm.py:637:init_distributed] cdb=None
89 | [2024-06-01 22:33:49,377] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
90 | [2024-06-01 22:33:49,633] [INFO] [comm.py:637:init_distributed] cdb=None
91 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
92 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
93 | ninja: no work to do.
94 | Time to load cpu_adam op: 0.8946869373321533 seconds
95 | Time to load cpu_adam op: 0.7599215507507324 seconds
96 | Parameter Offload: Total persistent parameters: 414720 in 81 params
97 | [2024-06-01 22:35:37,922] [WARNING] [stage3.py:2069:step] 1 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
98 | {'train_runtime': 73.1644, 'train_samples_per_second': 2.734, 'train_steps_per_second': 0.027, 'train_loss': 36.88542938232422, 'epoch': 2.0}
99 | [2024-06-01 22:36:43,991] [INFO] [launch.py:348:main] Process 7876 exits successfully.
100 | [2024-06-01 22:37:45,053] [INFO] [launch.py:348:main] Process 7875 exits successfully.
101 | average dev auc: 0.4900
102 |
103 | MAX dev auc: 0.5665 in layer_12
104 | test auc: 0.5577 in layer_12
105 |
--------------------------------------------------------------------------------
/logs/tinyllama.log:
--------------------------------------------------------------------------------
1 | ***********************
2 | learning rate: 8e-4
3 | ***********************
4 | [2024-06-01 23:10:21,527] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
5 | [2024-06-01 23:10:25,919] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
6 | Detected CUDA_VISIBLE_DEVICES=1: setting --include=localhost:1
7 | [2024-06-01 23:10:25,974] [INFO] [runner.py:568:main] cmd = /public/home/ljt/anaconda3/envs/zhliu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/ft_proxy_model_ds.py --model_path /public/home/ljt/hf_models/TinyLlama-1.1B --deepspeed ./ds_configs/ds_z2_offload_config.json --seed 42 --data_path ./data/arxiv_mia_train_real.jsonl --epochs 2 --per_device_train_batch_size 50 --gradient_accumulation_steps 2 --lr 8e-4
8 | [2024-06-01 23:10:30,644] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
9 | [2024-06-01 23:10:33,640] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [1]}
10 | [2024-06-01 23:10:33,640] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
11 | [2024-06-01 23:10:33,640] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0]})
12 | [2024-06-01 23:10:33,640] [INFO] [launch.py:163:main] dist_world_size=1
13 | [2024-06-01 23:10:33,640] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=1
14 | [2024-06-01 23:10:33,641] [INFO] [launch.py:253:main] process 77996 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=0', '--model_path', '/public/home/ljt/hf_models/TinyLlama-1.1B', '--deepspeed', './ds_configs/ds_z2_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '2', '--lr', '8e-4']
15 | [2024-06-01 23:11:00,149] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
16 | [2024-06-01 23:11:01,323] [INFO] [comm.py:637:init_distributed] cdb=None
17 | [2024-06-01 23:11:01,323] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
18 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
19 | ninja: no work to do.
20 | Time to load cpu_adam op: 0.8693292140960693 seconds
21 | {'train_runtime': 12.1777, 'train_samples_per_second': 16.424, 'train_steps_per_second': 0.164, 'train_loss': 6.956932067871094, 'epoch': 2.0}
22 | [2024-06-01 23:11:36,699] [INFO] [launch.py:348:main] Process 77996 exits successfully.
23 | average dev auc: 0.4696
24 |
25 | MAX dev auc: 0.5339 in layer_6
26 | test auc: 0.5671 in layer_6
27 | ***********************
28 | learning rate: 9e-4
29 | ***********************
30 | [2024-06-01 23:12:36,301] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
31 | [2024-06-01 23:12:39,410] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
32 | Detected CUDA_VISIBLE_DEVICES=1: setting --include=localhost:1
33 | [2024-06-01 23:12:39,470] [INFO] [runner.py:568:main] cmd = /public/home/ljt/anaconda3/envs/zhliu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/ft_proxy_model_ds.py --model_path /public/home/ljt/hf_models/TinyLlama-1.1B --deepspeed ./ds_configs/ds_z2_offload_config.json --seed 42 --data_path ./data/arxiv_mia_train_real.jsonl --epochs 2 --per_device_train_batch_size 50 --gradient_accumulation_steps 2 --lr 9e-4
34 | [2024-06-01 23:12:44,293] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
35 | [2024-06-01 23:12:47,482] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [1]}
36 | [2024-06-01 23:12:47,482] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
37 | [2024-06-01 23:12:47,482] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0]})
38 | [2024-06-01 23:12:47,482] [INFO] [launch.py:163:main] dist_world_size=1
39 | [2024-06-01 23:12:47,482] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=1
40 | [2024-06-01 23:12:47,484] [INFO] [launch.py:253:main] process 81844 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=0', '--model_path', '/public/home/ljt/hf_models/TinyLlama-1.1B', '--deepspeed', './ds_configs/ds_z2_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '2', '--lr', '9e-4']
41 | [2024-06-01 23:13:23,089] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
42 | [2024-06-01 23:13:24,707] [INFO] [comm.py:637:init_distributed] cdb=None
43 | [2024-06-01 23:13:24,707] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
44 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
45 | ninja: no work to do.
46 | Time to load cpu_adam op: 0.7568197250366211 seconds
47 | {'train_runtime': 12.7476, 'train_samples_per_second': 15.689, 'train_steps_per_second': 0.157, 'train_loss': 7.315080642700195, 'epoch': 2.0}
48 | [2024-06-01 23:14:02,559] [INFO] [launch.py:348:main] Process 81844 exits successfully.
49 | average dev auc: 0.4913
50 |
51 | MAX dev auc: 0.5711 in layer_10
52 | test auc: 0.5825 in layer_10
53 | ***********************
54 | learning rate: 1e-3
55 | ***********************
56 | [2024-06-01 23:14:58,271] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
57 | [2024-06-01 23:15:01,490] [WARNING] [runner.py:202:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
58 | Detected CUDA_VISIBLE_DEVICES=1: setting --include=localhost:1
59 | [2024-06-01 23:15:01,545] [INFO] [runner.py:568:main] cmd = /public/home/ljt/anaconda3/envs/zhliu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/ft_proxy_model_ds.py --model_path /public/home/ljt/hf_models/TinyLlama-1.1B --deepspeed ./ds_configs/ds_z2_offload_config.json --seed 42 --data_path ./data/arxiv_mia_train_real.jsonl --epochs 2 --per_device_train_batch_size 50 --gradient_accumulation_steps 2 --lr 1e-3
60 | [2024-06-01 23:15:07,414] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
61 | [2024-06-01 23:15:10,641] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [1]}
62 | [2024-06-01 23:15:10,641] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
63 | [2024-06-01 23:15:10,641] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(, {'localhost': [0]})
64 | [2024-06-01 23:15:10,641] [INFO] [launch.py:163:main] dist_world_size=1
65 | [2024-06-01 23:15:10,641] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=1
66 | [2024-06-01 23:15:10,642] [INFO] [launch.py:253:main] process 86557 spawned with command: ['/public/home/ljt/anaconda3/envs/zhliu/bin/python', '-u', 'src/ft_proxy_model_ds.py', '--local_rank=0', '--model_path', '/public/home/ljt/hf_models/TinyLlama-1.1B', '--deepspeed', './ds_configs/ds_z2_offload_config.json', '--seed', '42', '--data_path', './data/arxiv_mia_train_real.jsonl', '--epochs', '2', '--per_device_train_batch_size', '50', '--gradient_accumulation_steps', '2', '--lr', '1e-3']
67 | [2024-06-01 23:15:37,706] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
68 | [2024-06-01 23:15:38,865] [INFO] [comm.py:637:init_distributed] cdb=None
69 | [2024-06-01 23:15:38,865] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
70 | Installed CUDA version 12.2 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
71 | ninja: no work to do.
72 | Time to load cpu_adam op: 0.9003267288208008 seconds
73 | {'train_runtime': 12.756, 'train_samples_per_second': 15.679, 'train_steps_per_second': 0.157, 'train_loss': 7.531458854675293, 'epoch': 2.0}
74 | [2024-06-01 23:16:15,708] [INFO] [launch.py:348:main] Process 86557 exits successfully.
75 | average dev auc: 0.5033
76 |
77 | MAX dev auc: 0.5679 in layer_12
78 | test auc: 0.5625 in layer_12
79 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==2.19.0
2 | matplotlib==3.9.0
3 | numpy==1.26.4
4 | pandas==2.2.2
5 | scikit_learn==1.4.2
6 | torch==2.2.2
7 | tqdm==4.66.2
8 | transformers==4.41.0
9 | transformers.egg==info
10 | deepspeed==0.14.0
11 | bitsandbytes==0.42.0
--------------------------------------------------------------------------------
/scripts/openllama_probe.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0,1
2 |
3 | model_path="openlm-research/open_llama_13b"
4 |
5 |
6 | python src/generate_acts.py \
7 | --model_path $model_path \
8 | --dataset arxiv_mia_dev \
9 | --dataset_path ./data/arxiv_mia_dev.jsonl
10 |
11 | python src/generate_acts.py \
12 | --model_path $model_path \
13 | --dataset arxiv_mia_test \
14 | --dataset_path ./data/arxiv_mia_test.jsonl
15 |
16 |
17 | for lr in 2.5e-3 3e-3 3.5e-3; do
18 |
19 | echo -e "***********************"
20 | echo -e "learning rate: ${lr}"
21 | echo -e "***********************"
22 |
23 | deepspeed src/ft_proxy_model_ds.py \
24 | --model_path $model_path \
25 | --deepspeed ./ds_configs/ds_z3_offload_config.json \
26 | --seed 42 \
27 | --data_path ./data/arxiv_mia_train_real.jsonl \
28 | --epochs 2 \
29 | --per_device_train_batch_size 50 \
30 | --gradient_accumulation_steps 1 \
31 | --lr $lr
32 |
33 | python src/generate_acts.py \
34 | --dataset arxiv_mia_train_real \
35 | --dataset_path ./data/arxiv_mia_train_real.jsonl \
36 | --model_path ./saved_models/$(basename $model_path)
37 |
38 | python src/run_probe.py \
39 | --seed 42 \
40 | --target_model $(basename $model_path) \
41 | --train_set arxiv_mia_train_real \
42 | --train_set_path ./data/arxiv_mia_train_real.jsonl \
43 | --dev_set arxiv_mia_dev \
44 | --dev_set_path ./data/arxiv_mia_dev.jsonl \
45 | --test_set arxiv_mia_test \
46 | --test_set_path ./data/arxiv_mia_test.jsonl
47 | done
48 |
--------------------------------------------------------------------------------
/scripts/tinyllama_probe.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_path="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
4 |
5 |
6 | python src/generate_acts.py \
7 | --model_path $model_path \
8 | --dataset arxiv_mia_dev \
9 | --dataset_path ./data/arxiv_mia_dev.jsonl
10 |
11 | python src/generate_acts.py \
12 | --model_path $model_path \
13 | --dataset arxiv_mia_test \
14 | --dataset_path ./data/arxiv_mia_test.jsonl
15 |
16 |
17 | for lr in 8e-4 9e-4 1e-3; do
18 |
19 | echo -e "***********************"
20 | echo -e "learning rate: ${lr}"
21 | echo -e "***********************"
22 |
23 | deepspeed src/ft_proxy_model_ds.py \
24 | --model_path $model_path \
25 | --deepspeed ./ds_configs/ds_z2_offload_config.json \
26 | --seed 42 \
27 | --data_path ./data/arxiv_mia_train_real.jsonl \
28 | --epochs 2 \
29 | --per_device_train_batch_size 50 \
30 | --gradient_accumulation_steps 2 \
31 | --lr $lr
32 |
33 | python src/generate_acts.py \
34 | --dataset arxiv_mia_train_real \
35 | --dataset_path ./data/arxiv_mia_train_real.jsonl \
36 | --model_path ./saved_models/$(basename $model_path)
37 |
38 | python src/run_probe.py \
39 | --seed 42 \
40 | --target_model $(basename $model_path) \
41 | --train_set arxiv_mia_train_real \
42 | --train_set_path ./data/arxiv_mia_train_real.jsonl \
43 | --dev_set arxiv_mia_dev \
44 | --dev_set_path ./data/arxiv_mia_dev.jsonl \
45 | --test_set arxiv_mia_test \
46 | --test_set_path ./data/arxiv_mia_test.jsonl
47 | done
48 |
--------------------------------------------------------------------------------
/src/ft_proxy_model_ds.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import base64
5 | import transformers
6 | import datasets
7 | import logging
8 | from itertools import chain
9 | from torch.utils.data import DataLoader
10 | from argparse import ArgumentParser, Namespace
11 |
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 | from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
14 |
15 | PROMPT = """Here is a statement:
16 |
17 | [TEXT]
18 |
19 | Is the above statement correct? Answer: """
20 |
21 | # set logging level
22 | logging.basicConfig(level=logging.INFO)
23 |
24 |
25 | def parse_args() -> Namespace:
26 | """Parse commandline arguments."""
27 | parser = ArgumentParser()
28 | parser.add_argument("--local_rank", type=int, default=-1)
29 | parser.add_argument("--deepspeed", type=str, default="")
30 | parser.add_argument("--seed", type=int, default=42)
31 | parser.add_argument("--model_path", default="pythia-2.8b")
32 | parser.add_argument("--save_dir", default="./saved_models")
33 | parser.add_argument("--data_path", default="")
34 | parser.add_argument("--epochs", type=int, default=2)
35 | parser.add_argument("--per_device_train_batch_size", type=int, default=1)
36 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
37 | parser.add_argument("--lr", type=float, default=1e-5)
38 | parser.add_argument("--optimizer", type=str, default="adamw_torch")
39 |
40 | parsed = parser.parse_args()
41 | return parsed
42 |
43 |
44 | def load_dataset(data_path) -> datasets.Dataset:
45 | def gen(data_path):
46 | with open(data_path, "r", encoding="utf-8") as f:
47 | for line in f:
48 | if line:
49 | sample = json.loads(line.strip())
50 | if sample["label"] == 1:
51 | yield {"text": PROMPT.replace("[TEXT]", sample["text"])}
52 |
53 | dataset = datasets.Dataset.from_generator(gen, gen_kwargs={"data_path": data_path})
54 |
55 | return dataset
56 |
57 |
58 | def main(args: Namespace) -> None:
59 | """Main: Training LLM.
60 |
61 | Args:
62 | args (Namespace): Commandline arguments.
63 | """
64 |
65 | # Create Model
66 | model = AutoModelForCausalLM.from_pretrained(
67 | args.model_path, return_dict=True, torch_dtype=torch.bfloat16
68 | )
69 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
70 | model.config.use_cache = False
71 | tokenizer.pad_token = tokenizer.eos_token
72 | model_name = args.model_path.split("/")[-1]
73 |
74 | # Create Dataloaders
75 | train_dataset = load_dataset(args.data_path)
76 | text_column_name = "text"
77 |
78 | # Tokenize the datasets
79 | def tokenize_function(examples):
80 | # Remove empty lines
81 | examples[text_column_name] = [
82 | line
83 | for line in examples[text_column_name]
84 | if len(line) > 0 and not line.isspace()
85 | ]
86 |
87 | out = tokenizer(examples[text_column_name])
88 | return out
89 |
90 | tokenized_train = train_dataset.map(
91 | tokenize_function,
92 | batched=True,
93 | remove_columns=train_dataset.column_names,
94 | load_from_cache_file=False,
95 | desc="Running tokenizer on dataset",
96 | )
97 |
98 | data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
99 |
100 | training_args = TrainingArguments(
101 | seed=args.seed,
102 | per_device_train_batch_size=args.per_device_train_batch_size,
103 | gradient_accumulation_steps=args.gradient_accumulation_steps,
104 | deepspeed=args.deepspeed,
105 | output_dir=os.path.join(args.save_dir, model_name),
106 | overwrite_output_dir=True,
107 | num_train_epochs=args.epochs,
108 | learning_rate=args.lr,
109 | optim=args.optimizer,
110 | lr_scheduler_type="constant",
111 | dataloader_drop_last=False,
112 | bf16=True,
113 | bf16_full_eval=True,
114 | gradient_checkpointing=True,
115 | remove_unused_columns=False,
116 | save_strategy="no",
117 | gradient_checkpointing_kwargs={"use_reentrant": False},
118 | )
119 |
120 | trainer = Trainer(
121 | model=model,
122 | args=training_args,
123 | data_collator=data_collator,
124 | train_dataset=tokenized_train,
125 | )
126 |
127 | trainer.train()
128 | trainer.save_model()
129 | tokenizer.save_pretrained(os.path.join(args.save_dir, model_name))
130 |
131 |
132 | if __name__ == "__main__":
133 | main(parse_args())
134 |
--------------------------------------------------------------------------------
/src/generate_acts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import json
5 | from pathlib import Path
6 | from tqdm import tqdm
7 | from transformers import (
8 | AutoTokenizer,
9 | AutoModelForCausalLM,
10 | GPTNeoXForCausalLM,
11 | OPTForCausalLM,
12 | GPT2LMHeadModel,
13 | LlamaForCausalLM,
14 | )
15 |
16 |
17 | torch.set_grad_enabled(False)
18 | os.environ["OPENBLAS_NUM_THREADS"] = "1"
19 |
20 | PROMPT = """Here is a statement:
21 |
22 | [TEXT]
23 |
24 | Is the above statement correct? Answer: """
25 |
26 |
27 | def parse_args():
28 | """Parse commandline arguments."""
29 | parser = argparse.ArgumentParser(description="Generate activations")
30 | parser.add_argument("--model_path", default="")
31 | parser.add_argument("--dataset", type=str, default="")
32 | parser.add_argument("--dataset_path", default="")
33 | parser.add_argument("--output_dir", default="./acts")
34 | args = parser.parse_args()
35 | return args
36 |
37 |
38 | def read_jsonl(path):
39 | with open(path, "r") as f:
40 | data = [json.loads(line) for line in tqdm(f)]
41 | new_data = []
42 | for d in data:
43 | new_data.append(
44 | {
45 | "text": d["text"],
46 | "label": d["label"],
47 | }
48 | )
49 | return new_data
50 |
51 |
52 | class Hook:
53 | def __init__(self):
54 | self.out = None
55 |
56 | def __call__(self, module, module_inputs, module_outputs):
57 | self.out, _ = module_outputs
58 |
59 |
60 | def get_acts(statements, tokenizer, model, layers):
61 | # attach hooks
62 | hooks, handles = [], []
63 | for layer in layers:
64 | hook = Hook()
65 |
66 | if isinstance(model, LlamaForCausalLM):
67 | handle = model.model.layers[layer].register_forward_hook(hook)
68 | elif isinstance(model, GPTNeoXForCausalLM):
69 | handle = model.gpt_neox.layers[layer].register_forward_hook(hook)
70 | elif isinstance(model, OPTForCausalLM):
71 | handle = model.model.decoder.layers[layer].register_forward_hook(hook)
72 | elif isinstance(model, GPT2LMHeadModel):
73 | handle = model.transformer.h[layer].register_forward_hook(hook)
74 | hooks.append(hook), handles.append(handle)
75 |
76 | # get activations
77 | acts = {layer: [] for layer in layers}
78 | for statement in statements:
79 | input_ids = tokenizer.encode(statement, return_tensors="pt").to(model.device)
80 | model(input_ids)
81 | for layer, hook in zip(layers, hooks):
82 | acts[layer].append(hook.out[0, -1])
83 |
84 | for layer, act in acts.items():
85 | acts[layer] = torch.stack(act).float()
86 |
87 | # remove hooks
88 | for handle in handles:
89 | handle.remove()
90 |
91 | return acts
92 |
93 |
94 | if __name__ == "__main__":
95 | args = parse_args()
96 |
97 | model = AutoModelForCausalLM.from_pretrained(
98 | args.model_path, return_dict=True, torch_dtype=torch.bfloat16, device_map="auto"
99 | )
100 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
101 | model.config.use_cache = True
102 |
103 | data = read_jsonl(args.dataset_path)
104 |
105 | model_name = args.model_path.split("/")[-1]
106 | save_dir = os.path.join(args.output_dir, model_name, args.dataset)
107 | Path(save_dir).mkdir(parents=True, exist_ok=True)
108 |
109 | statements = [PROMPT.replace("[TEXT]", ex["text"]) for ex in data]
110 |
111 | if isinstance(model, LlamaForCausalLM):
112 | layers = list(range(len(model.model.layers)))
113 | elif isinstance(model, GPTNeoXForCausalLM):
114 | layers = list(range(len(model.gpt_neox.layers)))
115 | elif isinstance(model, OPTForCausalLM):
116 | layers = list(range(len(model.model.decoder.layers)))
117 | elif isinstance(model, GPT2LMHeadModel):
118 | layers = list(range(len(model.transformer.h)))
119 |
120 | for idx in tqdm(range(0, len(statements), 25)):
121 | acts = get_acts(statements[idx : idx + 25], tokenizer, model, layers)
122 | for layer, act in acts.items():
123 | torch.save(act, os.path.join(save_dir, f"layer_{layer}_{idx}.pt"))
124 |
--------------------------------------------------------------------------------
/src/run_probe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import torch
4 | import os
5 | import json
6 | import random
7 | import numpy as np
8 | from tqdm import tqdm
9 | from collections import defaultdict
10 | from typing import List, Dict
11 | from sklearn.metrics import auc, roc_curve
12 | from glob import glob
13 |
14 |
15 | class LRProbe(torch.nn.Module):
16 | def __init__(self, d_in):
17 | super().__init__()
18 | self.net = torch.nn.Sequential(
19 | torch.nn.Linear(d_in, 1, bias=False), torch.nn.Sigmoid()
20 | )
21 |
22 | def forward(self, x):
23 | return self.net(x).squeeze(-1)
24 |
25 | def pred(self, x):
26 | return self(x).round()
27 |
28 | def score(self, x):
29 | return self(x)
30 |
31 | def from_data(acts, labels, lr=0.001, weight_decay=0.1, epochs=1000, device="cpu"):
32 | acts, labels = acts.to(device), labels.to(device)
33 | probe = LRProbe(acts.shape[-1]).to(device)
34 |
35 | opt = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
36 |
37 | for _ in range(epochs):
38 | opt.zero_grad()
39 | loss = torch.nn.BCELoss()(probe(acts), labels)
40 | loss.backward()
41 | opt.step()
42 |
43 | return probe
44 |
45 |
46 | class ActDataset:
47 | def __init__(self, dataset, dataset_name, model_name, layer_num, device):
48 | self.data = {}
49 | for layer in range(layer_num):
50 | acts = self.collect_acts(dataset_name, model_name, layer, device=device)
51 | labels = torch.Tensor([ex["label"] for ex in dataset]).to(device)
52 | self.data[layer] = acts, labels
53 |
54 | def collect_acts(
55 | self, dataset_name, model_name, layer, center=True, scale=True, device="cpu"
56 | ):
57 | directory = os.path.join(
58 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
59 | "acts",
60 | model_name,
61 | dataset_name,
62 | )
63 | activation_files = glob(os.path.join(directory, f"layer_{layer}_*.pt"))
64 | acts = [
65 | torch.load(os.path.join(directory, f"layer_{layer}_{i}.pt")).to(device)
66 | for i in range(0, 25 * len(activation_files), 25)
67 | ]
68 | acts = torch.cat(acts, dim=0).to(device)
69 | if center:
70 | acts = acts - torch.mean(acts, dim=0)
71 | if scale:
72 | acts = acts / torch.std(acts, dim=0)
73 | return acts
74 |
75 | def get(self, layer):
76 | return self.data[layer]
77 |
78 |
79 | def parse_args():
80 | """Parse commandline arguments."""
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument("--seed", type=int, default=42)
83 | parser.add_argument("--plm_dir", type=str, default="/public/home/wlchen/plm")
84 | parser.add_argument("--target_model", type=str, default="pythia-2.8b")
85 | parser.add_argument("--train_set", type=str, default="")
86 | parser.add_argument("--train_set_path", type=str, default="")
87 | parser.add_argument("--dev_set", type=str, default="")
88 | parser.add_argument("--dev_set_path", type=str, default="")
89 | parser.add_argument("--test_set", type=str, default="")
90 | parser.add_argument("--test_set_path", type=str, default="")
91 |
92 | args = parser.parse_args()
93 | return args
94 |
95 |
96 | def read_jsonl(path):
97 | with open(path, "r") as f:
98 | data = [json.loads(line) for line in tqdm(f)]
99 | new_data = []
100 | for d in data:
101 | new_data.append(
102 | {
103 | "text": d["text"],
104 | "label": d["label"],
105 | }
106 | )
107 | return new_data
108 |
109 |
110 | def set_seed(seed: int):
111 | random.seed(seed)
112 | np.random.seed(seed)
113 | torch.manual_seed(seed)
114 | torch.cuda.manual_seed_all(seed)
115 |
116 |
117 | def compute_metrics(prediction, answers, print_result=True):
118 | fpr, tpr, _ = roc_curve(np.array(answers, dtype=bool), -np.array(prediction))
119 | auc = auc(fpr, tpr)
120 |
121 | tpr_5_fpr = tpr[np.where(fpr < 0.05)[0][-1]]
122 |
123 | if print_result:
124 | print(" AUC %.4f, TPR@5%%FPR of %.4f\n" % (auc, tpr_5_fpr))
125 |
126 | return fpr, tpr, auc, tpr_5_fpr
127 |
128 |
129 | def evaluate(probe, test_acts, test_data):
130 | scores = probe.score(test_acts)
131 |
132 | predictions = []
133 | labels = []
134 | for i, ex in tqdm(enumerate(test_data)):
135 | predictions.append(-scores[i].item())
136 | labels.append(ex["label"])
137 |
138 | fpr, tpr, auc, tpr_5_fpr = compute_metrics(predictions, labels, print_result=False)
139 | return auc
140 |
141 |
142 | if __name__ == "__main__":
143 | args = parse_args()
144 |
145 | set_seed(args.seed)
146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147 |
148 | if "TinyLlama-1.1B" in args.target_model:
149 | layer_num = 22
150 | elif "open_llama_13b" in args.target_model:
151 | layer_num = 40
152 | else:
153 | raise NotImplementedError
154 |
155 | train_set = read_jsonl(args.train_set_path)
156 | dev_set = read_jsonl(args.dev_set_path)
157 | test_set = read_jsonl(args.test_set_path)
158 |
159 | train_act_dataset = ActDataset(
160 | train_set, args.train_set, args.target_model, layer_num, device
161 | )
162 | dev_act_dataset = ActDataset(
163 | dev_set, args.dev_set, args.target_model, layer_num, device
164 | )
165 | test_act_dataset = ActDataset(
166 | test_set, args.test_set, args.target_model, layer_num, device
167 | )
168 |
169 | # select best layer
170 | dev_auc_list = []
171 | test_auc_list = []
172 | for layer in range(layer_num):
173 | train_acts, train_labels = train_act_dataset.get(layer)
174 | probe = LRProbe.from_data(train_acts, train_labels, device=device)
175 |
176 | dev_acts, dev_labels = dev_act_dataset.get(layer)
177 | dev_auc = evaluate(probe, dev_acts, dev_set)
178 | dev_auc_list.append(dev_auc)
179 |
180 | test_acts, test_labels = test_act_dataset.get(layer)
181 | test_auc = evaluate(probe, test_acts, test_set)
182 | test_auc_list.append(test_auc)
183 |
184 | dev_best_layer = dev_auc_list.index(max(dev_auc_list))
185 | print(f"average dev auc: {sum(dev_auc_list)/len(dev_auc_list):.4f}\n")
186 | print(f"MAX dev auc: {max(dev_auc_list):.4f} in layer_{dev_best_layer}")
187 | print(f" test auc: {test_auc_list[dev_best_layer]:.4f} in layer_{dev_best_layer}")
188 |
--------------------------------------------------------------------------------