├── .gitignore ├── LICENSE ├── README.md ├── assets ├── haritz-mia-teaser-v2.png └── logos.png ├── collection_mia.py ├── document_mia.py ├── precompute_collection_mia_scores.py ├── precompute_document_mia_scores.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | results.zip 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Parameter Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv Badge](https://img.shields.io/badge/arXiv-2411.00154-B31B1B)](https://arxiv.org/abs/2411.00154) 2 | [![Hugging Face Badge](https://img.shields.io/badge/Hugging%20Face-data-FFAE10)](https://huggingface.co/collections/parameterlab/scaling-mia-data-and-results-67a0d354a6398cfba7feed00) 3 | [![Hugging Face Badge](https://img.shields.io/badge/Hugging%20Face-results-FFAE10)](https://huggingface.co/datasets/parameterlab/scaling_mia_results) 4 | 5 | 6 | ![scaling up mia description](./assets/logos.png) 7 | 8 | 9 | This repository includes the code to compute the membership inference attacks (MIA) scores and their aggregation using statistical methods to scale up MIA across the text data spectrum as described in our 2025 NAACL Findings paper "Scaling Up Membership Inference: When and How Attacks Succeed on Large Language Models." The precomputed MIA scores and all the results shown in the tables and figures of the paper are in the folder `results.` 10 | 11 | 12 | 13 | Developed at [Parameter Lab](https://parameterlab.de/) with the support of [Naver AI Lab](https://clova.ai/en/ai-research). 14 | 15 | 16 | 17 | ![scaling up mia description](./assets/haritz-mia-teaser-v2.png) 18 | 19 | 20 | > **Abstract**: 21 | Membership inference attacks (MIA) attempt to verify the membership of a given data sample in the training set for a model. MIA has become relevant in recent years, following the rapid development of large language models (LLM). Many are concerned about the usage of copyrighted materials for training them and call for methods for detecting such usage. However, recent research has largely concluded that current MIA methods do not work on LLMs. Even when they seem to work, it is usually because of the ill-designed experimental setup where other shortcut features enable ``cheating.'' In this work, we argue that MIA still works on LLMs, but only when multiple documents are presented for testing. We construct new benchmarks that measure the MIA performances at a continuous scale of data samples, from sentences (n-grams) to a collection of documents (multiple chunks of tokens). To validate the efficacy of current MIA approaches at greater scales, we adapt a recent work on Dataset Inference (DI) for the task of binary membership detection that aggregates paragraph-level MIA features to enable document- and dataset-level MIA. This baseline achieves the first successful MIA on pre-trained and fine-tuned LLMs. 22 | 23 | 24 | 25 | This repository provides: 26 | * Code to run MIA attacks (`procompute_*.py` files) 27 | * Code to run MIA aggregation (`{}_mia.py` files) 28 | * Precomputed MIA attacks in `results/*/*/*/*/mia_members.json` and `mia_nonmembers.json` 29 | * CSV files with the evaluation performance `results/*/*/*/*/*.csv` 30 | 31 | ## Reproducing the Experiments 32 | This section will show you how to reproduce the experiments in our paper. 33 | 34 | ### Setup 35 | First, install the requirements 36 | 37 | ``` 38 | conda create --name scaling_mia python=3.9 39 | conda activate scaling_mia 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | Now, download the results data. In this .zip file you will find all the results used to plot the graphs, the raw MIA scores from the MIA attacks and the source text data in HuggingFace Dataset format. 44 | 45 | ``` 46 | wget https://huggingface.co/datasets/haritzpuerto/scaling_mia/resolve/main/results.zip 47 | unzip results.zip 48 | ``` 49 | 50 | ### Running MIA Attacks 51 | 52 | You can run the MIA attacks for a collection of documents using: 53 | 54 | 55 | ``` 56 | python precompute_collection_mia_scores.py \ 57 | --model_name EleutherAI/pythia-2.8b \ 58 | --dataset_name haritzpuerto/the_pile_00_arxiv \ 59 | --output_path $OUTPUT_PATH \ 60 | --seed 42 61 | ``` 62 | 63 | This script will run `perplexity`, `ppl/lowercase_ppl`, `ppl/zlib`, and `min-l` attacks on the first 2048 (context window) tokens of each document. 64 | 65 | Similarly, to run the MIA attacks on a document for Document-MIA you can use: 66 | 67 | ``` 68 | python precompute_document_mia_scores.py \ 69 | --model_name EleutherAI/pythia-2.8b \ 70 | --dataset_name haritzpuerto/the_pile_00_arxiv \ 71 | --output_path $OUTPUT_PATH \ 72 | --seed 42 73 | ``` 74 | 75 | This script will split the documents into 2048-tokens paragraphs and will run `perplexity`, `ppl/lowercase_ppl`, `ppl/zlib`, and `min-l` attacks on them. 76 | 77 | 78 | ### Running MIA aggregation 79 | 80 | ``` 81 | python collection_mia.py \ 82 | --mia_path results/collection_mia/EleutherAI/pythia-6.9b/haritzpuerto/the_pile_00_arxiv/2048 83 | ``` 84 | This script will aggregate the MIA scores of each document to do MIA at the collection level. 85 | 86 | ``` 87 | python document_mia.py \ 88 | --base_path results/doc_mia/EleutherAI 89 | ``` 90 | 91 | This script will run over all model sizes, datasets, and paragraph sizes and will conduct document-MIA by aggregating the MIA scores of each paragraph. 92 | 93 | ## Cite 94 | 95 | If you find our work useful, please consider citing it using the following citation: 96 | 97 | ``` 98 | @misc{puerto2024scalingmembershipinferenceattacks, 99 | title={Scaling Up Membership Inference: When and How Attacks Succeed on Large Language Models}, 100 | author={Haritz Puerto and Martin Gubri and Sangdoo Yun and Seong Joon Oh}, 101 | year={2024}, 102 | eprint={2411.00154}, 103 | archivePrefix={arXiv}, 104 | primaryClass={cs.CL}, 105 | url={https://arxiv.org/abs/2411.00154}, 106 | } 107 | ``` 108 | 109 | ## Credits 110 | The precompute MIA scripts are based on the code from https://github.com/swj0419/detect-pretrain-code. 111 | 112 | This work was supported by the NAVER corporation. 113 | 114 | 115 | ## Disclaimer 116 | 117 | > This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication. 118 | 119 | ✉️ Contact person: Haritz Puerto, haritz.puerto@tu-darmstadt.de 120 | 121 | https://www.parameterlab.de/ 122 | 123 | Don't hesitate to send us an e-mail or report an issue if something is broken (and it shouldn't be) or if you have further questions. 124 | -------------------------------------------------------------------------------- /assets/haritz-mia-teaser-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parameterlab/mia-scaling/78be89cea09068cf2ed4ee5d3353062b0bc702f8/assets/haritz-mia-teaser-v2.png -------------------------------------------------------------------------------- /assets/logos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parameterlab/mia-scaling/78be89cea09068cf2ed4ee5d3353062b0bc702f8/assets/logos.png -------------------------------------------------------------------------------- /collection_mia.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import logging 3 | 4 | logging.basicConfig(level='ERROR') 5 | import argparse 6 | import json 7 | import os 8 | import random 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from scipy.stats import ttest_ind 17 | from sklearn.metrics import (classification_report, roc_auc_score, 18 | roc_curve) 19 | from torch.utils.data import DataLoader, TensorDataset 20 | from tqdm import tqdm 21 | 22 | 23 | def combinations(list_items, num_combinations, num_groups): 24 | list_items_idx = list(range(len(list_items))) 25 | combinations = set() 26 | while len(combinations) < num_combinations: 27 | list_idx = random.sample(range(len(list_items_idx)), num_groups) 28 | combinations.add(tuple(list_idx)) 29 | 30 | combinations_items = [] 31 | for combination in combinations: 32 | combination_items = [] 33 | for idx in combination: 34 | combination_items.append(list_items[idx]) 35 | combinations_items.append(combination_items) 36 | return combinations_items 37 | # return list(combinations) 38 | 39 | def process_combination_mia_features(list_combinations_mia_features): 40 | ''' 41 | Convert a list of [{'pred': {'ppl': 5.1875, 42 | 'ppl/lowercase_ppl': -1.0285907103711933, 43 | 'ppl/zlib': 0.0001900983701566763, 44 | 'Min_5.0% Prob': 8.089154411764707, 45 | 'Min_10.0% Prob': 6.647058823529412, 46 | 'Min_20.0% Prob': 5.1873471882640585, 47 | 'Min_30.0% Prob': 4.264161746742671, 48 | 'Min_40.0% Prob': 3.5961376833740832, 49 | 'Min_50.0% Prob': 3.07814027370479, 50 | 'Min_60.0% Prob': 2.663736449002443}, 51 | 'label': 0}, ... ] 52 | into a numpy array with all the features 53 | ''' 54 | features = [] 55 | for combination_mia_features in list_combinations_mia_features: 56 | dataset_features = [] 57 | for mia_features in combination_mia_features: 58 | dataset_features.append(list(mia_features['pred'].values())) 59 | features.append(dataset_features) 60 | 61 | return np.array(features) 62 | 63 | def train_model(model, train_loader, criterion, optimizer, num_epochs=100): 64 | """ 65 | Train the model using the training set. 66 | 67 | Parameters: 68 | model (torch.nn.Module): The model to be trained. 69 | train_loader (torch.utils.data.DataLoader): DataLoader for the training set. 70 | criterion (torch.nn.Module): Loss function. 71 | optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. 72 | num_epochs (int): Number of epochs to train the model. Default is 100. 73 | 74 | Returns: 75 | None 76 | """ 77 | model.train() 78 | for epoch in range(num_epochs): 79 | for inputs, labels in train_loader: 80 | inputs, labels = inputs.to('cuda'), labels.to('cuda') 81 | outputs = model(inputs).squeeze(1) # Squeeze the output to match the shape of labels 82 | loss = criterion(outputs, labels.float()) # Convert labels to float 83 | 84 | optimizer.zero_grad() 85 | loss.backward() 86 | optimizer.step() 87 | 88 | # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') 89 | 90 | 91 | 92 | def evaluate_model(model, eval_loader, threshold=0): 93 | model.eval() 94 | list_raw_scores = [] 95 | list_labels = [] 96 | with torch.no_grad(): 97 | correct = 0 98 | total = 0 99 | for inputs, labels in eval_loader: 100 | inputs, labels = inputs.to('cuda'), labels.to('cuda') 101 | outputs = model(inputs) 102 | predicted = (outputs.squeeze() > threshold).float() # Apply threshold to get binary predictions 103 | try: 104 | list_raw_scores.extend(outputs.squeeze().detach().cpu().numpy().tolist()) 105 | list_labels.extend(labels.detach().cpu().numpy().tolist()) 106 | total += labels.size(0) 107 | correct += (predicted == labels.float()).sum().item() # Convert labels to float for comparison 108 | except: 109 | print("Error") 110 | print(f"Outputs: {outputs}") 111 | 112 | accuracy = correct / total 113 | 114 | auc_score = roc_auc_score(list_labels, list_raw_scores) 115 | 116 | return accuracy, auc_score 117 | 118 | # Example usage: 119 | # accuracy, roc_auc = evaluate_model(model, eval_loader, threshold=0) 120 | # print(f'Accuracy: {accuracy:.4f} with threshold {threshold}') 121 | # print(f'ROC AUC: {roc_auc:.4f}') 122 | 123 | def get_tpr_fpr(list_predictions, list_labels): 124 | tp = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 1 and list_labels[i] == 1]) 125 | tn = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 0 and list_labels[i] == 0]) 126 | fp = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 1 and list_labels[i] == 0]) 127 | fn = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 0 and list_labels[i] == 1]) 128 | fpr = fp / (fp + tn)*100 129 | tpr = tp / (tp + fn)*100 130 | return tpr, fpr 131 | 132 | def get_raw_predictions(model, data_point): 133 | ''' 134 | A data point is a set of docs/chunks, so a tensor of k x num_features 135 | ''' 136 | model.eval() 137 | with torch.no_grad(): 138 | outputs = model(torch.Tensor(data_point).to('cuda')) 139 | return outputs.squeeze().detach().cpu().numpy() 140 | 141 | 142 | def compute_auc(list_pvalues, list_labels, plot=False): 143 | list_tpr = [] 144 | list_fpr = [] 145 | for threshold in[1, 0.7, 0.5, 0.3, 0.1, 0.05, 0.02, 0.01, 0.005, 0.001]: 146 | list_predictions = [1 if pval < threshold else 0 for pval in list_pvalues] 147 | tpr, fpr = get_tpr_fpr(list_predictions, list_labels) 148 | list_tpr.append(tpr/100) 149 | list_fpr.append(fpr/100) 150 | 151 | # sort ascending order of fpr 152 | # Sort list_fpr and list_tpr based on list_fpr 153 | sorted_indices = np.argsort(list_fpr) 154 | list_fpr = np.array(list_fpr)[sorted_indices] 155 | list_tpr = np.array(list_tpr)[sorted_indices] 156 | 157 | # compute auc using list_tpr and list_fpr 158 | auc_score = np.trapz(list_tpr, list_fpr) 159 | 160 | if plot: 161 | # plot ROC curve 162 | plt.plot(list_fpr, list_tpr) 163 | plt.xlabel('False Positive Rate') 164 | plt.ylabel('True Positive Rate') 165 | plt.title('ROC Curve') 166 | plt.show() 167 | return auc_score 168 | 169 | def run_dataset_inference_pvalues(args, num_docs_per_dataset, known_datasets): 170 | with open(os.path.join(args.mia_path, "mia_members.jsonl")) as f: 171 | mia_members = f.readlines() 172 | mia_members = [json.loads(x) for x in mia_members] 173 | 174 | with open(os.path.join(args.mia_path, "mia_nonmembers.jsonl")) as f: 175 | mia_non_members = f.readlines() 176 | mia_non_members = [json.loads(x) for x in mia_non_members] 177 | 178 | # members_text = load_from_disk(os.path.join(args.mia_path, "members")) 179 | # non_members_text = load_from_disk(os.path.join(args.mia_path, "nonmembers")) 180 | 181 | members_idx = list(range(len(mia_members))) 182 | non_members_idx = list(range(len(mia_non_members))) 183 | 184 | # shuffle the indices 185 | np.random.shuffle(members_idx) 186 | np.random.shuffle(non_members_idx) 187 | 188 | # shuffle mia_members and text in the same way 189 | mia_members = [mia_members[i] for i in members_idx] 190 | # members_text = members_text.select(members_idx) 191 | mia_non_members = [mia_non_members[i] for i in non_members_idx] 192 | # non_members_text = non_members_text.select(non_members_idx) 193 | # %% 194 | A_members = mia_members[:args.training_set_size_per_class] 195 | A_non_members = mia_non_members[:args.training_set_size_per_class] 196 | 197 | known_members = mia_members[args.training_set_size_per_class:args.training_set_size_per_class+known_datasets] 198 | known_non_members = mia_non_members[args.training_set_size_per_class:args.training_set_size_per_class+known_datasets] 199 | 200 | st_idx = args.training_set_size_per_class+known_datasets 201 | eval_members = mia_members[st_idx:] 202 | eval_non_members = mia_non_members[st_idx:] 203 | 204 | # replace al NaN by 0 205 | for mia in A_members + A_non_members + known_members + known_non_members + eval_members + eval_non_members: 206 | for key in mia['pred']: 207 | if np.isnan(mia['pred'][key]): 208 | mia['pred'][key] = 0 209 | 210 | # %% 211 | eval_members_datasets = combinations(eval_members, args.eval_set_size_per_class, num_docs_per_dataset) 212 | eval_non_members_datasets = combinations(eval_non_members, args.eval_set_size_per_class, num_docs_per_dataset) 213 | 214 | try: 215 | # known_members_datasets = combinations(known_members, len(known_members), num_docs_per_dataset) 216 | known_non_members_datasets = combinations(known_non_members, len(known_non_members), num_docs_per_dataset) 217 | except: 218 | print("Cannot use a full dataset as known dataset") 219 | known_non_members_datasets = combinations(known_non_members, len(known_non_members), len(known_non_members)) 220 | 221 | # %% 222 | A_members = process_combination_mia_features([[x] for x in A_members]) 223 | A_non_members = process_combination_mia_features([[x] for x in A_non_members]) 224 | 225 | # known_members = process_combination_mia_features(known_members_datasets) 226 | known_non_members = process_combination_mia_features(known_non_members_datasets) 227 | 228 | eval_members_datasets = process_combination_mia_features(eval_members_datasets) 229 | eval_non_members_datasets = process_combination_mia_features(eval_non_members_datasets) 230 | 231 | # %% 232 | # remove the top 2.5% and bottom 2.5% of the training_members_datasets and training_non_members_datasets 233 | A_members = np.sort(A_members, axis=0)[int(0.025 * A_members.shape[0]):int(0.975 * A_members.shape[0])] 234 | A_non_members = np.sort(A_non_members, axis=0)[int(0.025 * A_non_members.shape[0]):int(0.975 * A_non_members.shape[0])] 235 | 236 | # known_members = np.sort(known_members, axis=0)[int(0.025 * known_members.shape[0]):int(0.975 * known_members.shape[0])] 237 | known_non_members = np.sort(known_non_members, axis=0)[int(0.025 * known_non_members.shape[0]):int(0.975 * known_non_members.shape[0])] 238 | 239 | # %% 240 | A_members.shape 241 | 242 | # %% 243 | print(f"Training size: {A_members.shape[0] + A_non_members.shape[0]}") 244 | print(f"Known size: {known_non_members.shape[0]}") 245 | print(f"Eval size: {eval_members_datasets.shape[0] + eval_non_members_datasets.shape[0]}") 246 | 247 | # %% 248 | training_set = np.concatenate([A_members, A_non_members], axis=0).squeeze(1) # (950, 10) 249 | training_labels = np.concatenate([np.ones(len(A_members)), np.zeros(len(A_non_members))]) # (2000,) 250 | 251 | eval_set = np.concatenate([eval_members_datasets, eval_non_members_datasets], axis=0) 252 | num_features = eval_set.shape[-1] 253 | eval_labels = np.concatenate([np.ones(len(eval_members_datasets)), np.zeros(len(eval_non_members_datasets))]) 254 | 255 | # %% 256 | 257 | 258 | # %% 259 | # Convert to PyTorch tensors 260 | training_set = torch.tensor(training_set, dtype=torch.float32) 261 | training_labels = torch.tensor(training_labels, dtype=torch.long) 262 | 263 | # Create DataLoader 264 | train_dataset = TensorDataset(training_set, training_labels) 265 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 266 | # %% 267 | input_size = training_set.shape[1] 268 | hidden_size = input_size // 2 269 | output_size = 1 # membership score 270 | 271 | model = nn.Linear(input_size, 1).to('cuda') 272 | 273 | # %% 274 | # Step 4: Define the loss function and optimizer 275 | criterion = nn.BCEWithLogitsLoss() 276 | optimizer = optim.Adam(model.parameters(), lr=0.01) 277 | 278 | 279 | 280 | 281 | # %% 282 | train_model(model, train_loader, criterion, optimizer, num_epochs=100) 283 | 284 | # %% 285 | # Chunk-level Evaluation 286 | B_members = process_combination_mia_features([[x] for x in eval_members]) 287 | B_non_members = process_combination_mia_features([[x] for x in eval_non_members]) 288 | chunk_eval_set = np.concatenate([B_members, B_non_members], axis=0).squeeze(1) # (950, 10) 289 | chunk_eval_labels = np.concatenate([np.ones(len(B_members)), np.zeros(len(B_non_members))]) # (2000,) 290 | # create Tensors 291 | chunk_eval_set = torch.tensor(chunk_eval_set, dtype=torch.float32) 292 | chunk_eval_labels = torch.tensor(chunk_eval_labels, dtype=torch.long) 293 | # Create DataLoader 294 | chunk_level_dataset = TensorDataset(chunk_eval_set, chunk_eval_labels) 295 | chunk_level_loader = DataLoader(chunk_level_dataset, batch_size=args.batch_size, shuffle=False) 296 | 297 | _, chunk_level_auc = evaluate_model(model, chunk_level_loader, threshold=0) 298 | 299 | 300 | 301 | # %% 302 | list_mia_distribution = [] 303 | for data_point in eval_members_datasets: 304 | list_mia_distribution.append(get_raw_predictions(model, data_point)) 305 | 306 | 307 | # %% 308 | # known_members_distribution = [] 309 | # for data_point in known_members: 310 | # known_members_distribution.append(get_raw_predictions(model, data_point)) 311 | 312 | known_non_members_distribution = [] 313 | for data_point in known_non_members: 314 | known_non_members_distribution.append(get_raw_predictions(model, data_point)) 315 | 316 | # %% 317 | 318 | # %% [markdown] 319 | # # Single Test 320 | 321 | # %% 322 | 323 | 324 | # %% 325 | # members should be mapped to 1 326 | # non-members should be mapped to 0 327 | 328 | list_pvalues = [] 329 | list_labels = [] 330 | 331 | list_mia_distribution = [] 332 | for data_point in eval_members_datasets: 333 | list_mia_distribution.append(get_raw_predictions(model, data_point)) 334 | 335 | 336 | for dataset_distr in list_mia_distribution: 337 | statistic, pvalue = ttest_ind(dataset_distr, 338 | np.array(known_non_members_distribution).reshape(-1), 339 | equal_var=True, 340 | alternative='greater') 341 | list_pvalues.append(statistic) 342 | list_labels.append(1) 343 | 344 | # count num pvalues < 0.05 345 | # num_significant = len([x for x in list_pvalues if x < 0.05]) 346 | # print(f'Number of significant pvalues for members: {num_significant}; Percentage: {num_significant / len(list_pvalues) * 100:.2f}%') 347 | 348 | list_mia_distribution = [] 349 | for data_point in eval_non_members_datasets: 350 | list_mia_distribution.append(get_raw_predictions(model, data_point)) 351 | 352 | 353 | for dataset_distr in list_mia_distribution: 354 | statistic, pvalue = ttest_ind(dataset_distr, 355 | np.array(known_non_members_distribution).reshape(-1), 356 | equal_var=True, 357 | alternative='greater') 358 | list_pvalues.append(statistic) 359 | list_labels.append(0) 360 | 361 | # num_significant = len([x for x in list_pvalues if x < 0.05]) 362 | # print(f'Number of significant pvalues non-members: {num_significant}; Percentage: {num_significant / len(list_pvalues) * 100:.2f}%') 363 | 364 | 365 | auc_roc = roc_auc_score(list_labels, list_pvalues) 366 | # auc_roc = roc_auc_score([1 - x for x in list_labels], list_pvalues) 367 | fpr, tpr, thresholds = roc_curve(list_labels, list_pvalues) 368 | # Calculate Youden's J statistic 369 | youden_j = tpr - fpr 370 | best_threshold_index = np.argmax(youden_j) 371 | best_threshold = thresholds[best_threshold_index] 372 | 373 | # list_predictions = [1 if pval < best_threshold else 0 for pval in list_pvalues] 374 | list_predictions = [1 if pval > best_threshold else 0 for pval in list_pvalues] 375 | tpr, fpr = get_tpr_fpr(list_predictions, list_labels) 376 | report = classification_report(list_labels, list_predictions, output_dict=True) 377 | best_f1 = report['weighted avg']['f1-score'] 378 | 379 | 380 | print(f'Best threshold: {best_threshold}') 381 | print(f'Best F1: {100*best_f1:.4f} with threshold {best_threshold}; TPR: {tpr:.4f}, FPR: {fpr:.4f}; AUC: {auc_roc:.4f}; Chunk-level AUC: {chunk_level_auc:.4f}') 382 | # print(best_report) 383 | return best_f1*100, best_threshold, tpr, fpr, auc_roc, chunk_level_auc 384 | 385 | def parse_args(): 386 | parser = argparse.ArgumentParser() 387 | parser.add_argument("--training_set_size_per_class", type=int, default=1000) 388 | parser.add_argument("--known_datasets", type=int, default=None) 389 | parser.add_argument("--eval_set_size_per_class", type=int, default=1000) 390 | parser.add_argument("--num_docs_per_dataset", type=int, default=1000) 391 | 392 | parser.add_argument("--mia_path", type=str, default="out/dataset_mia/EleutherAI/pythia-2.8b/arxiv") 393 | 394 | parser.add_argument("--batch_size", type=int, default=128) 395 | return parser.parse_args() 396 | 397 | if __name__ == "__main__": 398 | args = parse_args() 399 | 400 | list_num_docs_per_dataset = [500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10] 401 | for num_docs_per_dataset in list_num_docs_per_dataset: 402 | list_rows = [] 403 | if args.known_datasets is None: 404 | for known_datasets in tqdm([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10]): 405 | for seed in [670487, 116739, 26225, 777572, 288389]: 406 | # set all random seeds 407 | random.seed(seed) 408 | np.random.seed(seed) 409 | torch.manual_seed(seed) 410 | torch.cuda.manual_seed(seed) 411 | torch.cuda.manual_seed_all(seed) 412 | torch.backends.cudnn.deterministic = True 413 | torch.backends.cudnn.benchmark = False 414 | best_f1, best_pvalue, tpr, fpr, auc_roc, chunk_level_auc = run_dataset_inference_pvalues(args, num_docs_per_dataset, known_datasets) 415 | list_rows.append([num_docs_per_dataset, 416 | known_datasets, 417 | args.training_set_size_per_class*2, 418 | args.eval_set_size_per_class*2, 419 | best_f1, 420 | best_pvalue, 421 | tpr, 422 | fpr, 423 | auc_roc, 424 | chunk_level_auc, 425 | seed, 426 | ]) 427 | 428 | df = pd.DataFrame(list_rows, columns=['Dataset Size', 'Known Datasets', 'Training Size', 'Eval Size', 'F1', 'P-value', 'TPR', 'FPR', 'AUC', 'Chunk-level AUC', 'Seed']) 429 | df.to_csv(os.path.join(args.mia_path, f"dataset_inference_pvalues_{num_docs_per_dataset}_dataset_size.csv"), index=False) 430 | -------------------------------------------------------------------------------- /document_mia.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import logging 3 | 4 | logging.basicConfig(level='ERROR') 5 | import argparse 6 | import glob 7 | import json 8 | import os 9 | import random 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from datasets import load_from_disk 18 | from scipy.stats import mannwhitneyu 19 | from sklearn.metrics import (classification_report, roc_auc_score, 20 | roc_curve) 21 | from torch.utils.data import DataLoader, TensorDataset 22 | 23 | 24 | # %% 25 | def combinations(list_items, num_combinations, num_groups): 26 | list_items_idx = list(range(len(list_items))) 27 | combinations = set() 28 | while len(combinations) < num_combinations: 29 | list_idx = random.sample(range(len(list_items_idx)), num_groups) 30 | combinations.add(tuple(list_idx)) 31 | 32 | combinations_items = [] 33 | for combination in combinations: 34 | combination_items = [] 35 | for idx in combination: 36 | combination_items.append(list_items[idx]) 37 | combinations_items.append(combination_items) 38 | return np.array(combinations_items), combinations 39 | # return list(combinations) 40 | 41 | # %% 42 | def process_combination_mia_features(list_combinations_mia_features): 43 | ''' 44 | Convert a list of [{'pred': {'ppl': 5.1875, 45 | 'ppl/lowercase_ppl': -1.0285907103711933, 46 | 'ppl/zlib': 0.0001900983701566763, 47 | 'Min_5.0% Prob': 8.089154411764707, 48 | 'Min_10.0% Prob': 6.647058823529412, 49 | 'Min_20.0% Prob': 5.1873471882640585, 50 | 'Min_30.0% Prob': 4.264161746742671, 51 | 'Min_40.0% Prob': 3.5961376833740832, 52 | 'Min_50.0% Prob': 3.07814027370479, 53 | 'Min_60.0% Prob': 2.663736449002443}, 54 | 'label': 0}, ... ] 55 | into a numpy array with all the features 56 | ''' 57 | features = [] 58 | for combination_mia_features in list_combinations_mia_features: 59 | dataset_features = [] 60 | for mia_features in combination_mia_features: 61 | dataset_features.append(list(mia_features['pred'].values())) 62 | features.append(dataset_features) 63 | 64 | return np.array(features) 65 | 66 | # %% 67 | import torch 68 | 69 | 70 | def train_model(model, train_loader, criterion, optimizer, num_epochs=100): 71 | """ 72 | Train the model using the training set. 73 | 74 | Parameters: 75 | model (torch.nn.Module): The model to be trained. 76 | train_loader (torch.utils.data.DataLoader): DataLoader for the training set. 77 | criterion (torch.nn.Module): Loss function. 78 | optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. 79 | num_epochs (int): Number of epochs to train the model. Default is 100. 80 | 81 | Returns: 82 | None 83 | """ 84 | model.train() 85 | for epoch in range(num_epochs): 86 | for inputs, labels in train_loader: 87 | inputs, labels = inputs.to('cuda'), labels.to('cuda') 88 | outputs = model(inputs).squeeze(1) # Squeeze the output to match the shape of labels 89 | loss = criterion(outputs, labels.float()) # Convert labels to float 90 | 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') 96 | 97 | # %% 98 | 99 | 100 | def evaluate_model(model, eval_loader, threshold=0): 101 | model.eval() 102 | list_raw_scores = [] 103 | list_labels = [] 104 | with torch.no_grad(): 105 | correct = 0 106 | total = 0 107 | for inputs, labels in eval_loader: 108 | inputs, labels = inputs.to('cuda'), labels.to('cuda') 109 | outputs = model(inputs) 110 | predicted = (outputs.squeeze() > threshold).float() # Apply threshold to get binary predictions 111 | 112 | list_raw_scores.extend(outputs.squeeze().detach().cpu().numpy().tolist()) 113 | list_labels.extend(labels.detach().cpu().numpy().tolist()) 114 | total += labels.size(0) 115 | correct += (predicted == labels.float()).sum().item() # Convert labels to float for comparison 116 | 117 | accuracy = correct / total 118 | 119 | auc_score = roc_auc_score(list_labels, list_raw_scores) 120 | 121 | return accuracy, auc_score 122 | 123 | # Example usage: 124 | # accuracy, roc_auc = evaluate_model(model, eval_loader, threshold=0) 125 | # print(f'Accuracy: {accuracy:.4f} with threshold {threshold}') 126 | # print(f'ROC AUC: {roc_auc:.4f}') 127 | 128 | def get_raw_predictions(model, docs_features): 129 | ''' 130 | docs_features is a numpy array of shape (num_docs, num_features) 131 | ''' 132 | model.eval() 133 | with torch.no_grad(): 134 | outputs = model(docs_features.to('cuda')).cpu().numpy() 135 | return outputs 136 | 137 | def get_mia_scorer(A_members, A_non_members, eval_members, eval_non_members, batch_size): 138 | # %% 139 | A_members = process_combination_mia_features([[x] for x in A_members]) 140 | A_non_members = process_combination_mia_features([[x] for x in A_non_members]) 141 | # remove the top 2.5% and bottom 2.5% of the training_members_datasets and training_non_members_datasets 142 | A_members = np.sort(A_members, axis=0)[int(0.025 * A_members.shape[0]):int(0.975 * A_members.shape[0])] 143 | A_non_members = np.sort(A_non_members, axis=0)[int(0.025 * A_non_members.shape[0]):int(0.975 * A_non_members.shape[0])] 144 | training_set = np.concatenate([A_members, A_non_members], axis=0).squeeze(1) # (950, 10) 145 | training_labels = np.concatenate([np.ones(len(A_members)), np.zeros(len(A_non_members))]) # (2000,) 146 | 147 | # Convert to PyTorch tensors 148 | training_set = torch.tensor(training_set, dtype=torch.float32) 149 | training_labels = torch.tensor(training_labels, dtype=torch.long) 150 | 151 | # Create DataLoader 152 | train_dataset = TensorDataset(training_set, training_labels) 153 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 154 | 155 | # %% 156 | print(f"Training size: {A_members.shape[0] + A_non_members.shape[0]}") 157 | 158 | # %% 159 | input_size = training_set.shape[1] 160 | hidden_size = input_size // 2 161 | output_size = 1 # membership score 162 | 163 | model = nn.Linear(input_size, 1).to('cuda') 164 | 165 | # %% 166 | # Step 4: Define the loss function and optimizer 167 | criterion = nn.BCEWithLogitsLoss() 168 | optimizer = optim.Adam(model.parameters(), lr=0.01) 169 | train_model(model, train_loader, criterion, optimizer, num_epochs=100) 170 | 171 | # %% 172 | # Chunk-level Evaluation 173 | B_members = process_combination_mia_features([[ch] for list_ch in eval_members for ch in list_ch]) 174 | B_non_members = process_combination_mia_features([[ch] for list_ch in eval_non_members for ch in list_ch]) 175 | # in some cases there could be some nan values for Min-k Prob, we will remove them (if it happens, it usually affect to the small probs) 176 | B_members = np.nan_to_num(B_members, nan=0.0) 177 | B_non_members = np.nan_to_num(B_non_members, nan=0.0) 178 | chunk_eval_set = np.concatenate([B_members, B_non_members], axis=0).squeeze(1) # (950, 10) 179 | chunk_eval_labels = np.concatenate([np.ones(len(B_members)), np.zeros(len(B_non_members))]) # (2000,) 180 | # create Tensors 181 | chunk_eval_set = torch.tensor(chunk_eval_set, dtype=torch.float32) 182 | chunk_eval_labels = torch.tensor(chunk_eval_labels, dtype=torch.long) 183 | # Create DataLoader 184 | chunk_level_dataset = TensorDataset(chunk_eval_set, chunk_eval_labels) 185 | chunk_level_loader = DataLoader(chunk_level_dataset, batch_size=batch_size, shuffle=False) 186 | 187 | accuracy, roc_auc_chunk = evaluate_model(model, chunk_level_loader, threshold=0) 188 | return model, roc_auc_chunk 189 | 190 | # %% 191 | def dict_to_tensor(mia_features): 192 | # Extract the 'pred' values and convert them to a list 193 | pred_values = list(mia_features['pred'].values()) 194 | # clean nan values (replace with 0) 195 | pred_values = [0 if np.isnan(x) else x for x in pred_values] 196 | # Convert the list to a tensor 197 | tensor = torch.tensor(pred_values) 198 | 199 | return tensor 200 | 201 | def convert_docs2tensors(docs): 202 | ''' 203 | Convert a list of documents (dict of mia features) to a list of tensors. A tensor for each document (# paragraphs x #mia_features) 204 | ''' 205 | tensors = [] 206 | for doc in docs: 207 | paragraph_tensor = [] 208 | for paragraph in doc: 209 | paragraph_tensor.append(dict_to_tensor(paragraph)) 210 | tensors.append(torch.stack(paragraph_tensor)) 211 | return tensors 212 | 213 | # %% 214 | def get_eval_scores(model, members_path, non_members_path): 215 | with open(members_path) as f: 216 | eval_members = f.readlines() 217 | eval_members = [json.loads(x) for x in eval_members] 218 | 219 | with open(non_members_path) as f: 220 | eval_non_members = f.readlines() 221 | eval_non_members = [json.loads(x) for x in eval_non_members] 222 | 223 | eval_members_tensor = convert_docs2tensors(eval_members) 224 | eval_non_members_tensor = convert_docs2tensors(eval_non_members) 225 | 226 | eval_members_docs_scores = [] 227 | for tensor in eval_members_tensor: 228 | eval_members_docs_scores.append(get_raw_predictions(model, tensor)) 229 | 230 | eval_non_members_docs_scores = [] 231 | for tensor in eval_non_members_tensor: 232 | eval_non_members_docs_scores.append(get_raw_predictions(model, tensor)) 233 | 234 | return eval_members_docs_scores, eval_non_members_docs_scores 235 | 236 | 237 | # %% 238 | def evaluate(eval_members_docs_scores, eval_non_members_docs_scores, known_non_members_docs_scores, num_paragraphs): 239 | list_scores = [] 240 | list_labels = [] 241 | for eval_member_score in eval_members_docs_scores: 242 | if len(eval_member_score) > 1: 243 | statistic, pvalue = mannwhitneyu(eval_member_score[:num_paragraphs].squeeze(), 244 | known_non_members_docs_scores, 245 | alternative='greater') 246 | # statistic, pvalue = ttest_ind(eval_member_score[:num_paragraphs].squeeze(), 247 | # known_non_members_docs_scores, 248 | # alternative='greater', 249 | # equal_var=False) 250 | if not np.isnan(statistic): 251 | list_scores.append(statistic) 252 | list_labels.append(1) 253 | 254 | 255 | for eval_non_member_score in eval_non_members_docs_scores: 256 | if len(eval_non_member_score) > 1: 257 | statistic, pvalue = mannwhitneyu(eval_non_member_score[:num_paragraphs].squeeze(), 258 | known_non_members_docs_scores, 259 | alternative='greater') 260 | # statistic, pvalue = ttest_ind(eval_non_member_score[:num_paragraphs].squeeze(), 261 | # known_non_members_docs_scores, 262 | # alternative='greater', 263 | # equal_var=False) 264 | if not np.isnan(statistic): 265 | list_scores.append(statistic) 266 | list_labels.append(0) 267 | 268 | auc_roc = roc_auc_score(list_labels, list_scores) 269 | fpr, tpr, thresholds = roc_curve(list_labels, list_scores) 270 | # Calculate Youden's J statistic 271 | youden_j = tpr - fpr 272 | best_threshold_index = np.argmax(youden_j) 273 | best_threshold = thresholds[best_threshold_index] 274 | 275 | list_predictions = [1 if pval > best_threshold else 0 for pval in list_scores] 276 | tpr, fpr = get_tpr_fpr(list_predictions, list_labels) 277 | report = classification_report(list_labels, list_predictions, output_dict=True) 278 | best_f1 = report['weighted avg']['f1-score'] 279 | 280 | return best_f1, tpr, fpr, auc_roc 281 | # %% 282 | def get_tpr_fpr(list_predictions, list_labels): 283 | tp = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 1 and list_labels[i] == 1]) 284 | tn = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 0 and list_labels[i] == 0]) 285 | fp = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 1 and list_labels[i] == 0]) 286 | fn = sum([1 for i in range(len(list_predictions)) if list_predictions[i] == 0 and list_labels[i] == 1]) 287 | fpr = fp / (fp + tn)*100 288 | tpr = tp / (tp + fn)*100 289 | return tpr, fpr 290 | 291 | def clean_nan(mia_partition): 292 | for list_sents in mia_partition: 293 | for sent in list_sents: 294 | for (mia, value) in sent['pred'].items(): 295 | if np.isnan(value): 296 | sent['pred'][mia] = 100000 297 | 298 | def run_mia(mia_path, training_set_size_per_class, known_datasets, batch_size): 299 | with open(os.path.join(mia_path, "mia_members.jsonl")) as f: 300 | mia_members = f.readlines() 301 | mia_members = [json.loads(x) for x in mia_members] 302 | 303 | with open(os.path.join(mia_path, "mia_nonmembers.jsonl")) as f: 304 | mia_non_members = f.readlines() 305 | mia_non_members = [json.loads(x) for x in mia_non_members] 306 | 307 | clean_nan(mia_members) 308 | clean_nan(mia_non_members) 309 | # %% 310 | members_text = load_from_disk(os.path.join(mia_path, "members")) 311 | non_members_text = load_from_disk(os.path.join(mia_path, "nonmembers")) 312 | 313 | # %% 314 | members_idx = list(range(len(mia_members))) 315 | non_members_idx = list(range(len(mia_non_members))) 316 | 317 | # shuffle the indices 318 | np.random.shuffle(members_idx) 319 | np.random.shuffle(non_members_idx) 320 | 321 | # shuffle mia_members and text in the same way 322 | mia_members = [mia_members[i] for i in members_idx] 323 | members_text = members_text.select(members_idx) 324 | mia_non_members = [mia_non_members[i] for i in non_members_idx] 325 | non_members_text = non_members_text.select(non_members_idx) 326 | 327 | 328 | A_non_members = [x for list_paragraphs in mia_non_members[:950] for x in list_paragraphs] 329 | known_non_members = [x for list_paragraphs in mia_non_members[950:1000] for x in list_paragraphs] 330 | eval_non_members = mia_non_members[1000:] 331 | 332 | A_members = [x for list_paragraphs in mia_members[:950] for x in list_paragraphs ] 333 | known_members = [x for x in mia_members[950:1000]] 334 | eval_members = mia_members[1000:] 335 | 336 | print(f"Num. training paragraphs: {len(A_members) + len(A_non_members)}") 337 | print(f"Num. known docs: {len(known_members)}") 338 | print(f"Num. member eval docs: {len(eval_members)}") 339 | print(f"Num. non-member eval docs: {len(eval_non_members)}") 340 | 341 | eval_set_size = len(eval_members) + len(eval_non_members) 342 | 343 | 344 | # %% [markdown] 345 | # # Step 1: Train Chunk-level Model 346 | model, roc_auc_chunk = get_mia_scorer(A_members, A_non_members, eval_members, eval_non_members, batch_size) 347 | print(f'Chunk-level ROC AUC: {roc_auc_chunk:.4f}') 348 | 349 | # %% [markdown] 350 | # # Step 2: Prepare Document-level Evaluation 351 | 352 | # %% 353 | eval_members_tensor = convert_docs2tensors(eval_members) 354 | eval_non_members_tensor = convert_docs2tensors(eval_non_members) 355 | 356 | eval_members_docs_scores = [] 357 | for tensor in eval_members_tensor: 358 | eval_members_docs_scores.append(get_raw_predictions(model, tensor)) 359 | 360 | eval_non_members_docs_scores = [] 361 | for tensor in eval_non_members_tensor: 362 | eval_non_members_docs_scores.append(get_raw_predictions(model, tensor)) 363 | 364 | 365 | 366 | # %% [markdown] 367 | # # Step 3: Process the Known Partition 368 | 369 | # %% 370 | known_non_members_ = process_combination_mia_features([[x] for x in known_non_members]) 371 | known_non_members_ = np.nan_to_num(known_non_members_, nan=0.0) 372 | known_non_members_tensor = torch.tensor(known_non_members_, dtype=torch.float32) 373 | known_non_members_docs_scores = get_raw_predictions(model, known_non_members_tensor) 374 | known_non_members_docs_scores = np.array([x for l in known_non_members_docs_scores for x in l]).reshape(-1) 375 | 376 | ## Run Statistic Tests 377 | # list_num_paragraphs = (set([len(x) for x in eval_members_docs_scores + eval_non_members_docs_scores])) 378 | # list_num_paragraphs = [x for x in list_num_paragraphs if x > 1] 379 | list_num_paragraphs = [len(x) for x in eval_members_docs_scores + eval_non_members_docs_scores] 380 | print(f"Running mia for up to {int(np.mean(list_num_paragraphs) + np.std(list_num_paragraphs))} paragraphs") 381 | list_auc_roc = [] 382 | list_paragraphs = list(range(2, int(np.mean(list_num_paragraphs) + np.std(list_num_paragraphs)))) 383 | for num_paragraphs in list_paragraphs: 384 | best_f1, tpr, fpr, auc_roc = evaluate(eval_members_docs_scores, eval_non_members_docs_scores, known_non_members_docs_scores, num_paragraphs) 385 | list_auc_roc.append(auc_roc) 386 | 387 | return list_paragraphs, list_auc_roc, roc_auc_chunk, eval_set_size, np.mean(list_num_paragraphs) 388 | 389 | def set_seed(seed): 390 | random.seed(seed) 391 | np.random.seed(seed) 392 | torch.manual_seed(seed) 393 | torch.cuda.manual_seed(seed) 394 | torch.cuda.manual_seed_all(seed) 395 | torch.backends.cudnn.deterministic = True 396 | torch.backends.cudnn.benchmark = False 397 | 398 | def parse_args(): 399 | parser = argparse.ArgumentParser() 400 | parser.add_argument("--sample_size", type=int, default=None) 401 | parser.add_argument("--training_set_size_per_class", type=int, default=2000) 402 | parser.add_argument("--known_datasets", type=int, default=500) 403 | parser.add_argument("--batch_size", type=int, default=32) 404 | parser.add_argument("--base_path", type=str, default="results/doc_mia/EleutherAI") 405 | return parser.parse_args() 406 | 407 | if __name__ == "__main__": 408 | args = parse_args() 409 | # %% 410 | # generate 5 random seeds 411 | seeds = [random.randint(0, 100000) for _ in range(5)] 412 | for seed in seeds: 413 | set_seed(seed) 414 | # %% 415 | list_chunk_sizes = [512, 1024, 2048] 416 | list_llm_names = [os.path.basename(x) for x in glob.glob(os.path.join(args.base_path, "*"))] 417 | for llm_name in list_llm_names: 418 | model_size = llm_name.split("pythia-")[1].split("/")[0] 419 | for dataset_path in glob.glob(f"{args.base_path}/{llm_name}/haritzpuerto/*"): 420 | dataset_name = os.path.basename(dataset_path).split("_")[-1] 421 | for chunk_size in list_chunk_sizes: 422 | full_mia_path = os.path.join(dataset_path, f"sample_size_2000/{chunk_size}") 423 | print(f"Running {full_mia_path}. Seed {seed}") 424 | try: 425 | list_num_paragraphs, list_auc_roc, roc_auc_chunk, eval_set_size, avg_paragraphs = run_mia(full_mia_path, args.training_set_size_per_class, args.known_datasets, args.batch_size) 426 | df = pd.DataFrame({"num_paragraphs": list_num_paragraphs, 427 | "auc_roc": list_auc_roc, 428 | "roc_auc_chunk": [roc_auc_chunk]*len(list_num_paragraphs), 429 | "eval_set_size": [eval_set_size]*len(list_num_paragraphs), 430 | "seed": [seed]*len(list_num_paragraphs), 431 | "avg_paragraphs": [avg_paragraphs]*len(list_num_paragraphs), 432 | }) 433 | output_path = os.path.join(full_mia_path, "mia_results_950-50-1000") 434 | os.makedirs(output_path, exist_ok=True) 435 | df.to_csv(os.path.join(output_path, f"mia_results_{seed}.csv"), index=False) 436 | plt.plot(list_num_paragraphs, list_auc_roc, label=f'{model_size} - {chunk_size}') 437 | plt.xlabel('# Paragraphs') 438 | plt.ylabel('AUC') 439 | plt.title(f'AUC vs. # Paragraphs ({dataset_name}) - Chunk-level AUC: {roc_auc_chunk:.4f}') 440 | plt.legend(loc='lower right') 441 | # save the plot 442 | plt.savefig(os.path.join(output_path, f"mia_results_{seed}.png")) 443 | plt.close() 444 | print(f"Done running {output_path}") 445 | except Exception as e: 446 | print(f"Error running {output_path}") 447 | print(e) 448 | with open(f"{args.base_path}/{llm_name}/haritzpuerto/errors.txt", "w") as f: 449 | f.write(str(e)) 450 | continue 451 | -------------------------------------------------------------------------------- /precompute_collection_mia_scores.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import logging 3 | 4 | logging.basicConfig(level='ERROR') 5 | import argparse 6 | import json 7 | import os 8 | import random 9 | import zlib 10 | 11 | import numpy as np 12 | import torch 13 | from datasets import concatenate_datasets, load_dataset 14 | from tqdm import tqdm 15 | from transformers import AutoModelForCausalLM, AutoTokenizer 16 | 17 | ''' 18 | This script uses the first paragraph of 2048 tokens from each document in the collection to compute the MIA scores. 19 | ''' 20 | 21 | # %% 22 | def calculatePerplexity(sentence, model, tokenizer): 23 | """ 24 | exp(loss) 25 | """ 26 | encodings = tokenizer(sentence, return_tensors='pt', truncation=True, max_length=2048) 27 | if model.device.type == "cuda": 28 | encodings = {k: v.cuda() for k, v in encodings.items()} 29 | with torch.no_grad(): 30 | outputs = model(**encodings, labels=encodings['input_ids']) 31 | loss, logits = outputs[:2] 32 | 33 | ''' 34 | extract logits: 35 | ''' 36 | # Apply softmax to the logits to get probabilities 37 | probabilities = torch.nn.functional.log_softmax(logits, dim=-1) 38 | # probabilities = torch.nn.functional.softmax(logits, dim=-1) 39 | all_prob = [] 40 | input_ids_processed = encodings['input_ids'][0][1:] 41 | for i, token_id in enumerate(input_ids_processed): 42 | probability = probabilities[0, i, token_id].item() 43 | all_prob.append(probability) 44 | return torch.exp(loss).item(), all_prob, loss.item() 45 | 46 | 47 | # %% 48 | # in run.py you have a variant of this function with one more MIA: ppl/Ref_ppl 49 | def inference(model1, tokenizer1, text): 50 | pred = {} 51 | 52 | p1, all_prob, p1_likelihood = calculatePerplexity(text, model1, tokenizer1) 53 | p_lower, _, p_lower_likelihood = calculatePerplexity(text.lower(), model1, tokenizer1) 54 | 55 | 56 | # ppl 57 | pred["ppl"] = p1 58 | 59 | # Ratio of log ppl of lower-case and normal-case 60 | pred["ppl/lowercase_ppl"] = -(np.log(p_lower) / np.log(p1)).item() 61 | # Ratio of log ppl of large and zlib 62 | zlib_entropy = len(zlib.compress(bytes(text, 'utf-8'))) 63 | pred["ppl/zlib"] = np.log(p1)/zlib_entropy 64 | # min-k prob 65 | for ratio in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]: 66 | k_length = int(len(all_prob)*ratio) 67 | topk_prob = np.sort(all_prob)[:k_length] 68 | pred[f"Min_{ratio*100}% Prob"] = -np.mean(topk_prob).item() 69 | 70 | return pred 71 | # %% 72 | def create_text(x): 73 | conversation = x['conversations'] 74 | text = "" 75 | for message in conversation: 76 | text += message['from'] + ": " + message['value'] + "\n" 77 | return {"text": text} 78 | 79 | 80 | def parse_args(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-2.8b") 83 | parser.add_argument("--dataset_name", type=str, default="haritzpuerto/the_pile_00_arxiv") 84 | parser.add_argument("--filter_outliers", action="store_true") 85 | parser.add_argument("--min_chars", type=int, default=100) 86 | parser.add_argument("--output_path", type=str) 87 | parser.add_argument("--cache_dir", type=str, default="/tmp") 88 | parser.add_argument("--seed", type=int, default=0) 89 | return parser.parse_args() 90 | 91 | if __name__ == "__main__": 92 | args = parse_args() 93 | random.seed(args.seed) 94 | # %% 95 | model1 = AutoModelForCausalLM.from_pretrained(args.model_name, return_dict=True, device_map='auto', torch_dtype=torch.bfloat16, cache_dir=args.cache_dir) 96 | model1.eval() 97 | tokenizer1 = AutoTokenizer.from_pretrained(args.model_name) 98 | 99 | # %% 100 | ds = load_dataset(args.dataset_name) 101 | if args.filter_outliers: 102 | # removing outlier docs 103 | ds['train'] = ds['train'].filter(lambda x: len(x["text"]) > args.min_chars) 104 | ds['validation'] = ds['validation'].filter(lambda x: len(x["text"]) > args.min_chars) 105 | ds['test'] = ds['test'].filter(lambda x: len(x["text"]) > args.min_chars) 106 | 107 | nonmembers = concatenate_datasets([ds["validation"], ds["test"]]) 108 | members = ds["train"].shuffle(seed=args.seed).select(range(len(nonmembers))) 109 | 110 | nonmembers.save_to_disk(os.path.join(args.output_path, "nonmembers")) 111 | members.save_to_disk(os.path.join(args.output_path, "members")) 112 | # %% 113 | data_points_members = [] 114 | for text in tqdm(members['text']): 115 | mia_features = inference(model1, tokenizer1, text) 116 | data_points_members.append({'pred': mia_features, 'label': 1}) 117 | torch.cuda.empty_cache() 118 | 119 | with open(os.path.join(args.output_path, "mia_members.jsonl"), "w") as f: 120 | for dp in data_points_members: 121 | f.write(json.dumps(dp) + "\n") 122 | 123 | data_points_nonmembers = [] 124 | for text in tqdm(nonmembers['text']): 125 | mia_features = inference(model1, tokenizer1, text) 126 | data_points_nonmembers.append({'pred': mia_features, 'label': 0}) 127 | torch.cuda.empty_cache() 128 | 129 | with open(os.path.join(args.output_path, "mia_nonmembers.jsonl"), "w") as f: 130 | for dp in data_points_nonmembers: 131 | f.write(json.dumps(dp) + "\n") 132 | -------------------------------------------------------------------------------- /precompute_document_mia_scores.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import logging 3 | 4 | logging.basicConfig(level='ERROR') 5 | import argparse 6 | import json 7 | import os 8 | import random 9 | import zlib 10 | 11 | import datasets 12 | import numpy as np 13 | import torch 14 | from datasets import concatenate_datasets, load_dataset 15 | from tqdm import tqdm 16 | from transformers import AutoModelForCausalLM, AutoTokenizer 17 | 18 | ''' 19 | This script chunks each document into paragraphs of args.max_length tokens and computes the MIA scores for each paragraph. 20 | ''' 21 | 22 | # %% 23 | def calculatePerplexity(sentence, model, tokenizer): 24 | """ 25 | exp(loss) 26 | """ 27 | encodings = tokenizer(sentence, return_tensors='pt', truncation=True, max_length=2048) 28 | if model.device.type == "cuda": 29 | encodings = {k: v.cuda() for k, v in encodings.items()} 30 | with torch.no_grad(): 31 | outputs = model(**encodings, labels=encodings['input_ids']) 32 | loss, logits = outputs[:2] 33 | 34 | ''' 35 | extract logits: 36 | ''' 37 | # Apply softmax to the logits to get probabilities 38 | probabilities = torch.nn.functional.log_softmax(logits, dim=-1) 39 | # probabilities = torch.nn.functional.softmax(logits, dim=-1) 40 | all_prob = [] 41 | input_ids_processed = encodings['input_ids'][0][1:] 42 | for i, token_id in enumerate(input_ids_processed): 43 | probability = probabilities[0, i, token_id].item() 44 | all_prob.append(probability) 45 | return torch.exp(loss).item(), all_prob, loss.item() 46 | 47 | 48 | 49 | # %% 50 | # in run.py you have a variant of this function with one more MIA: ppl/Ref_ppl 51 | def inference(model1, tokenizer1, text): 52 | pred = {} 53 | 54 | p1, all_prob, p1_likelihood = calculatePerplexity(text, model1, tokenizer1) 55 | p_lower, _, p_lower_likelihood = calculatePerplexity(text.lower(), model1, tokenizer1) 56 | 57 | 58 | # ppl 59 | pred["ppl"] = p1 60 | 61 | # Ratio of log ppl of lower-case and normal-case 62 | pred["ppl/lowercase_ppl"] = -(np.log(p_lower) / np.log(p1)).item() 63 | # Ratio of log ppl of large and zlib 64 | zlib_entropy = len(zlib.compress(bytes(text, 'utf-8'))) 65 | pred["ppl/zlib"] = np.log(p1)/zlib_entropy 66 | # min-k prob 67 | for ratio in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]: 68 | k_length = int(len(all_prob)*ratio) 69 | topk_prob = np.sort(all_prob)[:k_length] 70 | pred[f"Min_{ratio*100}% Prob"] = -np.mean(topk_prob).item() 71 | 72 | return pred 73 | 74 | # %% 75 | def create_text(x): 76 | conversation = x['conversations'] 77 | text = "" 78 | for message in conversation: 79 | text += message['from'] + ": " + message['value'] + "\n" 80 | return {"text": text} 81 | 82 | 83 | def create_chunks(text, tokenizer1, max_length): 84 | tokens = tokenizer1.encode(text, add_special_tokens=True) 85 | chunks = [tokenizer1.decode(tokens[i:i+max_length], skip_special_tokens=True) for i in range(0, len(tokens), max_length)] 86 | return chunks 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-2.8b") 91 | parser.add_argument("--dataset_name", type=str, default="haritzpuerto/the_pile_00_arxiv") 92 | parser.add_argument("--max_length", type=int, default=2048) 93 | parser.add_argument("--filter_outliers", action="store_true") 94 | parser.add_argument("--min_chars", type=int, default=100) 95 | parser.add_argument("--output_path", type=str) 96 | parser.add_argument("--cache_dir", type=str, default="/tmp") 97 | parser.add_argument("--seed", type=int, default=0) 98 | parser.add_argument("--sample_size", type=int, default=None) 99 | return parser.parse_args() 100 | 101 | if __name__ == "__main__": 102 | ''' 103 | How to run 104 | python precompute_mia_docs.py \ 105 | --model_name EleutherAI/pythia-2.8b \ 106 | --output_path out/doc_mia/EleutherAI/pythia-2.8b/arxiv/2048_tokens \ 107 | --filter_outliers 108 | 109 | ''' 110 | args = parse_args() 111 | random.seed(args.seed) 112 | # %% 113 | model1 = AutoModelForCausalLM.from_pretrained(args.model_name, return_dict=True, device_map='auto', torch_dtype=torch.bfloat16, cache_dir=args.cache_dir) 114 | model1.eval() 115 | tokenizer1 = AutoTokenizer.from_pretrained(args.model_name) 116 | 117 | # %% 118 | ds = load_dataset(args.dataset_name) 119 | if args.filter_outliers: 120 | # removing outlier docs 121 | ds['train'] = ds['train'].filter(lambda x: len(x["text"]) > args.min_chars) 122 | ds['validation'] = ds['validation'].filter(lambda x: len(x["text"]) > args.min_chars) 123 | ds['test'] = ds['test'].filter(lambda x: len(x["text"]) > args.min_chars) 124 | 125 | nonmembers = concatenate_datasets([ds["validation"], ds["test"]]) 126 | if args.sample_size is not None: 127 | # pick the largest sample_size from the nonmembers 128 | nonmembers = sorted(nonmembers, key=lambda x: len(x['text']), reverse=True)[:args.sample_size] 129 | # pick the largest sample_size from the members 130 | # PICK A RANDOM SAMPLE OF MEMBERS WHOSE LENGTH IS IN THE RANGE OF THE LENGTH OF THE DOCS OF THE NONMEMBERS 131 | doc_lengths = [len(text['text']) for text in nonmembers] 132 | min_len = min(doc_lengths) 133 | max_len = max(doc_lengths) 134 | nonmembers = datasets.Dataset.from_list(nonmembers) 135 | members = ds['train'].filter(lambda x: min_len <= len(x["text"]) and len(x["text"]) <= max_len).shuffle(seed=args.seed).select(range(len(nonmembers))) 136 | 137 | 138 | nonmembers.save_to_disk(os.path.join(args.output_path, "nonmembers")) 139 | members.save_to_disk(os.path.join(args.output_path, "members")) 140 | # %% 141 | data_points_members = [] 142 | for text in tqdm(members['text']): 143 | chunks = create_chunks(text, tokenizer1, args.max_length) 144 | doc_features = [] 145 | for chunk in chunks: 146 | mia_features = inference(model1, tokenizer1, chunk) 147 | doc_features.append({'pred': mia_features, 'label': 1}) 148 | data_points_members.append(doc_features) 149 | torch.cuda.empty_cache() 150 | 151 | with open(os.path.join(args.output_path, "mia_members.jsonl"), "w") as f: 152 | for dp in data_points_members: 153 | f.write(json.dumps(dp) + "\n") 154 | 155 | data_points_nonmembers = [] 156 | for text in tqdm(nonmembers['text']): 157 | chunks = create_chunks(text, tokenizer1, args.max_length) 158 | doc_features = [] 159 | for chunk in chunks: 160 | if len(chunk) > args.min_chars: 161 | mia_features = inference(model1, tokenizer1, chunk) 162 | doc_features.append({'pred': mia_features, 'label': 0}) 163 | data_points_nonmembers.append(doc_features) 164 | torch.cuda.empty_cache() 165 | 166 | with open(os.path.join(args.output_path, "mia_nonmembers.jsonl"), "w") as f: 167 | for dp in data_points_nonmembers: 168 | f.write(json.dumps(dp) + "\n") 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.34.2 2 | cuda-python==11.6.0 3 | datasets==2.20.0 4 | matplotlib==3.6.2 5 | nltk==3.8.1 6 | pandas==1.5.2 7 | scikit-learn 8 | scipy==1.6.3 9 | transformers @ git+https://github.com/huggingface/transformers@aec1ca3a588bc6c65f7886e3d3eaa74901a6356f 10 | torch==1.14.0 11 | --------------------------------------------------------------------------------