├── .gitignore ├── FAQ.md ├── LICENSE ├── README.md ├── assets └── fig1.png ├── examples ├── __init__.py ├── classification │ ├── README.md │ ├── __init__.py │ ├── data │ │ ├── download_dataset.sh │ │ ├── k-shot │ │ │ └── checksum │ │ ├── make_k_shot_without_dev.py │ │ └── make_valid_data.py │ ├── requirements.txt │ ├── run_classification.py │ ├── run_wrapper.py │ ├── spectral_analysis │ │ ├── 3d_surface.png │ │ ├── README.md │ │ ├── __init__.py │ │ ├── density.py │ │ ├── geometric_median.py │ │ ├── rebuttal_neurips_2022.py │ │ ├── rebuttal_plots_neurips_2022.py │ │ └── visuals.ipynb │ └── src │ │ ├── __init__.py │ │ ├── common.py │ │ ├── compiled_args.py │ │ ├── dataset.py │ │ ├── label_search.py │ │ ├── models.py │ │ ├── processors.py │ │ └── trainer.py ├── image_classification │ ├── README.md │ ├── __init__.py │ └── main.py └── table2text │ ├── README.md │ ├── __init__.py │ ├── compiled_args.py │ ├── data_utils │ ├── __init__.py │ ├── data_collator.py │ └── language_modeling.py │ ├── decoding_utils.py │ ├── density.py │ ├── misc.py │ ├── models.py │ ├── requirements.txt │ ├── run.sh │ ├── run_language_modeling.py │ └── trainer.py ├── private_transformers ├── __init__.py ├── accounting │ ├── __init__.py │ ├── accounting_manager.py │ └── rdp_accounting.py ├── autograd_grad_sample.py ├── lora_utils.py ├── privacy_engine.py ├── settings.py ├── supported_layers_grad_samplers.py └── transformers_support.py ├── setup.py └── tests ├── __init__.py └── test_privacy_engine.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | .DS_Store 133 | 134 | # Ignore figures and plots when uploading to github. 135 | **.png 136 | **.pdf 137 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ## FAQ 2 | 3 | ### How do I perform gradient accumulation? 4 | 5 | Use `virtual_step` in combination with `step`. For example, the following gives a simplified demo of the structure: 6 | 7 | ```python 8 | import torch, transformers, private_transformers 9 | 10 | gradient_accumulation_steps = 10 # Take an update once this many iterations. 11 | 12 | batches = ... # Data. 13 | model = transformers.AutoModelWithLMHead.from_pretrained('gpt') 14 | optimizer = torch.optim.Adam(model.parameters()) 15 | privacy_engine = private_transformers.PrivacyEngine(...) 16 | privacy_engine.attach(optimizer) 17 | 18 | for i, batch in enumerate(batches, 1): 19 | loss = model(batch) 20 | if i % gradient_accumulation_steps == 0: 21 | optimizer.step(loss=loss) 22 | optimizer.zero_grad() 23 | else: 24 | optimizer.virtual_step(loss=loss) 25 | ``` 26 | 27 | ### What is ghost clipping? 28 | 29 | It's a per example gradient clipping (then summing) technique that avoids instantiating per example gradients. It can 30 | make private training have almost the same memory cost as non-private training. 31 | The method is based on accumulating gradient norms on a layer-by-layer basis first demonstrated 32 | in [this work](https://arxiv.org/abs/2009.03106). 33 | We implemented and extended this method so that computing gradient norms for linear layers can be cheap; this is based on a 34 | linear algebra identity that we derived in [this work](https://arxiv.org/pdf/2110.05679.pdf). 35 | [Subsequent work](https://arxiv.org/abs/2205.10683) adapted the overall approach to suit training convolutional layers. 36 | 37 | ### How did you test that ghost clipping gives the 'right' gradients? 38 | 39 | We ran stringent numerical tests to ensure the double-backward implementation is correct (e.g., remove sources of 40 | randomness like dropout and compare gradients from double backward against gradients from autodiff + for loop). 41 | Check out files in the `tests` folder for more on this. 42 | 43 | ### When can't I use ghost clipping? 44 | 45 | Ghost clipping can't handle parameter sharing, that's why in our code, we separate the lm-head out from the embedding 46 | layer for generation tasks. Similarly, it can't be applied to fine-tuning ALBERT which ties weights across many layers 47 | of the model. 48 | 49 | ### What if I want to freeze some parameters of the network while updating all others? 50 | 51 | Before creating the privacy engine and optimizer, set parts of the model which won't be optimized to 52 | have `.requires_grad=False`. The privacy engine will do the rest for you. For instance: 53 | 54 | ```python 55 | import transformers, private_transformers 56 | 57 | model = transformers.AutoModelWithLMHead.from_pretrained('gpt') 58 | # Input embeddings aren't optimized; this line needs to proceed privacy engine creation. 59 | model.get_input_embeddings().requires_grad_(False) 60 | privacy_engine = private_transformers.PrivacyEngine(model, ...) 61 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # private-transformers 2 | 3 | This codebase facilitates fast experimentation of differentially private training 4 | of [Hugging Face transformers](https://huggingface.co/transformers/). 5 | 6 | --- 7 |

8 | 9 |

10 | 11 | ## What is this? Why an extra codebase? 12 | 13 | - This codebase provides a privacy engine that builds off and rewrites [Opacus](https://github.com/pytorch/opacus) so 14 | that integration with 15 | [Hugging Face's transformers library](https://github.com/huggingface/transformers) is easy. 16 | - Additionally, we support the *ghost clipping* technique (see Section 4 of [this](https://arxiv.org/pdf/2110.05679.pdf) 17 | preprint on how it works) which allows privately training large transformers with considerably reduced memory cost -- 18 | in many cases, almost as light as non-private training -- at a modest run-time overhead. 19 | - **With this codebase, we have fine-tuned very large pretrained models, yielding some of the best performing 20 | differentially private NLP models to date. Some of these models have performance matching strong non-private baseline 21 | approaches. We see strong empirical evidence that highly performant DP NLP models could be built on modest datasets.** 22 | 23 | ## Installation 24 | 25 | Make sure you have python>=3.8; run the following command: 26 | 27 | ```bash 28 | pip install git+https://github.com/lxuechen/private-transformers.git 29 | ``` 30 | 31 | To check the package is installed properly, be sure to run the test suite (requires pytest and a GPU) via the following 32 | command: 33 | 34 | ```bash 35 | pytest -s tests 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### Basic usage 41 | 42 | Privately training Hugging Face transformers with our codebase simply consists of 4 steps: 43 | 44 | 1. Create your favourite transformer model and optimizer; attach this optimizer to a `PrivacyEngine` 45 | 2. Compute a per-example loss (1-D tensor) for a mini-batch of data 46 | 3. Pass the loss to `optimizer.step` or `optimizer.virtual_step` as a keyword argument 47 | 4. Repeat from step 2 48 | 49 | Below is a quick example: 50 | 51 | ```python 52 | import transformers, torch 53 | from private_transformers import PrivacyEngine 54 | import torch.nn.functional as F 55 | 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | model = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2').to(device) 58 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4) 59 | privacy_engine = PrivacyEngine( 60 | model, 61 | batch_size=10, 62 | sample_size=50000, 63 | epochs=3, 64 | max_grad_norm=0.1, 65 | target_epsilon=3, 66 | ) 67 | privacy_engine.attach(optimizer) 68 | 69 | batch_size, seq_len = 10, 20 70 | # Inputs are batch-first format, i.e., the first dimension of tensors must be batch dimension. 71 | input_ids = torch.randint(size=[batch_size, seq_len], low=0, high=100, device=device) 72 | # Calling `.train()` is very important; otherwise underlying forward and backward hooks don't run. 73 | model.train() 74 | outputs = model(input_ids=input_ids, return_dict=True) 75 | labels = input_ids[:, 1:, ] 76 | logits = outputs.logits[:, :-1, :].permute(0, 2, 1) 77 | # `loss` is a 1-D tensor of shape (batch_size,). 78 | loss = F.cross_entropy(logits, labels, reduction="none").mean(dim=1) 79 | # This step is different from existing workflows: 80 | # Don't call `loss.backward`; leave it to `optimizer.step` to handle backward. 81 | optimizer.step(loss=loss) 82 | ``` 83 | 84 | The biggest differences compared to Opacus are: 85 | 86 | - We require the per-example loss (a 1-D tensor) be passed into `optimizer.step` (or `optimizer.virtual_step`). 87 | - The per-example loss must be passed in as a *keyword argument*. 88 | - `loss.backward()` shouldn't be called on the user end; it's called internally in `optimizer.step` ( 89 | or `optimizer.virtual_step`). 90 | - Inputs should be in batch-first format; there isn't a toggle to switch between different formats in the engine. 91 | 92 | ### Ghost clipping: memory saving differentially private learning 93 | 94 | Turning on ghost clipping requires changing only 1 line. You should notice a drastic reduction in peak GPU memory usage 95 | once this is turned on, at a potential cost of slower training speed. One might find this especially useful when 96 | constrained to only use older GPUs with small VRAMs or fitting super large models. 97 | 98 | ```python 99 | import transformers, torch 100 | from private_transformers import PrivacyEngine 101 | 102 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 103 | model = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2').to(device) 104 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4) 105 | privacy_engine = PrivacyEngine( 106 | model, 107 | batch_size=10, 108 | sample_size=50000, 109 | epochs=3, 110 | max_grad_norm=0.1, 111 | target_epsilon=3, 112 | clipping_mode="ghost", # The only change you need to make! 113 | ) 114 | privacy_engine.attach(optimizer) 115 | ``` 116 | 117 | ### Examples 118 | 119 | Code in the `examples` folder roughly reproduces our results for the table-to-text and classification tasks. There may 120 | be some minor discrepancies, since hyperparameters there aren't exactly what's used in the paper. Nevertheless, it 121 | should be sufficient to get things started. Detailed instructions are in the readme file of each subfolder. 122 | 123 | ### Currently supported [Hugging Face models](https://huggingface.co/transformers/pretrained_models.html) 124 | 125 | - [OpenAIGPTLMHeadModel](https://huggingface.co/docs/transformers/model_doc/openai-gpt#transformers.OpenAIGPTLMHeadModel) 126 | - [OpenAIGPTDoubleHeadsModel](https://huggingface.co/docs/transformers/model_doc/openai-gpt#transformers.OpenAIGPTDoubleHeadsModel) 127 | - [GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel) 128 | - [GPT2DoubleHeadsModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2DoubleHeadsModel) 129 | - [BertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForSequenceClassification) 130 | - [RobertaForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaForSequenceClassification) 131 | - [AlbertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertForSequenceClassification) 132 | - [BartForConditionalGeneration](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartForConditionalGeneration) 133 | (when positional embedding layers are frozen) 134 | - [T5ForConditionalGeneration](https://huggingface.co/docs/transformers/v4.20.1/en/model_doc/t5#transformers.T5ForConditionalGeneration) 135 | - [OPTForCausalLM](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTForCausalLM) 136 | - [ViTForImageClassification](https://huggingface.co/docs/transformers/v4.20.1/en/model_doc/vit#transformers.ViTForImageClassification) 137 | (when isolated parameters are frozen; see [this example](examples/image_classification/main.py)) 138 | - [DeiTForImageClassification](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTForImageClassification) 139 | (when isolated parameters are frozen) 140 | - [BeitForImageClassification](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitForImageClassification) 141 | (when isolated parameters are frozen) 142 | 143 | Not all models in the Hugging Face library are supported. The main additional work to support a model is to 144 | 145 | 1. Support per-example gradients for bespoke modules 146 | (e.g., [T5LayerNorm](https://huggingface.co/transformers/_modules/transformers/modeling_t5.html)), and 147 | 2. Ensure `position_ids` are repeated (duplicated along batch dim 0). Normally, to save memory, one creates positional 148 | embedding for one instance and rely on broadcasting when there're multiple instances within a batch. This creates a 149 | problem with per-sample gradient accumulation, so we instead duplicate inputs to positional embeddings. 150 | 151 | We plan to support more models in the future if there's such a need. Feel free to open an issue if you may want to try 152 | out specific models that aren't in the current list. 153 | 154 | ## FAQ 155 | 156 | I wrote some stuff to potential questions [here](https://github.com/lxuechen/private-transformers/blob/main/FAQ.md). 157 | These include performing gradient accumulation, ghost clipping, and freezing parts of a model. 158 | 159 | ## Acknowledgements 160 | 161 | It would have been impossible to develop this codebase without cool past works and existing codebases. We roughly follow 162 | the `PrivacyEngine` design in `Opacus==0.13.0`. We directly use 163 | an [off-the-shelf package](https://github.com/microsoft/prv_accountant) for tightly tracking tradeoff functions while 164 | composing multiple private mechanisms. 165 | 166 | ## Disclaimer 167 | 168 | - This codebase is not yet production-grade, e.g., cryptographically secure PRNGs are required for sampling noise -- our 169 | codebase currently does not use these strong PRNGs as they tend to slow down training. This codebase also isn't immune 170 | to [floating point representation attacks](https://github.com/pytorch/opacus/pull/260). 171 | - This codebase is born out of the need to experiment with various things for differentially private NLP rapidly. I've 172 | tried my best to write clean code, though parts of this codebase may be less tidy than I had hoped 173 | given the extremely tight timeline. 174 | 175 | ## Citation 176 | 177 | If you found this codebase useful in your research, please consider citing: 178 | 179 | ``` 180 | @inproceedings{ 181 | li2022large, 182 | title={Large Language Models Can Be Strong Differentially Private Learners}, 183 | author={Xuechen Li and Florian Tramer and Percy Liang and Tatsunori Hashimoto}, 184 | booktitle={International Conference on Learning Representations}, 185 | year={2022}, 186 | url={https://openreview.net/forum?id=bVuP3ltATMz} 187 | } 188 | 189 | @inproceedings{ 190 | li2022when, 191 | title={When Does Differentially Private Learning Not Suffer in High Dimensions?}, 192 | author={Xuechen Li and Daogao Liu and Tatsunori Hashimoto and Huseyin A Inan and Janardhan Kulkarni and YinTat Lee and Abhradeep Guha Thakurta}, 193 | booktitle={Advances in Neural Information Processing Systems}, 194 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 195 | year={2022}, 196 | url={https://openreview.net/forum?id=FR--mkQu0dw} 197 | } 198 | ``` 199 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/private-transformers/18ccc4eab7355e4ac96051a82434796f6aa4624b/assets/fig1.png -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/classification/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing results for sentence classification 2 | 3 | ### Requirements 4 | 5 | In addition to requirements of the `private-transformers` package, install additional requirements by running the 6 | following from the `examples` folder of this repo: 7 | 8 | ```bash 9 | pip install -r classification/requirements.txt 10 | ``` 11 | 12 | This code is tested against `transformers==4.11.3`, but should also work for slightly earlier versions. 13 | 14 | ### Getting the data 15 | 16 | This part of the codebase is adapted from the excellent work 17 | by [[Gao et al., 2021](https://arxiv.org/pdf/2012.15723.pdf)]. We reuse their data pipeline. To obtain the data, run the 18 | following: 19 | 20 | ```bash 21 | cd data 22 | bash download_dataset.sh 23 | ``` 24 | 25 | This should produce a `data/original` subfolder that contains all the data that we need. 26 | 27 | ### Running 28 | 29 | Use the `run_wrapper.py` script in the folder. This Python script produces a text string for the command and runs it. 30 | 31 | Supply at least 2 arguments: 32 | 33 | - `--output_dir`: path to a folder where results will be written 34 | - `--task_name`: name of task; one of `sst-2`, `qnli`, `qqp`, `mnli` 35 | 36 | For instance, run the following under the `examples/` folder: 37 | 38 | ```bash 39 | python -m classification.run_wrapper --output_dir --task_name 40 | ``` 41 | 42 | The script by default uses ghost clipping, and the micro batch size is tweaked so that things should run smoothly even 43 | on a Titan Xp with 12Gigs of VRAM. For SST-2, the run-time of this script on an RTX 3090 is roughly less than one and a 44 | half hours. Larger datasets take longer to train. 45 | 46 | Additional arguments: 47 | 48 | - `--target_epsilon`: Target privacy spending 49 | - `--model_name_or_path`: The pretrained model; one of `distilbert-base-uncased`, `bert-base-uncased` 50 | , `bert-large-uncased`, `distilroberta-base`, `roberta-base`, `roberta-large` 51 | - `--few_shot_type`: Whether to use the generic prompt formatter described in Section 3.2 of our paper. `prompt` is to 52 | use, `finetune` is to not use. 53 | - `--ghost_clipping`: Whether to use ghost clipping for memory saving; one of `yes`, `no` 54 | Note keeping other training hyperparameter (e.g., number of training epochs, clipping norm, learning rate) the same, 55 | things should still work 56 | - `--data_dir`: Path to where data is stored; if data is obtained via the procedure described above, just stick to the 57 | defaults. 58 | 59 | Training on the larger datasets for even more epochs should bring further performance gains. 60 | 61 | ### Notes 62 | 63 | - We have reproduced some results in our paper with the codebase of 64 | a [concurrent anonymous submission](https://openreview.net/pdf?id=Q42f0dfjECO). Our modified version of their codebase 65 | is located at [this link](https://github.com/lxuechen/Differentially-Private-Fine-tuning-of-Language-Models). This 66 | code is modified from their original codebase and only optimizes the dense/linear layers in a Transformer model, and 67 | hence is not strictly full fine-tuning (since the embedding and LayerNorm layers aren't updated). The main difference 68 | from their original setup is that we run everything in full precision (i.e., fp32), not mixed-precision. 69 | - We got similar results as those reported in the paper with Opacus, but with the embedding subnetworks (word embedding, 70 | positional embedding, token type embedding) frozen. Note that unfreezing the embedding subnetwork and plugging such a 71 | model (from HF) into Opacus would result in errors, due to how HF transformers are implemented. 72 | -------------------------------------------------------------------------------- /examples/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/classification/data/download_dataset.sh: -------------------------------------------------------------------------------- 1 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar 2 | tar xvf datasets.tar 3 | -------------------------------------------------------------------------------- /examples/classification/data/k-shot/checksum: -------------------------------------------------------------------------------- 1 | fe166f3b9e9952cb729a8d2c3bd682ae SST-2/16-13/train.tsv 2 | 0eeeb2195d82899d15ab5eeaddcaa944 SST-2/16-13/dev.tsv 3 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-13/test.tsv 4 | 6661938fcefe356e8c9c3f866b67c047 SST-2/16-21/train.tsv 5 | 74a12c79dea73deb5526c398b1bc7bc0 SST-2/16-21/dev.tsv 6 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-21/test.tsv 7 | edce4c52734c2084476ac03bfdfa5c77 SST-2/16-42/train.tsv 8 | 4ba7f2ffe0fc89844d1184f294f15774 SST-2/16-42/dev.tsv 9 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-42/test.tsv 10 | 783b9e625ec6dea210cc4062d464aba2 SST-2/16-87/train.tsv 11 | 2972df0d5eace29b48eb8ce77770c3d5 SST-2/16-87/dev.tsv 12 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-87/test.tsv 13 | f5079b5b5bd27087a7d836e7e8faedc7 SST-2/16-100/train.tsv 14 | 1d65f0f67026bbdcbfef6ed3ada2510a SST-2/16-100/dev.tsv 15 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-100/test.tsv 16 | c446d010f7ea4846be28d58d81169648 sst-5/16-13/train.csv 17 | aca7b0130f532a7a3289feaf2c9cf510 sst-5/16-13/dev.csv 18 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-13/test.csv 19 | e3d8f0d8ee2c3d19cbb8cf370a80c0ff sst-5/16-21/train.csv 20 | a673a34db10fa7ea40bee8d635a58d71 sst-5/16-21/dev.csv 21 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-21/test.csv 22 | c53271e5fbb67f49ff6ccec9058f4198 sst-5/16-42/train.csv 23 | 040144885c068d74bfee3415ee77a9f9 sst-5/16-42/dev.csv 24 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-42/test.csv 25 | c77128dc79022b7d645e587823c90fa9 sst-5/16-87/train.csv 26 | 74a356f47dbeaaba86dc87d6de8f822b sst-5/16-87/dev.csv 27 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-87/test.csv 28 | bee588d86301e040f13b2e26d3cbd3bc sst-5/16-100/train.csv 29 | 3de890aa28b082bb00c88e92329d9736 sst-5/16-100/dev.csv 30 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-100/test.csv 31 | 77e589985c38f642b743c00f454a8f72 mr/16-13/train.csv 32 | 0487ece2b558b211787050177529cc15 mr/16-13/dev.csv 33 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-13/test.csv 34 | 50f61d4ad574c862103a281cfb277568 mr/16-21/train.csv 35 | 36a5b9a0aef14b5ef85a794a79f7a670 mr/16-21/dev.csv 36 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-21/test.csv 37 | 7f2be6d8286c853be70d6b731529110e mr/16-42/train.csv 38 | f331ad440925ff5550065bafa4d1e14e mr/16-42/dev.csv 39 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-42/test.csv 40 | 9ad43ad9a35fd4fe88a8d74c164f3056 mr/16-87/train.csv 41 | 7650a096ae10bc0d22de08d28b080736 mr/16-87/dev.csv 42 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-87/test.csv 43 | edab3bd41c79c073dda7469f31761202 mr/16-100/train.csv 44 | d241fa66a9302aa417876971fcc309a1 mr/16-100/dev.csv 45 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-100/test.csv 46 | fef4839c9b1532845efe85f459c4657a cr/16-13/train.csv 47 | 543574c5b1fe954e1d82f109a060b8a1 cr/16-13/dev.csv 48 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-13/test.csv 49 | 890b0ed3d8f57ed19e90884580cc08af cr/16-21/train.csv 50 | 02246fc77e862f8d973058478d89bfb9 cr/16-21/dev.csv 51 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-21/test.csv 52 | 09260d2c41f4fbf48fad2cdb09b9bd19 cr/16-42/train.csv 53 | 23f70a47d55def2c4308cebcdaf1f5b9 cr/16-42/dev.csv 54 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-42/test.csv 55 | 2c47c48c1223ff979f2e6e8243fa3219 cr/16-87/train.csv 56 | 1f0f5e7a984740751ea2b5314f847566 cr/16-87/dev.csv 57 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-87/test.csv 58 | 6c11804b15bbe6582ebf0eab71513d0d cr/16-100/train.csv 59 | 844e4d9cbc956e9495dfd77e3881caba cr/16-100/dev.csv 60 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-100/test.csv 61 | 18764f2c39c6fa88b4c3871d79ea3a99 mpqa/16-13/train.csv 62 | 5d3c82a72ae0319afbe940e4d35ec402 mpqa/16-13/dev.csv 63 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-13/test.csv 64 | 1e71fb7b6def9a89678f350c92f386f7 mpqa/16-21/train.csv 65 | 37a581edfaa917b4975777d1bb531dad mpqa/16-21/dev.csv 66 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-21/test.csv 67 | 4d889a2e53e2fc9f89aef415e52c4376 mpqa/16-42/train.csv 68 | d58bd2e6128dd61460e597b41aed4073 mpqa/16-42/dev.csv 69 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-42/test.csv 70 | a11bd12d360e1a0c66fe5823b95d40db mpqa/16-87/train.csv 71 | a9aea0b29f42d955c430249b5aba3ff4 mpqa/16-87/dev.csv 72 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-87/test.csv 73 | d8b8864485fef5582b862e999f1d78fa mpqa/16-100/train.csv 74 | d05230e664b98d5c368a4616fed106b8 mpqa/16-100/dev.csv 75 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-100/test.csv 76 | 9007424294b4a73428032b4ba872d798 subj/16-13/train.csv 77 | ed1c7a9a31cb8ed53798e5d5c0a0ba6a subj/16-13/dev.csv 78 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-13/test.csv 79 | e91660061eac76f4d2496f0ee3a17a55 subj/16-21/train.csv 80 | 9bb1c6a470f9c437734d5000a9887298 subj/16-21/dev.csv 81 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-21/test.csv 82 | 5ad5810af6f69c0cae872f9b8ac96e4b subj/16-42/train.csv 83 | d689743060587300b50ba25303cee457 subj/16-42/dev.csv 84 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-42/test.csv 85 | 116db5d05fb7152af549b322138ddf0c subj/16-87/train.csv 86 | 8fcbfc9937902ed65b5f457db143eeae subj/16-87/dev.csv 87 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-87/test.csv 88 | dafc9f79ea2643f9d0db4bb9cd50ed68 subj/16-100/train.csv 89 | d205bd3e890483010b73b40fe2341874 subj/16-100/dev.csv 90 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-100/test.csv 91 | 50cff1542195f5321cb951ca1d980ec0 trec/16-13/train.csv 92 | 97ca7ed4ec756c13b6233c2d66c97ee5 trec/16-13/dev.csv 93 | b4e08b69eb3aae325197216450e3a28b trec/16-13/test.csv 94 | 178b968a79785571398a99082f2059aa trec/16-21/train.csv 95 | df1be41d4a11a1f10b99bfa9dab069fd trec/16-21/dev.csv 96 | b4e08b69eb3aae325197216450e3a28b trec/16-21/test.csv 97 | f17a8bb90b8110ad37632ded1dc0b515 trec/16-42/train.csv 98 | 682e1b9724b472132249396e826c7741 trec/16-42/dev.csv 99 | b4e08b69eb3aae325197216450e3a28b trec/16-42/test.csv 100 | efe409aae7de9e639e5da33430626e83 trec/16-87/train.csv 101 | 62eb496a54c2c810703c8a3d47213237 trec/16-87/dev.csv 102 | b4e08b69eb3aae325197216450e3a28b trec/16-87/test.csv 103 | 5df9333e324779c32bdc14eed7a3f51c trec/16-100/train.csv 104 | 983ce074f8cc101e7539122d20dc2815 trec/16-100/dev.csv 105 | b4e08b69eb3aae325197216450e3a28b trec/16-100/test.csv 106 | 8b3b16bdc7ccdf7c5e0397caf1ce61c1 CoLA/16-13/train.tsv 107 | 2a3b6875b91447043e5fc7366d6d6c59 CoLA/16-13/dev.tsv 108 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-13/test.tsv 109 | 9af338ce1d82a300b1d09e06da913454 CoLA/16-21/train.tsv 110 | ad4bd1c5d1bbc9cb149cc10648ade05f CoLA/16-21/dev.tsv 111 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-21/test.tsv 112 | 35870db90cf9b4044159d49cf4c4a620 CoLA/16-42/train.tsv 113 | 6f2b371c0fa209c253d49e20941d44c5 CoLA/16-42/dev.tsv 114 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-42/test.tsv 115 | f0260a425124e254cef6c41e78c3b990 CoLA/16-87/train.tsv 116 | 30975efe3642c02d2afaf6fa6849b5b1 CoLA/16-87/dev.tsv 117 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-87/test.tsv 118 | c02bce246a4481b2733522d82200d81e CoLA/16-100/train.tsv 119 | fda65a964d3ebeec632f8c45fac8209e CoLA/16-100/dev.tsv 120 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-100/test.tsv 121 | 6c61998bef38ccadacc2df0013035715 MNLI/16-13/train.tsv 122 | 5e84a205fcb8c20d8706df7e5cc7c03f MNLI/16-13/dev_matched.tsv 123 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-13/test_matched.tsv 124 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-13/test_mismatched.tsv 125 | 139ce8bcc9b973b164fe2065a507819a MNLI/16-21/train.tsv 126 | 5da136973f2ccb8c3cd421c74cf343e8 MNLI/16-21/dev_matched.tsv 127 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-21/test_matched.tsv 128 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-21/test_mismatched.tsv 129 | 5adecf01452cdd271940d5d7f094c77e MNLI/16-42/train.tsv 130 | c51ee1e3afccc39cd4f4fd660f6b547f MNLI/16-42/dev_matched.tsv 131 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-42/test_matched.tsv 132 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-42/test_mismatched.tsv 133 | c2540989877d39d56aa34f503bf8d145 MNLI/16-87/train.tsv 134 | f53efd36c22ed01af9b61ee6f917a714 MNLI/16-87/dev_matched.tsv 135 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-87/test_matched.tsv 136 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-87/test_mismatched.tsv 137 | 288a569addee0a37675ba0d27bb5cf9a MNLI/16-100/train.tsv 138 | 2139f9ecc2aabf761ca904a7ecfa2652 MNLI/16-100/dev_matched.tsv 139 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-100/test_matched.tsv 140 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-100/test_mismatched.tsv 141 | 10672f831d104166efe0d3678303f612 SNLI/16-13/train.tsv 142 | 85e3b9d4646c19709cdd09a36c6dda7e SNLI/16-13/dev.tsv 143 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-13/test.tsv 144 | b2dbd54b5152c755cda4c0b3b5f0f0f2 SNLI/16-21/train.tsv 145 | 204397b61efbfe84a731183340c935df SNLI/16-21/dev.tsv 146 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-21/test.tsv 147 | 684cdeec26f1150b412a6fe42b203215 SNLI/16-42/train.tsv 148 | fccf683689b3276582041c2b1108f8fe SNLI/16-42/dev.tsv 149 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-42/test.tsv 150 | 442e9e4ab7372b1b15ca7f40d87559f0 SNLI/16-87/train.tsv 151 | f4359f6d8c33af80c22e089a63fd4c24 SNLI/16-87/dev.tsv 152 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-87/test.tsv 153 | 2b68c0dff2302bd3b013a1de14d20f1b SNLI/16-100/train.tsv 154 | b0bf433e9c0f6b5456ae5833dee41eee SNLI/16-100/dev.tsv 155 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-100/test.tsv 156 | 4fd9e83941409d616f67a4fb6e68206b QNLI/16-13/train.tsv 157 | a54657913dd71710bb0eff9b8c044746 QNLI/16-13/dev.tsv 158 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-13/test.tsv 159 | f316c511c6990c2648b821b99351f185 QNLI/16-21/train.tsv 160 | f7dd6b59f755d86a00395d2ee7fc4938 QNLI/16-21/dev.tsv 161 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-21/test.tsv 162 | edbea1a802d48db725c61d487597ab60 QNLI/16-42/train.tsv 163 | e0d03f6166b24a8c16985e32a8315427 QNLI/16-42/dev.tsv 164 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-42/test.tsv 165 | 2bcc47b08f67553f5e9f45991af07d10 QNLI/16-87/train.tsv 166 | 2b7f1fc57c802433c59357f659888a08 QNLI/16-87/dev.tsv 167 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-87/test.tsv 168 | 85c42afc824ac7a0474ccc42239f664c QNLI/16-100/train.tsv 169 | 9f193ffb9e102d28a25524f9a778b819 QNLI/16-100/dev.tsv 170 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-100/test.tsv 171 | 0aaae3483435626089adbc4ae851b75d RTE/16-13/train.tsv 172 | 8fc45ab58e72e16cf90547710f3b5d15 RTE/16-13/dev.tsv 173 | 973cb4178d4534cf745a01c309d4a66c RTE/16-13/test.tsv 174 | 79490f931a776fad7f3b2ba31f9142d1 RTE/16-21/train.tsv 175 | 96ac052f30101901e6bc04adb4875edb RTE/16-21/dev.tsv 176 | 973cb4178d4534cf745a01c309d4a66c RTE/16-21/test.tsv 177 | 4d45391333d2f56d16633d7b85d3e051 RTE/16-42/train.tsv 178 | a64069d0ca424995d35e05e67c03ceb9 RTE/16-42/dev.tsv 179 | 973cb4178d4534cf745a01c309d4a66c RTE/16-42/test.tsv 180 | 46af88a12bc44c12a40ea2cb79e94837 RTE/16-87/train.tsv 181 | 19ee5593b53d1b290737eb9de2ca4869 RTE/16-87/dev.tsv 182 | 973cb4178d4534cf745a01c309d4a66c RTE/16-87/test.tsv 183 | 548d8775d98ca0638615b9e0bfbe5c32 RTE/16-100/train.tsv 184 | bc66584b1fa19f9620051e8db7bcc1dc RTE/16-100/dev.tsv 185 | 973cb4178d4534cf745a01c309d4a66c RTE/16-100/test.tsv 186 | 4ee5a1591a6b385ed235dff1fe0bc2ff MRPC/16-13/train.tsv 187 | 55da0a2ef25c0bdda9ba0bbe1a757644 MRPC/16-13/dev.tsv 188 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-13/test.tsv 189 | 17f3132a74d23a122c8e8bc6936de492 MRPC/16-21/train.tsv 190 | 2f0e716c6a5ddaf45d03d0b48075fd5a MRPC/16-21/dev.tsv 191 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-21/test.tsv 192 | f96a9434f7cdef032f544f405498c266 MRPC/16-42/train.tsv 193 | a2a9648bec87318674338caed2538091 MRPC/16-42/dev.tsv 194 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-42/test.tsv 195 | cad49c8459f1eef52093b9527ec87735 MRPC/16-87/train.tsv 196 | b9664f662b04ec18feb8605462a8a1cb MRPC/16-87/dev.tsv 197 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-87/test.tsv 198 | 29d9a7598589c1636f64c7cf3dd25de8 MRPC/16-100/train.tsv 199 | f65a7cbec81d738f594db0040fb32610 MRPC/16-100/dev.tsv 200 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-100/test.tsv 201 | 89173509ae29f234632fe2e7805860ba QQP/16-13/train.tsv 202 | 14a3c59ae8a5298372923dadc3c86012 QQP/16-13/dev.tsv 203 | cff6a448d1580132367c22fc449ec214 QQP/16-13/test.tsv 204 | ce7ed8d40a8e0c090b8c4634914ef895 QQP/16-21/train.tsv 205 | c29687f39ece2c692efca3ade0b91912 QQP/16-21/dev.tsv 206 | cff6a448d1580132367c22fc449ec214 QQP/16-21/test.tsv 207 | 88f2cc2d8a4c7b691099a47794ce2aee QQP/16-42/train.tsv 208 | c1b2198ba8a84976bb1084c164418234 QQP/16-42/dev.tsv 209 | cff6a448d1580132367c22fc449ec214 QQP/16-42/test.tsv 210 | 7110878494e561015e94fe459115c92c QQP/16-87/train.tsv 211 | cb7a47bad6671d75a93ec2db13c7964a QQP/16-87/dev.tsv 212 | cff6a448d1580132367c22fc449ec214 QQP/16-87/test.tsv 213 | b6a7ea69840e305608faabd068a74f06 QQP/16-100/train.tsv 214 | 136f4a247a897a80285a1e215a850d0a QQP/16-100/dev.tsv 215 | cff6a448d1580132367c22fc449ec214 QQP/16-100/test.tsv 216 | 634047695a8036cfe2523cbbc58f2fbb STS-B/16-13/train.tsv 217 | f17c612f0da73b1599a9946b1e8c7d06 STS-B/16-13/dev.tsv 218 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-13/test.tsv 219 | af75494229a00b1ea764e7e38a3fbcf8 STS-B/16-21/train.tsv 220 | a79b76d7ee8efcb7cdefb4056b0606b8 STS-B/16-21/dev.tsv 221 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-21/test.tsv 222 | c4577d13b90f33d9a61ca7abb8126132 STS-B/16-42/train.tsv 223 | 08b70fdb9484181f9331409329cd2259 STS-B/16-42/dev.tsv 224 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-42/test.tsv 225 | c1a41f8d2a0867802920e447b4e6e73b STS-B/16-87/train.tsv 226 | 430153037049cab16d6141e766b4cca2 STS-B/16-87/dev.tsv 227 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-87/test.tsv 228 | 5f8d5211ff72f49ebf181b63e6390449 STS-B/16-100/train.tsv 229 | d826c4b3fa6d419aed50442b7f1f6adc STS-B/16-100/dev.tsv 230 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-100/test.tsv 231 | -------------------------------------------------------------------------------- /examples/classification/data/make_k_shot_without_dev.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """The datasets in the k-shot folder contain dev.tsv; we make the test set the dev set in the new k-shot. 16 | 17 | python -m classification.data.make_k_shot_without_dev 18 | """ 19 | import os 20 | 21 | from ml_swissknife import utils 22 | 23 | join = os.path.join 24 | 25 | base_dir = '/nlp/scr/lxuechen/data/lm-bff/data/k-shot' 26 | new_dir = '/nlp/scr/lxuechen/data/lm-bff/data/k-shot-no-dev' 27 | 28 | task_names = ("SST-2", "QNLI", "MNLI", "QQP") 29 | for task_name in task_names: 30 | folder = join(base_dir, task_name) 31 | new_folder = join(new_dir, task_name) 32 | 33 | for name in utils.listdir(folder): 34 | subfolder = join(folder, name) 35 | new_subfolder = join(new_folder, name) 36 | os.makedirs(new_subfolder, exist_ok=True) 37 | 38 | train = join(subfolder, 'train.tsv') 39 | new_train = join(new_subfolder, 'train.tsv') 40 | os.system(f'cp {train} {new_train}') 41 | 42 | if task_name == "MNLI": 43 | test = join(subfolder, 'test_matched.tsv') 44 | new_dev = join(new_subfolder, 'dev_matched.tsv') 45 | os.system(f'cp {test} {new_dev}') 46 | 47 | test = join(subfolder, 'test_mismatched.tsv') 48 | new_dev = join(new_subfolder, 'dev_mismatched.tsv') 49 | os.system(f'cp {test} {new_dev}') 50 | else: 51 | test = join(subfolder, 'test.tsv') 52 | new_dev = join(new_subfolder, 'dev.tsv') 53 | os.system(f'cp {test} {new_dev}') 54 | -------------------------------------------------------------------------------- /examples/classification/data/make_valid_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Make the separate validation data, so that we don't tune on dev set. 16 | 17 | python -m classification.data.make_valid_data 18 | """ 19 | import os 20 | 21 | import fire 22 | import numpy as np 23 | import tqdm 24 | 25 | 26 | def write_lines(path, lines, mode="w"): 27 | os.makedirs(os.path.dirname(path), exist_ok=True) 28 | with open(path, mode) as f: 29 | f.writelines(lines) 30 | print(len(lines)) 31 | 32 | 33 | def main(): 34 | valid_percentage = 0.1 35 | original_dir = "/nlp/scr/lxuechen/data/lm-bff/data/original" 36 | new_dir = "/nlp/scr/lxuechen/data/lm-bff/data/glue-with-validation" 37 | 38 | task_folders = ("GLUE-SST-2", "QNLI", "QQP") 39 | for task_folder in task_folders: 40 | # Create train and valid splits. 41 | full_train_path = os.path.join(original_dir, task_folder, 'train.tsv') 42 | with open(full_train_path, 'r') as f: 43 | full_train = f.readlines() 44 | 45 | header = full_train[0] 46 | full_train = full_train[1:] # Remove header. 47 | 48 | indices = np.random.permutation(len(full_train)) 49 | new_valid_size = int(len(indices) * valid_percentage) 50 | new_train_size = len(indices) - new_valid_size 51 | new_train_indices = indices[:new_train_size] 52 | new_valid_indices = indices[new_train_size:] 53 | assert len(new_train_indices) == new_train_size 54 | assert len(new_valid_indices) == new_valid_size 55 | 56 | new_train = [header] + [full_train[i] for i in new_train_indices] 57 | new_valid = [header] + [full_train[i] for i in new_valid_indices] 58 | 59 | new_train_path = os.path.join(new_dir, task_folder, 'train.tsv') 60 | new_valid_path = os.path.join(new_dir, task_folder, 'dev.tsv') 61 | 62 | write_lines(new_train_path, new_train) 63 | write_lines(new_valid_path, new_valid) 64 | del new_train, new_valid, new_train_path, new_valid_path 65 | del new_train_size, new_train_indices 66 | del new_valid_size, new_valid_indices 67 | 68 | # Make test! 69 | test_path = os.path.join(original_dir, task_folder, 'dev.tsv') 70 | new_test_path = os.path.join(new_dir, task_folder, 'test.tsv') 71 | os.system(f'cp {test_path} {new_test_path}') 72 | del test_path, new_test_path 73 | 74 | # Make valid set for MNLI; different, since matched/mismatched! 75 | task_folder = "MNLI" 76 | matched_genres = ['slate', 'government', 'telephone', 'travel', 'fiction'] 77 | mismatched_genres = ['letters', 'verbatim', 'facetoface', 'oup', 'nineeleven'] 78 | full_train_path = os.path.join(original_dir, task_folder, 'train.tsv') 79 | with open(full_train_path, 'r') as f: 80 | full_train = f.readlines() 81 | full_train_csv = [line.split('\t') for line in full_train] 82 | 83 | # Check the lengths are correct. 84 | l = len(full_train_csv[0]) 85 | for line in full_train_csv: 86 | assert l == len(line) 87 | 88 | # Remove header. 89 | header = full_train[0] 90 | header_csv = full_train_csv[0] 91 | 92 | full_train = full_train[1:] 93 | full_train_csv = full_train_csv[1:] 94 | 95 | # Get index of genre. 96 | genre_index = header_csv.index('genre') 97 | 98 | # Shuffle both! 99 | indices = np.random.permutation(len(full_train)) 100 | full_train = [full_train[i] for i in indices] 101 | full_train_csv = [full_train_csv[i] for i in indices] 102 | 103 | # Split validation. 104 | new_valid_size = int(len(indices) * valid_percentage) 105 | new_matched_valid_size = new_mismatched_valid_size = new_valid_size // 2 106 | 107 | # Fetch the indices. 108 | new_train_indices = [] 109 | new_matched_valid_indices = [] 110 | new_mismatched_valid_indices = [] 111 | matched_count = mismatched_count = 0 112 | for i, row in enumerate(full_train_csv): 113 | genre = row[genre_index] 114 | if genre in matched_genres and matched_count < new_matched_valid_size: 115 | new_matched_valid_indices.append(i) 116 | matched_count += 1 117 | elif genre in mismatched_genres and mismatched_count < new_mismatched_valid_size: 118 | new_mismatched_valid_indices.append(i) 119 | mismatched_count += 1 120 | else: 121 | new_train_indices.append(i) 122 | 123 | new_matched_valid_indices = set(new_matched_valid_indices) 124 | new_mismatched_valid_indices = set(new_mismatched_valid_indices) 125 | 126 | new_train = [header] 127 | new_matched_valid = [header] 128 | new_mismatched_valid = [header] 129 | for i, line in tqdm.tqdm(enumerate(full_train)): 130 | if i in new_matched_valid_indices: 131 | new_matched_valid.append(line) 132 | elif i in new_mismatched_valid_indices: 133 | new_mismatched_valid.append(line) 134 | else: 135 | new_train.append(line) 136 | 137 | new_train_path = os.path.join(new_dir, task_folder, 'train.tsv') 138 | new_matched_valid_path = os.path.join(new_dir, task_folder, 'dev_matched.tsv') 139 | new_mismatched_valid_path = os.path.join(new_dir, task_folder, 'dev_mismatched.tsv') 140 | 141 | write_lines(new_train_path, new_train) 142 | write_lines(new_matched_valid_path, new_matched_valid) 143 | write_lines(new_mismatched_valid_path, new_mismatched_valid) 144 | 145 | matched_test_path = os.path.join(original_dir, task_folder, 'dev_matched.tsv') 146 | new_matched_test_path = os.path.join(new_dir, task_folder, 'test_matched.tsv') 147 | os.system(f'cp {matched_test_path} {new_matched_test_path}') 148 | 149 | mismatched_test_path = os.path.join(original_dir, task_folder, 'dev_mismatched.tsv') 150 | new_mismatched_test_path = os.path.join(new_dir, task_folder, 'test_mismatched.tsv') 151 | os.system(f'cp {mismatched_test_path} {new_mismatched_test_path}') 152 | 153 | 154 | if __name__ == "__main__": 155 | fire.Fire(main) 156 | -------------------------------------------------------------------------------- /examples/classification/requirements.txt: -------------------------------------------------------------------------------- 1 | argcomplete==1.12.1 2 | avro-python3==1.9.2.1 3 | azure-storage-blob==12.4.0 4 | bottle==0.12.19 5 | certifi==2021.5.30 6 | chardet==3.0.4 7 | charset-normalizer==2.0.4 8 | click==8.0.1 9 | crcmod==1.7 10 | cycler==0.10.0 11 | diffimg==0.2.3 12 | docopt==0.6.2 13 | fastavro==1.4.1 14 | filelock==3.0.12 15 | fire 16 | fusepy==2.0.4 17 | future==0.18.2 18 | gdown>=3.13.0 19 | httplib2==0.17.4 20 | idna==3.2 21 | imageio==2.9.0 22 | indexed-gzip-fileobj-fork-epicfaace==1.5.4 23 | isodate==0.6.0 24 | joblib==1.0.1 25 | kiwisolver==1.3.1 26 | markdown2==2.3.10 27 | marshmallow==2.15.1 28 | marshmallow-jsonapi==0.15.1 29 | matplotlib==3.4.3 30 | mock==2.0.0 31 | networkx==2.6.2 32 | nltk==3.6.2 33 | numpy==1.21.2 34 | oauth2client==4.1.3 35 | opacus==0.13.0 36 | packaging==21.0 37 | pandas==1.3.2 38 | pathtools==0.1.2 39 | pbr==5.6.0 40 | Pillow==8.3.1 41 | psutil==5.7.2 42 | pyasn1==0.4.8 43 | pyasn1-modules==0.2.8 44 | pycparser==2.20 45 | pydot==1.4.2 46 | pymongo==3.11.4 47 | pyparsing==2.4.7 48 | PySocks==1.7.1 49 | python-dateutil==2.8.* 50 | pytz==2021.1 51 | PyWavelets==1.1.1 52 | PyYAML==5.4.* 53 | regex==2021.8.3 54 | requests 55 | retry==0.9.2 56 | sacremoses==0.0.45 57 | scikit-image==0.18.2 58 | scikit-learn==0.24.2 59 | scipy==1.7.1 60 | seaborn==0.11.2 61 | selenium==3.141.0 62 | sentence-transformers>=2.0.0 63 | sentencepiece==0.1.96 64 | sentry-sdk==0.18.0 65 | six==1.15.0 66 | SQLAlchemy==1.3.19 67 | termcolor==1.1.0 68 | threadpoolctl==2.2.0 69 | tifffile==2021.8.8 70 | tokenizers==0.10.3 71 | tqdm>=4.62.1 72 | typing-extensions==3.7.4.3 73 | urllib3==1.26.* 74 | watchdog==0.10.3 75 | websocket-client==1.0.1 76 | gpytorch 77 | jupyterlab 78 | -------------------------------------------------------------------------------- /examples/classification/run_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Wrapper launcher script.""" 16 | 17 | import os 18 | 19 | import fire 20 | 21 | from .src import common 22 | 23 | 24 | def _get_command( 25 | task_name, 26 | output_dir, 27 | model_name_or_path, 28 | data_dir, 29 | learning_rate, 30 | clipping_mode: str, 31 | non_private, 32 | target_epsilon, 33 | few_shot_type, 34 | seed, 35 | attention_only, 36 | static_lm_head, 37 | static_embedding, 38 | randomly_initialize, 39 | per_device_train_batch_size, 40 | batch_size, 41 | num_train_epochs, 42 | eval_steps, 43 | eval_spectrum, 44 | max_spectrum_batches, 45 | max_lanczos_iter, 46 | store_grads, 47 | orthogonal_projection_path, 48 | orthogonal_projection_rank, 49 | ): 50 | task_name_to_factor = { 51 | "sst-2": 1, "qnli": 2, "qqp": 6, "mnli": 6, 52 | } 53 | factor = task_name_to_factor[task_name] 54 | 55 | if batch_size is None: 56 | base_batch_size = 1000 57 | # This batch size selection roughly ensures the sampling rates on different 58 | # datasets are in the same ballpark. 59 | batch_size = int(base_batch_size * factor) 60 | gradient_accumulation_steps = batch_size // per_device_train_batch_size 61 | 62 | if num_train_epochs is None: 63 | base_num_train_epochs = 3 64 | num_train_epochs = int(base_num_train_epochs * factor) 65 | 66 | if learning_rate is None: 67 | if non_private.lower() in ('yes', 'y', 'true', 't'): 68 | learning_rate = 5e-5 69 | else: 70 | learning_rate = 5e-4 71 | 72 | data_dir = f"{data_dir}/{common.task_name2suffix_name[task_name]}" 73 | template = { 74 | "sst-2": "*cls**sent_0*_It_was*mask*.*sep+*", 75 | "mnli": "*cls**sent-_0*?*mask*,*+sentl_1**sep+*", 76 | "qnli": "*cls**sent-_0*?*mask*,*+sentl_1**sep+*", 77 | "qqp": "*cls**sent-_0**mask*,*+sentl_1**sep+*", 78 | }[task_name] 79 | 80 | # Epochs chosen roughly to match e2e number of updates. We didn't hyperparameter tune on classification tasks :) 81 | cmd = f''' 82 | python -m classification.run_classification \ 83 | --task_name {task_name} \ 84 | --data_dir {data_dir} \ 85 | --output_dir {output_dir} \ 86 | --overwrite_output_dir \ 87 | --model_name_or_path {model_name_or_path} \ 88 | --few_shot_type {few_shot_type} \ 89 | --num_k 1 \ 90 | --num_sample 1 --seed {seed} \ 91 | --template {template} \ 92 | --non_private {non_private} \ 93 | --num_train_epochs {num_train_epochs} \ 94 | --target_epsilon {target_epsilon} \ 95 | --per_device_train_batch_size {per_device_train_batch_size} \ 96 | --gradient_accumulation_steps {gradient_accumulation_steps} \ 97 | --per_device_eval_batch_size 8 \ 98 | --per_example_max_grad_norm 0.1 --clipping_mode {clipping_mode} \ 99 | --learning_rate {learning_rate} \ 100 | --lr_decay yes \ 101 | --adam_epsilon 1e-08 \ 102 | --weight_decay 0 \ 103 | --max_seq_len 256 \ 104 | --evaluation_strategy steps --eval_steps {eval_steps} --evaluate_before_training True \ 105 | --do_train --do_eval \ 106 | --first_sent_limit 200 --other_sent_limit 200 --truncate_head yes \ 107 | --attention_only {attention_only} --static_lm_head {static_lm_head} --static_embedding {static_embedding} \ 108 | --randomly_initialize {randomly_initialize} \ 109 | --eval_spectrum {eval_spectrum} --max_spectrum_batches {max_spectrum_batches} --max_lanczos_iter {max_lanczos_iter} \ 110 | --store_grads {store_grads}''' 111 | if orthogonal_projection_path is not None: 112 | cmd += f' --orthogonal_projection_path {orthogonal_projection_path}' 113 | cmd += f' --orthogonal_projection_rank {orthogonal_projection_rank}' 114 | return cmd 115 | 116 | 117 | def main( 118 | output_dir, 119 | task_name, 120 | few_shot_type="prompt", 121 | seed=42, 122 | model_name_or_path="roberta-base", 123 | data_dir="classification/data/original", 124 | learning_rate=None, 125 | clipping_mode="ghost", 126 | non_private="no", 127 | target_epsilon=8, 128 | attention_only="no", 129 | static_lm_head="no", 130 | static_embedding="no", 131 | per_device_train_batch_size=20, 132 | eval_steps=10, 133 | eval_spectrum="no", 134 | max_spectrum_batches=2, 135 | max_lanczos_iter=2, 136 | randomly_initialize="no", 137 | batch_size=None, 138 | num_train_epochs=None, 139 | store_grads="no", 140 | orthogonal_projection_path=None, 141 | orthogonal_projection_rank=100, 142 | ): 143 | command = _get_command( 144 | output_dir=output_dir, 145 | task_name=task_name, 146 | model_name_or_path=model_name_or_path, 147 | data_dir=data_dir, 148 | learning_rate=learning_rate, 149 | clipping_mode=clipping_mode, 150 | non_private=non_private, 151 | target_epsilon=target_epsilon, 152 | few_shot_type=few_shot_type, 153 | seed=seed, 154 | attention_only=attention_only, 155 | static_lm_head=static_lm_head, 156 | static_embedding=static_embedding, 157 | per_device_train_batch_size=per_device_train_batch_size, 158 | eval_steps=eval_steps, 159 | eval_spectrum=eval_spectrum, 160 | max_spectrum_batches=max_spectrum_batches, 161 | max_lanczos_iter=max_lanczos_iter, 162 | randomly_initialize=randomly_initialize, 163 | batch_size=batch_size, 164 | num_train_epochs=num_train_epochs, 165 | store_grads=store_grads, 166 | orthogonal_projection_path=orthogonal_projection_path, 167 | orthogonal_projection_rank=orthogonal_projection_rank, 168 | ) 169 | print('Running command:') 170 | print(command) 171 | os.system(command) 172 | 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/3d_surface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/private-transformers/18ccc4eab7355e4ac96051a82434796f6aa4624b/examples/classification/spectral_analysis/3d_surface.png -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/README.md: -------------------------------------------------------------------------------- 1 | ## Experiments for spectral analysis 2 | 3 | Everything below should be run from the `examples/` folder. 4 | 5 | 1. To run with the geometric median example in the paper, use the following command and supply an `` of your 6 | choice: 7 | 8 | ```bash 9 | python -m classification.spectral_analysis.geometric_median --img_dir 10 | ``` 11 | 12 | 2. Spectral analysis. 13 | 14 | 2.1. To reproduce the spectral analysis experiments, one first need a first round of fine-tuning to collect 15 | gradients. 16 | Run the following command with `` of your choice. Note everything down the line about PCA will be stored 17 | here, so make sure you have enough diskspace! It's perhaps safe to reserve 500G~1T. The spectral analyses are very 18 | diskspace intensive. Note below `` can be one of `distilroberta-base`, `roberta-base` 19 | , `roberta-large`. 20 | 21 | ```bash 22 | CUDA_VISIBLE_DEVICES=0 python -m classification.spectral_analysis.rebuttal_neurips_2022 \ 23 | --task "run_save_grads" \ 24 | --train_dir \ 25 | --model_name_or_path 26 | ``` 27 | 28 | 2.2. Now run PCA with orthogonal iteration to extract top eigenvectors. The command below runs PCA based on 4k 29 | checkpoints (4k gradients stored along the trajectory), and extracts the top 1k eigenvalues and 30 | eigenvectors. `batch_size` can be set small to save memory (it affects distributed matmul). Note for 31 | the `roberta-large` experiment, you would likely need several GPUs. For reference, I used 4 A6000 (each with 48G 32 | VRAM) for that experiment. The code is written in a 33 | way so that computation can be distributed across many GPUs on a single machine, and should be 34 | fast with enough accelerators. **`` below must be the same as in the previous command.** 35 | 36 | ```bash 37 | python -m classification.spectral_analysis.rebuttal_neurips_2022 \ 38 | --task "run_pca" \ 39 | --train_dir \ 40 | --n 4000 \ 41 | --k 1000 \ 42 | --num_power_iteration 10 \ 43 | --batch_size 20 44 | ``` 45 | 46 | 2.3. For re-training in subspace, we need to specify to the command the place where the PCA results are stored in 47 | order to use those. The PCA results will be in `/orthproj/all/`. There will likely be a couple of 48 | checkpoints in this folder, each of which correspondes to a different iteration of the orthogonal iteration. Now run 49 | the 50 | following command. Note that below 1) `` should **not** be the same as `` to avoid 51 | overwriting previous PCA results, and 2) `` should be smaller than `k` from the previous command since it's the 52 | rank of the subspace. 53 | 54 | ```bash 55 | CUDA_VISIBLE_DEVICES=0 python -m classification.spectral_analysis.rebuttal_neurips_2022 \ 56 | --task "run_retrain_single" \ 57 | --output_dir \ 58 | --orthogonal_projection_path "/orthproj/all/global_step_x.pt" \ 59 | --rank \ 60 | --model_name_or_path 61 | ``` 62 | 63 | ## Citation 64 | 65 | If you found this codebase useful in your research, please consider citing: 66 | 67 | ```@misc{li2022when, 68 | doi = {10.48550/ARXIV.2207.00160}, 69 | url = {https://arxiv.org/abs/2207.00160}, 70 | author = {Li, Xuechen and Liu, Daogao and Hashimoto, Tatsunori and Inan, Huseyin A. and Kulkarni, Janardhan and Lee, Yin Tat and Thakurta, Abhradeep Guha}, 71 | keywords = {Machine Learning (cs.LG), Cryptography and Security (cs.CR), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences}, 72 | title = {When Does Differentially Private Learning Not Suffer in High Dimensions?}, 73 | publisher = {arXiv}, 74 | year = {2022}, 75 | copyright = {Creative Commons Attribution 4.0 International} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/density.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # Copyright 2019 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Code for converting Lanczos outputs to densities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import numpy as np 25 | 26 | 27 | def eigv_to_density(eig_vals, all_weights=None, grids=None, 28 | grid_len=10000, sigma_squared=None, grid_expand=1e-2): 29 | """Compute the smoothed spectral density from a set of eigenvalues. 30 | 31 | Convolves the given eigenvalues with a Gaussian kernel, weighting the values 32 | by all_weights (or uniform weighting if all_weights is None). Example output 33 | can be seen in Figure 1 of https://arxiv.org/pdf/1901.10159.pdf. Visualizing 34 | the estimated density can be done by calling plt.plot(grids, density). There 35 | is likely not a best value of sigma_squared that works for all use cases, 36 | so it is recommended to try multiple values in the range [1e-5,1e-1]. 37 | 38 | Args: 39 | eig_vals: Array of shape [num_draws, order] 40 | all_weights: Array of shape [num_draws, order], if None then weights will be 41 | taken to be uniform. 42 | grids: Array of shape [grid_len], the smoothed spectrum will be plotted 43 | in the interval [grids[0], grids[-1]]. If None then grids will be 44 | computed based on max and min eigenvalues and grid length. 45 | grid_len: Integer specifying number of grid cells to use, only used if 46 | grids is None 47 | sigma_squared: Scalar. Controls the smoothing of the spectrum estimate. 48 | If None, an appropriate value is inferred. 49 | grid_expand: Controls the window of values that grids spans. 50 | grids[0] = smallest eigenvalue - grid_expand. 51 | grids[-1] = largest_eigenvalue + grid_expand. 52 | 53 | Returns: 54 | density: Array of shape [grid_len], the estimated density, averaged over 55 | all draws. 56 | grids: Array of shape [grid_len]. The values the density is estimated on. 57 | """ 58 | if all_weights is None: 59 | all_weights = np.ones(eig_vals.shape) * 1.0 / float(eig_vals.shape[1]) 60 | num_draws = eig_vals.shape[0] 61 | 62 | lambda_max = np.nanmean(np.max(eig_vals, axis=1), axis=0) + grid_expand 63 | lambda_min = np.nanmean(np.min(eig_vals, axis=1), axis=0) - grid_expand 64 | 65 | if grids is None: 66 | assert grid_len is not None, 'grid_len is required if grids is None.' 67 | grids = np.linspace(lambda_min, lambda_max, num=grid_len) 68 | 69 | grid_len = grids.shape[0] 70 | if sigma_squared is None: 71 | sigma = 10 ** -5 * max(1, (lambda_max - lambda_min)) 72 | else: 73 | sigma = sigma_squared * max(1, (lambda_max - lambda_min)) 74 | 75 | density_each_draw = np.zeros((num_draws, grid_len)) 76 | for i in range(num_draws): 77 | 78 | if np.isnan(eig_vals[i, 0]): 79 | raise ValueError('tridaig has nan values.') 80 | else: 81 | for j in range(grid_len): 82 | x = grids[j] 83 | vals = _kernel(eig_vals[i, :], x, sigma) 84 | density_each_draw[i, j] = np.sum(vals * all_weights[i, :]) 85 | density = np.nanmean(density_each_draw, axis=0) 86 | norm_fact = np.sum(density) * (grids[1] - grids[0]) 87 | density = density / norm_fact 88 | return density, grids 89 | 90 | 91 | def tridiag_to_eigv(tridiag_list): 92 | """Preprocess the tridiagonal matrices for density estimation. 93 | 94 | Args: 95 | tridiag_list: Array of shape [num_draws, order, order] List of the 96 | tridiagonal matrices computed from running num_draws independent runs 97 | of lanczos. The output of this function can be fed directly into 98 | eigv_to_density. 99 | 100 | Returns: 101 | eig_vals: Array of shape [num_draws, order]. The eigenvalues of the 102 | tridiagonal matricies. 103 | all_weights: Array of shape [num_draws, order]. The weights associated with 104 | each eigenvalue. These weights are to be used in the kernel density 105 | estimate. 106 | """ 107 | # Calculating the node / weights from Jacobi matrices. 108 | num_draws = len(tridiag_list) 109 | num_lanczos = tridiag_list[0].shape[0] 110 | eig_vals = np.zeros((num_draws, num_lanczos)) 111 | all_weights = np.zeros((num_draws, num_lanczos)) 112 | for i in range(num_draws): 113 | nodes, evecs = np.linalg.eigh(tridiag_list[i]) 114 | index = np.argsort(nodes) 115 | nodes = nodes[index] 116 | evecs = evecs[:, index] 117 | eig_vals[i, :] = nodes 118 | all_weights[i, :] = evecs[0] ** 2 119 | return eig_vals, all_weights 120 | 121 | 122 | def tridiag_to_density(tridiag_list, sigma_squared=1e-5, grid_len=10000): 123 | """This function estimates the smoothed density from the output of lanczos. 124 | 125 | Args: 126 | tridiag_list: Array of shape [num_draws, order, order] List of the 127 | tridiagonal matrices computed from running num_draws independent runs 128 | of lanczos. 129 | sigma_squared: Controls the smoothing of the density. 130 | grid_len: Controls the granularity of the density. 131 | 132 | Returns: 133 | density: Array of size [grid_len]. The smoothed density estimate averaged 134 | over all num_draws. 135 | grids: Array of size [grid_len]. The values the density estimate is on. 136 | """ 137 | eig_vals, all_weights = tridiag_to_eigv(tridiag_list) 138 | density, grids = eigv_to_density(eig_vals, all_weights, 139 | grid_len=grid_len, 140 | sigma_squared=sigma_squared) 141 | return density, grids 142 | 143 | 144 | def _kernel(x, x0, variance): 145 | """Point estimate of the Gaussian kernel. 146 | 147 | This function computes the Gaussian kernel for 148 | C exp(-(x - x0) ^2 /(2 * variance)) where C is the appropriate normalization. 149 | variance should be a list of length 1. Either x0 or x should be a scalar. Only 150 | one of the x or x0 can be a numpy array. 151 | 152 | Args: 153 | x: Can be either scalar or array of shape [order]. Points to estimate 154 | the kernel on. 155 | x0: Scalar. Mean of the kernel. 156 | variance: Scalar. Variance of the kernel. 157 | 158 | Returns: 159 | point_estimate: A scalar corresponding to 160 | C exp(-(x - x0) ^2 /(2 * variance)). 161 | """ 162 | coeff = 1.0 / np.sqrt(2 * math.pi * variance) 163 | val = -(x0 - x) ** 2 164 | val = val / (2.0 * variance) 165 | val = np.exp(val) 166 | point_estimate = coeff * val 167 | return point_estimate 168 | -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/geometric_median.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Toy example on geometric median estimation in the paper. 17 | 18 | python -m classification.spectral_analysis.geometric_median --img_dir "/mnt/disks/disk-2/dump/spectrum/geometric_median" 19 | """ 20 | import dataclasses 21 | import logging 22 | import math 23 | import sys 24 | from typing import Tuple 25 | 26 | import fire 27 | import numpy as np 28 | import torch 29 | import tqdm 30 | from ml_swissknife import utils 31 | 32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | 34 | 35 | @dataclasses.dataclass 36 | class Data: 37 | beta_train: torch.Tensor 38 | beta_test: torch.Tensor 39 | Ar: torch.Tensor # A^{1/2}. 40 | sensitivity: float 41 | 42 | def __post_init__(self): 43 | self.n_train, self.d = self.beta_train.size() 44 | self.n_test = self.beta_test.shape[0] 45 | 46 | 47 | class Modes(metaclass=utils.ContainerMeta): 48 | const = "const" 49 | quarter = "quarter" 50 | sqrt = "sqrt" 51 | linear = "linear" 52 | quadratic = "quadratic" 53 | 54 | 55 | def make_data( 56 | betas=None, 57 | n_train=100000, n_test=100000, d=10, dmin=1, mu_beta=0.2, si_beta=0.1, 58 | mode="linear", 59 | g0=1., 60 | ): 61 | if betas is None: 62 | beta_train, beta_test = make_beta( 63 | n_train=n_train, n_test=n_test, d=d, dmin=dmin, mu_beta=mu_beta, si_beta=si_beta 64 | ) 65 | else: 66 | beta_train, beta_test = betas 67 | n_train, d = beta_train.size() 68 | n_test, _ = beta_test.size() 69 | 70 | if mode == Modes.const: 71 | Ar = g0 * torch.arange(1, d + 1, device=device) 72 | elif mode == Modes.quarter: 73 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -.25 74 | elif mode == Modes.sqrt: 75 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -.5 76 | elif mode == Modes.linear: 77 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -1. 78 | elif mode == Modes.quadratic: 79 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -2. 80 | else: 81 | raise ValueError(f"Unknown mode: {mode}") 82 | 83 | sensitivity = 2 * g0 / n_train 84 | 85 | return Data(beta_train=beta_train, beta_test=beta_test, Ar=Ar, sensitivity=sensitivity) 86 | 87 | 88 | def make_beta(n_train, n_test, d, dmin, mu_beta, si_beta): 89 | if d < dmin: 90 | raise ValueError(f"d < dmin") 91 | 92 | beta_train = mu_beta + torch.randn(size=(n_train, d), device=device) * si_beta 93 | beta_train[:, dmin:] = 0. # Ensure init distance to opt is the same. 94 | 95 | beta_test = mu_beta + torch.randn(size=(n_test, d), device=device) * si_beta 96 | beta_test[:, dmin:] = 0. # Same distribution as train. 97 | 98 | return beta_train, beta_test 99 | 100 | 101 | def evaluate(data: Data, beta: torch.Tensor) -> Tuple: 102 | """Compute loss 1 / n sum_i | A^{1/2} (beta - beta_i) |_2 for train and test.""" 103 | 104 | def compute_loss(samples): 105 | res = data.Ar[None, :] * (beta - samples) # (n, d). 106 | return res.norm(2, dim=1).mean(dim=0).item() 107 | 108 | return tuple( 109 | compute_loss(samples=samples) 110 | for samples in (data.beta_train, data.beta_test) 111 | ) 112 | 113 | 114 | def train_one_step(data: Data, beta, lr, epsilon, delta, weight_decay): 115 | res = data.Ar[None, :] * (beta - data.beta_train) # (n, d). 116 | grad = data.Ar * (res / res.norm(2, dim=1, keepdim=True)).mean(dim=0) 117 | 118 | gaussian_mechanism_variance = 2. * math.log(1.25 / delta) * data.sensitivity ** 2. / epsilon ** 2. 119 | grad_priv = grad + torch.randn_like(grad) * math.sqrt(gaussian_mechanism_variance) 120 | beta = beta - lr * (grad_priv + weight_decay * beta) 121 | return beta 122 | 123 | 124 | @torch.no_grad() 125 | def train(data: Data, num_steps, eval_steps, lr, weight_decay, epsilon, delta, tag, verbose, seed): 126 | utils.manual_seed(seed) 127 | 128 | per_step_epsilon, per_step_delta = make_per_step_privacy_spending( 129 | target_epsilon=epsilon, target_delta=delta, num_steps=num_steps 130 | ) 131 | 132 | beta = torch.zeros(size=(1, data.d,), device=device) 133 | beta_avg = beta.clone() 134 | 135 | for global_step in range(0, num_steps): 136 | if global_step % eval_steps == 0: 137 | tr_loss, te_loss = evaluate(data=data, beta=beta_avg) 138 | if verbose: 139 | logging.warning( 140 | f"tag: {tag}, global_step: {global_step}, lr: {lr:.6f}, num_steps: {num_steps}, " 141 | f"train_loss: {tr_loss:.6f}, test_loss: {te_loss:.6f}" 142 | ) 143 | 144 | beta = train_one_step( 145 | data=data, 146 | beta=beta, 147 | lr=lr, weight_decay=weight_decay, 148 | epsilon=per_step_epsilon, delta=per_step_delta, 149 | ) 150 | beta_avg = beta_avg * global_step / (global_step + 1) + beta / (global_step + 1) 151 | 152 | final_tr_loss, final_te_loss = evaluate(data=data, beta=beta_avg) 153 | if verbose: 154 | logging.warning( 155 | f"tag: {tag}, final, lr: {lr:.6f}, num_steps: {num_steps}, " 156 | f"train_loss: {final_tr_loss:.6f}, te_loss: {final_te_loss:.6f}" 157 | ) 158 | 159 | return beta_avg, (final_tr_loss, final_te_loss) 160 | 161 | 162 | def make_per_step_privacy_spending( 163 | target_epsilon, target_delta, num_steps, threshold=1e-4, 164 | ): 165 | per_step_delta = target_delta / (num_steps + 1) 166 | 167 | def adv_composition(per_step_epsilon): 168 | total_epsilon = ( 169 | math.sqrt(2 * num_steps * math.log(1 / per_step_delta)) * per_step_epsilon + 170 | num_steps * per_step_epsilon * (math.exp(per_step_epsilon) - 1) 171 | ) 172 | return total_epsilon 173 | 174 | minval, maxval = 1e-6, 5 175 | while maxval - minval > threshold: 176 | midval = (maxval + minval) / 2 177 | eps = adv_composition(midval) 178 | if eps > target_epsilon: 179 | maxval = midval 180 | else: 181 | minval = midval 182 | per_step_epsilon = minval 183 | return per_step_epsilon, per_step_delta 184 | 185 | 186 | def main( 187 | img_dir=None, eval_steps=10000, weight_decay=0, epsilon=2, delta=1e-6, 188 | n_train=10000, n_test=10000, dmin=1, mu_beta=1., si_beta=1, g0=3., 189 | seeds=(42, 96, 10000, 999, 101), # Some arbitrary numbers. 190 | modes=(Modes.const, Modes.sqrt, Modes.linear), # A subset of all possible modes for visualization. 191 | verbose=False, 192 | quick=False, # Use small data if True. 193 | ): 194 | if quick: 195 | dims = (10, 50,) 196 | num_steps_list = (10, 20,) 197 | lrs = (1e-4, 3e-4,) 198 | else: 199 | dims = (20, 50, 100, 200, 500, 1000, 2000) 200 | num_steps_list = (10, 20, 40, 80, 160, 320, 640, 1280, 2560, 5120) 201 | lrs = (1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1, 1, 3,) 202 | 203 | tr_losses = {mode: [] for mode in modes} 204 | te_losses = {mode: [] for mode in modes} 205 | for dim in tqdm.tqdm(dims, desc="dims"): 206 | betas = make_beta(n_train=n_train, n_test=n_test, d=dim, dmin=dmin, mu_beta=mu_beta, si_beta=si_beta) 207 | data = tuple(make_data(betas=betas, mode=mode, g0=g0) for mode in modes) 208 | 209 | tr_loss = {mode: [sys.maxsize] for mode in modes} 210 | te_loss = {mode: [sys.maxsize] for mode in modes} 211 | for this_data, this_mode in tqdm.tqdm(utils.zip_(data, modes), desc="modes", total=len(data)): 212 | 213 | # Hyperparameter tuning. 214 | for num_steps in num_steps_list: 215 | for lr in lrs: 216 | kwargs = dict( 217 | data=this_data, 218 | num_steps=num_steps, 219 | lr=lr, 220 | 221 | eval_steps=eval_steps, 222 | weight_decay=weight_decay, 223 | epsilon=epsilon, 224 | delta=delta, 225 | tag=this_mode, 226 | verbose=verbose, 227 | ) 228 | 229 | tr_results = [] 230 | te_results = [] 231 | for seed in seeds: 232 | _, (a, b) = train(**kwargs, seed=seed) 233 | tr_results.append(a) 234 | te_results.append(b) 235 | 236 | if np.mean(tr_results) < np.mean(tr_loss[this_mode]): 237 | tr_loss[this_mode] = tr_results 238 | te_loss[this_mode] = te_results 239 | 240 | # update after hp tuning. 241 | for this_mode in modes: 242 | tr_losses[this_mode].append(tr_loss[this_mode]) 243 | te_losses[this_mode].append(te_loss[this_mode]) 244 | 245 | raw_data = dict(tr_losses=tr_losses, te_losses=te_losses, modes=modes, dims=dims) 246 | 247 | if img_dir is not None: 248 | utils.jdump(raw_data, utils.join(img_dir, 'toyplot.json')) 249 | 250 | plot_modes = modes 251 | linestyles = ("-", "--", ":", "-.") 252 | markers = ("o", "+", "x", "^") 253 | 254 | tr_plotting = dict( 255 | errorbars=tuple( 256 | dict( 257 | x=dims, 258 | y=np.mean(np.array(tr_losses[this_mode]), axis=1), 259 | yerr=np.std(np.array(tr_losses[this_mode]), axis=1), 260 | label=this_mode, marker=markers[mode_idx], 261 | linestyle=linestyles[mode_idx] 262 | ) 263 | for mode_idx, this_mode in enumerate(plot_modes) 264 | ), 265 | options=dict(xlabel="$d$", ylabel="train loss") 266 | ) 267 | utils.plot_wrapper( 268 | img_path=utils.join(img_dir, 'trplot'), 269 | suffixes=('.png', '.pdf'), 270 | **tr_plotting, 271 | ) 272 | 273 | te_plotting = dict( 274 | errorbars=tuple( 275 | dict( 276 | x=dims, 277 | y=np.mean(np.array(te_losses[this_mode]), axis=1), 278 | yerr=np.std(np.array(te_losses[this_mode]), axis=1), 279 | label=this_mode, marker=markers[mode_idx], 280 | linestyle=linestyles[mode_idx] 281 | ) 282 | for mode_idx, this_mode in enumerate(plot_modes) 283 | ), 284 | options=dict(xlabel="$d$", ylabel="test loss") 285 | ) 286 | utils.plot_wrapper( 287 | img_path=utils.join(img_dir, 'teplot'), 288 | suffixes=('.png', '.pdf'), 289 | **te_plotting, 290 | ) 291 | 292 | 293 | if __name__ == "__main__": 294 | fire.Fire(main) 295 | -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/rebuttal_neurips_2022.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Experiments ran pre- and post-rebuttals.""" 16 | import logging 17 | import os 18 | from typing import Optional 19 | 20 | import fire 21 | import torch 22 | import tqdm 23 | from ml_swissknife import utils, numerical_distributed 24 | from torch.utils.data import DataLoader, TensorDataset 25 | 26 | 27 | def run_save_grads( 28 | num_train_epochs=60, # This amounts to 4k updates, roughly. 29 | model_name_or_path="roberta-base", 30 | train_dir=None, 31 | per_device_train_batch_size=25, 32 | ): 33 | if train_dir is None: 34 | train_dir = utils.join("/mnt/data1/dump/", 'rebuttal_v2', f'run-{model_name_or_path}') 35 | command = f'''python -m classification.run_wrapper \ 36 | --output_dir {train_dir} \ 37 | --task_name "sst-2" \ 38 | --model_name_or_path "{model_name_or_path}" \ 39 | --attention_only "yes" \ 40 | --static_lm_head "yes" \ 41 | --num_train_epochs {num_train_epochs} \ 42 | --eval_spectrum "no" \ 43 | --non_private "no" \ 44 | --eval_steps 50 \ 45 | --randomly_initialize "no" \ 46 | --per_device_train_batch_size {per_device_train_batch_size} \ 47 | --batch_size 1000 \ 48 | --clipping_mode "default" \ 49 | --store_grads "yes"''' 50 | os.system(command) 51 | 52 | 53 | def run_pca( 54 | # Place where grads are stored and where results will be stored. 55 | train_dir="/mnt/disks/disk-2/dump/privlm/roberta/sst-2", 56 | n=2000, # How many checkpoints? 57 | k=1000, # How many eigenvectors? 58 | num_power_iteration=10, 59 | batch_size=20, # Batch size for processing the checkpoints in matmul. 60 | seed=42, # Controls randomness in sampling the first vector in orthogonal iteration. 61 | start_index=0, # The index of the first checkpoint to be selected. 62 | eval_steps=5, # Evaluate PCA accuracy once this many iterations. 63 | save_steps=5, # Save eigenvalue and eigenvector tensors once this many iterations. 64 | disable_tqdm=False, 65 | dtype="float", # String repr of dtype. 66 | ): 67 | utils.manual_seed(seed) 68 | 69 | ckpt_dir = utils.join(train_dir, 'grad_trajectory') 70 | dump_dir = utils.join(train_dir, 'orthproj') 71 | 72 | all_ckpts = utils.all_ckpts(ckpt_dir, sort=True) 73 | tgt_ckpts = all_ckpts[start_index:start_index + n] 74 | dataset = torch.stack([ 75 | torch.load(ckpt_path)["flat_grad"] for ckpt_path in tqdm.tqdm(tgt_ckpts, desc="load data") 76 | ]).to(utils.get_dtype(dtype)) 77 | input_mat = DataLoader(dataset=TensorDataset(dataset), batch_size=batch_size) 78 | 79 | def callback(global_step, eigenvalues, eigenvectors): 80 | if global_step % save_steps == 0: 81 | utils.tsave( 82 | dict(eigenvalues=eigenvalues, eigenvectors=eigenvectors), 83 | utils.join(dump_dir, "all", f"global_step_{global_step:06d}.pt") 84 | ) 85 | utils.tsave( 86 | dict(eigenvalues=eigenvalues), 87 | utils.join(dump_dir, "eigenvalues", f"global_step_{global_step:06d}.evals") 88 | ) 89 | if global_step % eval_steps == 0: 90 | err_abs, err_rel = numerical_distributed.check_error( 91 | input_mat=input_mat, eigenvectors=eigenvectors, disable_tqdm=disable_tqdm 92 | ) 93 | logging.warning(f"global_step: {global_step}, abs error: {err_abs:.6f}, rel error: {err_rel:.6f}") 94 | 95 | numerical_distributed.orthogonal_iteration( 96 | input_mat=input_mat, 97 | k=k, 98 | num_power_iteration=num_power_iteration, 99 | callback=callback, 100 | disable_tqdm=disable_tqdm, 101 | ) 102 | 103 | 104 | def run_retrain_single( 105 | output_dir: str, 106 | orthogonal_projection_path: str, 107 | model_name_or_path: str, 108 | rank: Optional[int] = None, 109 | seed=42, 110 | ): 111 | cmd = f'''python -m classification.run_wrapper \ 112 | --output_dir {output_dir} \ 113 | --task_name "sst-2" \ 114 | --model_name_or_path {model_name_or_path} \ 115 | --few_shot_type "prompt" \ 116 | --attention_only "yes" \ 117 | --static_lm_head "yes" \ 118 | --per_device_train_batch_size 25 \ 119 | --batch_size 1000 \ 120 | --clipping_mode "default" \ 121 | --num_train_epochs 4 \ 122 | --eval_spectrum "no" \ 123 | --non_private "no" \ 124 | --eval_steps 25 \ 125 | --randomly_initialize "no" \ 126 | --seed {seed} \ 127 | --orthogonal_projection_path {orthogonal_projection_path}''' 128 | if rank is not None: 129 | cmd += f' --orthogonal_projection_rank {rank}' 130 | os.system(cmd) 131 | 132 | 133 | def main(task, **kwargs): 134 | globals()[task](**kwargs) 135 | 136 | 137 | if __name__ == "__main__": 138 | fire.Fire(main) 139 | -------------------------------------------------------------------------------- /examples/classification/spectral_analysis/rebuttal_plots_neurips_2022.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Plot 1) spectral decay, 2) retrain curves. 17 | """ 18 | 19 | import math 20 | 21 | import fire 22 | import numpy as np 23 | import scipy.stats 24 | import torch 25 | from ml_swissknife import utils 26 | 27 | from . import density 28 | 29 | 30 | def plot1( 31 | ckpt_path: str, # Path to eigenvalues. 32 | dump_dir="./classification/plots", 33 | img_name="", 34 | k=500, 35 | **kwargs, 36 | ): 37 | """Eigenvalues. 38 | 39 | Run on gvm. 40 | """ 41 | # Roberta-large 42 | # python -m classification.spectral_analysis.rebuttal_plots_neurips_2022 --task "plot1" --ckpt_path "/mnt/data1/dump/rebuttal/run-roberta-large/orthproj/eigenvalues/global_step_000005.evals" --img_name "large" --k 100 43 | if img_name != "": 44 | img_name = f'-{img_name}' 45 | 46 | state_dicts = torch.load(ckpt_path) 47 | eigenvalues = state_dicts["eigenvalues"].numpy() 48 | eigenvalues = -np.sort(-eigenvalues) 49 | k = min(k, len(eigenvalues)) 50 | 51 | # Linear fit. 52 | x = np.arange(1, k + 1) 53 | g = np.sqrt(eigenvalues[:k]) 54 | logg = np.log(g) 55 | logx = np.log(x) 56 | 57 | linfit = scipy.stats.linregress(logx, logg) 58 | g_linfit = np.exp(logx * linfit.slope + linfit.intercept) 59 | 60 | print("slope:", linfit.slope) 61 | print("R value:", linfit.rvalue) 62 | 63 | plots = [ 64 | dict(x=x, y=g, marker='+', linewidth=0, label="estimated values", markersize=8, alpha=0.8), 65 | dict(x=x, y=g_linfit, 66 | label=f"linear fit: $\log y = {linfit.slope:.2f} \log x {linfit.intercept:.2f} $ ($R^2=" 67 | f"{linfit.rvalue ** 2.:.3f}$)"), 68 | ] 69 | utils.plot_wrapper( 70 | img_path=utils.join(dump_dir, f"eigenvalue-linfit{img_name}"), 71 | suffixes=(".png", ".pdf"), 72 | plots=plots, 73 | options=dict(xlabel="$k$", ylabel="$\lambda(H^\\top H)^{1/2}$", xscale='log', yscale='log') 74 | ) 75 | 76 | # Spectral density. 77 | sigma_squared = 1e-6 78 | evals = np.sqrt(eigenvalues[None, :k]) 79 | den, gri = density.eigv_to_density(evals, sigma_squared=sigma_squared, grid_len=300000, grid_expand=3e-4) 80 | utils.plot_wrapper( 81 | img_path=utils.join(dump_dir, f'eigenvalue-density{img_name}'), 82 | suffixes=(".png", ".pdf"), 83 | plots=[dict(x=gri, y=den, label=f"bandwidth $\sigma={math.sqrt(sigma_squared):.5f}$")], 84 | options=dict(xlabel="$\lambda(H^\\top H)^{1/2}$", ylabel="Density of KDE", 85 | ylim=dict(bottom=1e-10, top=2e2), 86 | xscale="log", yscale='log') 87 | ) 88 | 89 | 90 | def plot2( 91 | base_dir: str, 92 | img_name="", 93 | seeds=(42, 9008, 0), 94 | ranks=(10, 20, 100, None), 95 | dump_dir="./classification/plots", 96 | markers=('x', '^', '+', 'o'), 97 | roberta_large=False, 98 | **kwargs, 99 | ): 100 | """Retrain. 101 | 102 | Run locally. 103 | """ 104 | # Roberta-large 105 | # python -m classification.spectral_analysis.rebuttal_plots_neurips_2022 --task "plot2" --img_name "large" --base_dir "/mnt/data1/dump/rebuttal" --roberta_large True 106 | if img_name != "": 107 | img_name = f'-{img_name}' 108 | 109 | errorbars = [] 110 | for rank, marker in utils.zip_(ranks, markers): 111 | results = [] 112 | for seed in seeds: 113 | if roberta_large: 114 | output_dir = utils.join( 115 | f"{base_dir}/roberta_prompt_large_retrain_{rank}_{seed}/sst-2", 116 | 'log_history.json' 117 | ) 118 | else: 119 | output_dir = utils.join( 120 | f"{base_dir}/roberta_prompt_retrain_{rank}_{seed}/sst-2", 121 | 'log_history.json' 122 | ) 123 | record = utils.jload(output_dir) 124 | results.append([dumpi['dev']['eval_acc'] for dumpi in record]) 125 | steps = [dumpi['step'] for dumpi in record] 126 | 127 | label = f"subspace rank={rank}" if rank is not None else "original" 128 | mu, si = utils.average_over_seed(results) 129 | errorbar = dict(x=steps, y=mu, yerr=si, label=label, marker=marker) 130 | errorbars.append(errorbar) 131 | 132 | img_path = utils.join(dump_dir, f'plot2{img_name}') 133 | utils.plot_wrapper( 134 | img_path=img_path, 135 | suffixes=('.png', '.pdf'), 136 | errorbars=errorbars, 137 | options=dict(xlabel="iteration", ylabel="SST-2 classification accuracy (dev)") 138 | ) 139 | 140 | 141 | def plot_all(**kwargs): 142 | # rebuttal roberta-base experiments. 143 | # python -m classification.spectral_analysis.rebuttal_plots_neurips_2022 --task "plot_all" --base_dir "/mnt/data1/dump/rebuttal" --ckpt_path "/mnt/data1/dump/rebuttal/run-roberta-base/orthproj/eigenvalues/global_step_000010.evals" 144 | plot1(**kwargs) 145 | plot2(**kwargs) 146 | 147 | 148 | def main(task="plot_all", **kwargs): 149 | utils.runs_tasks( 150 | task=task, 151 | task_names=("plot_all", "plot1", "plot2"), 152 | task_callables=(plot_all, plot1, plot2), 153 | **kwargs, 154 | ) 155 | 156 | 157 | if __name__ == "__main__": 158 | fire.Fire(main) 159 | -------------------------------------------------------------------------------- /examples/classification/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/classification/src/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | task_name2suffix_name = {"sst-2": "GLUE-SST-2", "mnli": "MNLI", "qqp": "QQP", "qnli": "QNLI"} 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | true_tags = ('y', 'yes', 't', 'true') 20 | -------------------------------------------------------------------------------- /examples/classification/src/compiled_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | import transformers 18 | 19 | from .common import true_tags 20 | from typing import Optional 21 | 22 | 23 | @dataclass 24 | class PrivacyArguments: 25 | """Arguments for differentially private training.""" 26 | 27 | per_example_max_grad_norm: float = field( 28 | default=.1, metadata={ 29 | "help": "Clipping 2-norm of per-sample gradients." 30 | } 31 | ) 32 | noise_multiplier: float = field( 33 | default=None, metadata={ 34 | "help": "Standard deviation of noise added for privacy; if `target_epsilon` is specified, " 35 | "use the one searched based budget" 36 | } 37 | ) 38 | target_epsilon: float = field( 39 | default=None, metadata={ 40 | "help": "Privacy budget; if `None` use the noise multiplier specified." 41 | } 42 | ) 43 | target_delta: float = field( 44 | default=None, metadata={ 45 | "help": "Lax probability in approximate differential privacy; if `None` use 1 / len(train_data)." 46 | } 47 | ) 48 | non_private: str = field( 49 | default="yes", metadata={"help": "Train non-privately if True."} 50 | ) 51 | accounting_mode: str = field( 52 | default="rdp", metadata={"help": "One of (`rdp`, `glw`, `all`)."} 53 | ) 54 | clipping_mode: str = field( 55 | default="default" 56 | ) 57 | 58 | def __post_init__(self): 59 | self.non_private = self.non_private.lower() in true_tags # noqa 60 | 61 | 62 | @dataclass 63 | class TrainingArguments(transformers.TrainingArguments): 64 | eval_epochs: int = field(default=10, metadata={"help": "Evaluate once such epochs"}) 65 | evaluate_before_training: bool = field(default=False, metadata={"help": "Run evaluation before training."}) 66 | lr_decay: str = field( 67 | default="no", metadata={"help": "Apply the usual linear decay if `yes`, otherwise no deacy."} 68 | ) 69 | evaluate_test_split: bool = field(default=False, metadata={"help": "Run evaluation on the test split"}) 70 | 71 | def __post_init__(self): 72 | super(TrainingArguments, self).__post_init__() 73 | self.lr_decay = self.lr_decay.lower() in true_tags # noqa 74 | 75 | 76 | @dataclass 77 | class AuxiliaryArguments: 78 | eval_spectrum: str = field(default="no") 79 | max_spectrum_batches: int = field(default=100) 80 | max_lanczos_iter: int = field(default=100) 81 | 82 | store_grads: str = field(default="no") 83 | orthogonal_projection_path: Optional[str] = field(default=None) 84 | orthogonal_projection_rank: int = field(default=100) 85 | 86 | def __post_init__(self): 87 | self.eval_spectrum = self.eval_spectrum.lower() in true_tags # noqa 88 | self.store_grads = self.store_grads.lower() in true_tags # noqa 89 | -------------------------------------------------------------------------------- /examples/classification/src/label_search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Automatic label search helpers.""" 16 | 17 | import itertools 18 | import logging 19 | import multiprocessing 20 | 21 | import numpy as np 22 | import scipy.spatial as spatial 23 | import scipy.special as special 24 | import scipy.stats as stats 25 | import tqdm 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def select_likely_words(train_logits, train_labels, k_likely=1000, vocab=None, is_regression=False): 31 | """Pre-select likely words based on conditional likelihood.""" 32 | indices = [] 33 | if is_regression: 34 | median = np.median(train_labels) 35 | train_labels = (train_labels > median).astype(np.int) 36 | num_labels = np.max(train_labels) + 1 37 | for idx in range(num_labels): 38 | label_logits = train_logits[train_labels == idx] 39 | scores = label_logits.mean(axis=0) 40 | kept = [] 41 | for i in np.argsort(-scores): 42 | text = vocab[i] 43 | if not text.startswith("Ġ"): 44 | continue 45 | kept.append(i) 46 | indices.append(kept[:k_likely]) 47 | return indices 48 | 49 | 50 | def select_neighbors(distances, k_neighbors, valid): 51 | """Select k nearest neighbors based on distance (filtered to be within the 'valid' set).""" 52 | indices = np.argsort(distances) 53 | neighbors = [] 54 | for i in indices: 55 | if i not in valid: 56 | continue 57 | neighbors.append(i) 58 | if k_neighbors > 0: 59 | return neighbors[:k_neighbors] 60 | return neighbors 61 | 62 | 63 | def init(train_logits, train_labels): 64 | global logits, labels 65 | logits = train_logits 66 | labels = train_labels 67 | 68 | 69 | def eval_pairing_acc(pairing): 70 | global logits, labels 71 | label_logits = np.take(logits, pairing, axis=-1) 72 | preds = np.argmax(label_logits, axis=-1) 73 | correct = np.sum(preds == labels) 74 | return correct / len(labels) 75 | 76 | 77 | def eval_pairing_corr(pairing): 78 | global logits, labels 79 | if pairing[0] == pairing[1]: 80 | return -1 81 | label_logits = np.take(logits, pairing, axis=-1) 82 | label_probs = special.softmax(label_logits, axis=-1)[:, 1] 83 | pearson_corr = stats.pearsonr(label_probs, labels)[0] 84 | return pearson_corr 85 | 86 | 87 | def find_labels( 88 | model, 89 | train_logits, 90 | train_labels, 91 | seed_labels=None, 92 | k_likely=1000, 93 | k_neighbors=None, 94 | top_n=-1, 95 | vocab=None, 96 | is_regression=False, 97 | ): 98 | # Get top indices based on conditional likelihood using the LM. 99 | likely_indices = select_likely_words( 100 | train_logits=train_logits, 101 | train_labels=train_labels, 102 | k_likely=k_likely, 103 | vocab=vocab, 104 | is_regression=is_regression) 105 | 106 | logger.info("Top labels (conditional) per class:") 107 | for i, inds in enumerate(likely_indices): 108 | logger.info("\t| Label %d: %s", i, ", ".join([vocab[i] for i in inds[:10]])) 109 | 110 | # Convert to sets. 111 | valid_indices = [set(inds) for inds in likely_indices] 112 | 113 | # If specified, further re-rank according to nearest neighbors of seed labels. 114 | # Otherwise, keep ranking as is (based on conditional likelihood only). 115 | if seed_labels: 116 | assert (vocab is not None) 117 | seed_ids = [vocab.index(l) for l in seed_labels] 118 | vocab_vecs = model.lm_head.decoder.weight.detach().cpu().numpy() 119 | seed_vecs = np.take(vocab_vecs, seed_ids, axis=0) 120 | 121 | # [num_labels, vocab_size] 122 | label_distances = spatial.distance.cdist(seed_vecs, vocab_vecs, metric="cosine") 123 | 124 | # Establish label candidates (as k nearest neighbors). 125 | label_candidates = [] 126 | logger.info("Re-ranked by nearest neighbors:") 127 | for i, distances in enumerate(label_distances): 128 | label_candidates.append(select_neighbors(distances, k_neighbors, valid_indices[i])) 129 | logger.info("\t| Label: %s", seed_labels[i]) 130 | logger.info("\t| Neighbors: %s", " ".join([vocab[idx] for idx in label_candidates[i]])) 131 | else: 132 | label_candidates = likely_indices 133 | 134 | # Brute-force search all valid pairings. 135 | pairings = list(itertools.product(*label_candidates)) 136 | 137 | if is_regression: 138 | eval_pairing = eval_pairing_corr 139 | metric = "corr" 140 | else: 141 | eval_pairing = eval_pairing_acc 142 | metric = "acc" 143 | 144 | # Score each pairing. 145 | pairing_scores = [] 146 | with multiprocessing.Pool(initializer=init, initargs=(train_logits, train_labels)) as workers: 147 | with tqdm.tqdm(total=len(pairings)) as pbar: 148 | chunksize = max(10, int(len(pairings) / 1000)) 149 | for score in workers.imap(eval_pairing, pairings, chunksize=chunksize): 150 | pairing_scores.append(score) 151 | pbar.update() 152 | 153 | # Take top-n. 154 | best_idx = np.argsort(-np.array(pairing_scores))[:top_n] 155 | best_scores = [pairing_scores[i] for i in best_idx] 156 | best_pairings = [pairings[i] for i in best_idx] 157 | 158 | logger.info("Automatically searched pairings:") 159 | for i, indices in enumerate(best_pairings): 160 | logger.info("\t| %s (%s = %2.2f)", " ".join([vocab[j] for j in indices]), metric, best_scores[i]) 161 | 162 | return best_pairings 163 | -------------------------------------------------------------------------------- /examples/image_classification/README.md: -------------------------------------------------------------------------------- 1 | ## Image classification with vision Transformers 2 | 3 | ### Notes 4 | 5 | `main.py` contains simple training code that enables either full fine-tuning (up to isolated embedding parameters) or 6 | linear probing on CIFAR-10. For image classification, linear probing tends to perform generally better than full 7 | fine-tuning. This is confirmed in my personal experiments and also recent works (e.g., [[1]](https://arxiv.org/pdf/2204.13650.pdf) [[2]](https://arxiv.org/pdf/2205.02973.pdf)). 8 | 9 | [1] De, Soham, et al. "Unlocking high-accuracy differentially private image classification through scale." arXiv preprint arXiv:2204.13650 (2022). 10 | 11 | [2] Mehta, Harsh, et al. "Large scale transfer learning for differentially private image classification." arXiv preprint arXiv:2205.02973 (2022). 12 | -------------------------------------------------------------------------------- /examples/image_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/image_classification/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """CIFAR-10 classification with Vi-T.""" 16 | import logging 17 | 18 | import fire 19 | import torch 20 | import torch.nn.functional as F 21 | import tqdm 22 | import transformers 23 | from ml_swissknife import utils 24 | from torchvision import transforms 25 | 26 | import private_transformers 27 | 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | 30 | 31 | @torch.no_grad() 32 | def evaluate(loader, model): 33 | model.eval() 34 | xents, zeons = [], [] 35 | for i, (images, labels) in enumerate(loader): 36 | images, labels = tuple(t.to(device) for t in (images, labels)) 37 | logits = model(pixel_values=images).logits 38 | xents.append(F.cross_entropy(logits, labels, reduction='none')) 39 | zeons.append(logits.argmax(dim=-1).ne(labels).float()) 40 | return tuple(torch.cat(lst).mean().item() for lst in (xents, zeons)) 41 | 42 | 43 | def main( 44 | model_name_or_path='google/vit-base-patch16-224', 45 | train_batch_size=1000, 46 | per_device_train_batch_size=50, 47 | test_batch_size=500, 48 | epochs=10, 49 | target_epsilon=2, 50 | lr=2e-3, 51 | max_grad_norm=0.1, 52 | linear_probe=True, 53 | ): 54 | gradient_accumulation_steps = train_batch_size // per_device_train_batch_size 55 | 56 | image_transform = transforms.Compose([ 57 | transforms.Resize((224, 224)), 58 | transforms.ToTensor(), 59 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 60 | ]) 61 | train_loader, test_loader = utils.get_loader( 62 | data_name='cifar10', 63 | task="classification", 64 | train_batch_size=per_device_train_batch_size, 65 | test_batch_size=test_batch_size, 66 | data_aug=False, 67 | train_transform=image_transform, 68 | test_transform=image_transform, 69 | ) 70 | 71 | config = transformers.AutoConfig.from_pretrained(model_name_or_path) 72 | config.num_labels = 10 73 | model = transformers.ViTForImageClassification.from_pretrained( 74 | model_name_or_path, 75 | config=config, 76 | ignore_mismatched_sizes=True # Default pre-trained model has 1k classes; we only have 10. 77 | ).to(device) 78 | if linear_probe: 79 | model.requires_grad_(False) 80 | model.classifier.requires_grad_(True) 81 | logging.warning("Linear probe classification head.") 82 | else: 83 | private_transformers.freeze_isolated_params_for_vit(model) 84 | logging.warning("Full fine-tune up to isolated embedding parameters.") 85 | 86 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) 87 | privacy_engine = private_transformers.PrivacyEngine( 88 | model, 89 | batch_size=train_batch_size, 90 | sample_size=50000, 91 | epochs=epochs, 92 | max_grad_norm=max_grad_norm, 93 | target_epsilon=target_epsilon, 94 | ) 95 | privacy_engine.attach(optimizer) 96 | 97 | train_loss_meter = utils.AvgMeter() 98 | for epoch in range(epochs): 99 | optimizer.zero_grad() 100 | pbar = tqdm.tqdm(enumerate(train_loader, 1), total=len(train_loader)) 101 | for global_step, (images, labels) in pbar: 102 | model.train() 103 | images, labels = tuple(t.to(device) for t in (images, labels)) 104 | logits = model(pixel_values=images).logits 105 | loss = F.cross_entropy(logits, labels, reduction="none") 106 | train_loss_meter.step(loss.mean().item()) 107 | if global_step % gradient_accumulation_steps == 0: 108 | optimizer.step(loss=loss) 109 | optimizer.zero_grad() 110 | else: 111 | optimizer.virtual_step(loss=loss) 112 | pbar.set_description(f"Train loss running average: {train_loss_meter.item():.4f}") 113 | avg_xent, avg_zeon = evaluate(test_loader, model) 114 | logging.warning( 115 | f"Epoch: {epoch}, average cross ent loss: {avg_xent:.4f}, average zero one loss: {avg_zeon:.4f}" 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | fire.Fire(main) 121 | -------------------------------------------------------------------------------- /examples/table2text/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing results for table-to-text generation 2 | 3 | ### Requirements 4 | 5 | In addition to requirements of the `private-transformers` package, install additional requirements by running the 6 | following from the `examples` folder of this repo: 7 | 8 | ```bash 9 | pip install -r table2text/requirements.txt 10 | ``` 11 | 12 | ### Getting the data 13 | 14 | We host the datasets for E2E and DART on Google drive at 15 | this [link](https://drive.google.com/file/d/1Re1wyUPtS3IalSsVVJhSg2sn8UNa7DM7/view?usp=sharing). Download and unzip the 16 | folder to a reasonable location. The unzipped folder is named `prefix-tuning`, since it's adapted from data used in the 17 | prefix-tuning paper. 18 | 19 | ### Running 20 | 21 | Use the `run.sh` script in the folder. 22 | 23 | Supply at least 3 arguments: 24 | 25 | - `--output_dir`: path to a folder where results will be written 26 | - `--data_folder`: path to the unzipped data folder 27 | - `--task_mode`: name of task; one of `e2e` and `dart` 28 | 29 | For instance, to fine-tune GPT-2 on E2E at ε = 8, run the following from the `examples` folder of this repo: 30 | 31 | ```bash 32 | bash table2text/run.sh "e2e" 33 | ``` 34 | 35 | The script by default uses ghost clipping, and the micro batch size is tweaked so that things should run smoothly even 36 | on a Titan Xp with 12Gigs of VRAM. For E2E, the run-time of this script on an RTX 3090 is roughly less than one and a 37 | half hours. 38 | 39 | Feel free to toggle other arguments like `target_epsilon` and `model_name_or_path` of the `run.sh` script to use 40 | different privacy levels and models. The other hyperparameters should still mostly work for workloads with varied model 41 | and privacy level. 42 | 43 | ### Automatic evaluation 44 | 45 | While the runs automatically decode from the model (via beam-search) once in a while during training, the script does 46 | not run any evaluation on top of the generations. For our paper, we ran the 47 | official [e2e-metrics](https://github.com/tuetschek/e2e-metrics) for evaluating common metrics (e.g., BLEU, ROUGE) on 48 | E2E, and used the [evaluation pipeline in the GEM-benchmark](https://github.com/GEM-benchmark/GEM-metrics) for DART. 49 | -------------------------------------------------------------------------------- /examples/table2text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/table2text/compiled_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Compilation of all the arguments.""" 16 | import logging 17 | import os 18 | import sys 19 | from dataclasses import dataclass, field 20 | from typing import Optional 21 | 22 | import transformers 23 | 24 | MODEL_CONFIG_CLASSES = list(transformers.MODEL_WITH_LM_HEAD_MAPPING.keys()) 25 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 26 | 27 | TRUE_TAGS = ('y', 'yes', 't', 'true') 28 | 29 | 30 | # See all possible arguments in src/transformers/training_args.py 31 | # or by passing the --help flag to this script. 32 | # We now keep distinct sets of args, for a cleaner separation of concerns. 33 | @dataclass 34 | class ModelArguments: 35 | """ 36 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 37 | """ 38 | model_name_or_path: Optional[str] = field( 39 | default=None, 40 | metadata={ 41 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from " 42 | "scratch." 43 | }, 44 | ) 45 | model_type: Optional[str] = field( 46 | default=None, 47 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 48 | ) 49 | config_name: Optional[str] = field( 50 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 51 | ) 52 | tokenizer_name: Optional[str] = field( 53 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 54 | ) 55 | cache_dir: Optional[str] = field( 56 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 57 | ) 58 | 59 | static_lm_head: str = field(default='no') 60 | static_embedding: str = field(default='no') 61 | attention_only: str = field(default="no") 62 | 63 | def __post_init__(self): 64 | self.static_lm_head = self.static_lm_head.lower() in TRUE_TAGS 65 | self.static_embedding = self.static_embedding.lower() in TRUE_TAGS 66 | self.attention_only = self.attention_only.lower() in TRUE_TAGS 67 | 68 | 69 | @dataclass 70 | class DataTrainingArguments: 71 | """ 72 | Arguments pertaining to what data we are going to input our model for training and eval. 73 | """ 74 | data_folder: Optional[str] = field(default=None, metadata={"help": "Path to folder with all the data."}) 75 | 76 | # Useful for truncating the dataset. 77 | max_train_examples: Optional[int] = field(default=sys.maxsize) 78 | max_valid_examples: Optional[int] = field(default=sys.maxsize) 79 | max_eval_examples: Optional[int] = field(default=sys.maxsize) 80 | 81 | line_by_line: bool = field( 82 | default=True, 83 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 84 | ) 85 | task_mode: Optional[str] = field( 86 | default=None, metadata={"help": "The name of the task."} 87 | ) 88 | format_mode: Optional[str] = field( 89 | default='cat', metadata={"help": "The mode of data2text format (cat, peek, nopeek)"} 90 | ) 91 | max_source_length: Optional[int] = field( 92 | default=512, metadata={"help": "the max source length of summarization data. "} 93 | ) 94 | train_max_target_length: Optional[int] = field( 95 | default=100, metadata={"help": "the max target length for training data. "} 96 | ) 97 | val_max_target_length: Optional[int] = field( 98 | default=100, metadata={"help": "the max target length for dev data. "} 99 | ) 100 | block_size: int = field( 101 | default=-1, 102 | metadata={ 103 | "help": "Optional input sequence length after tokenization." 104 | "The training dataset will be truncated in block of this size for training." 105 | "Default to the model max input length for single sentence inputs (take into account special " 106 | "tokens)." 107 | }, 108 | ) 109 | overwrite_cache: bool = field( 110 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 111 | ) 112 | max_seq_len: int = field(default=sys.maxsize) 113 | 114 | def __post_init__(self): 115 | if self.data_folder is not None: 116 | logging.warning(f'Overriding dataset paths using those given in `data_folder`') 117 | 118 | if self.task_mode == "e2e": 119 | self.train_data_file = os.path.join(self.data_folder, 'src1_train.txt') 120 | self.valid_data_file = os.path.join(self.data_folder, 'src1_valid.txt') 121 | self.eval_data_file = os.path.join(self.data_folder, 'src1_test.txt') 122 | 123 | self.train_prompt_file = os.path.join(self.data_folder, 'prompts_train.txt') 124 | self.val_prompt_file = os.path.join(self.data_folder, 'prompts_valid.txt') 125 | self.eval_prompt_file = os.path.join(self.data_folder, 'prompts_test.txt') 126 | 127 | elif self.task_mode == "dart": 128 | self.train_data_file = os.path.join(self.data_folder, 'dart-v1.1.1-full-train.json') 129 | self.valid_data_file = os.path.join(self.data_folder, 'dart-v1.1.1-full-dev.json') 130 | self.eval_data_file = os.path.join(self.data_folder, 'dart-v1.1.1-full-test.json') 131 | 132 | self.train_prompt_file = os.path.join(self.data_folder, 'prompts_train.txt') 133 | self.val_prompt_file = os.path.join(self.data_folder, 'prompts_valid.txt') 134 | self.eval_prompt_file = os.path.join(self.data_folder, 'prompts_test.txt') 135 | 136 | 137 | @dataclass 138 | class TrainingArguments(transformers.TrainingArguments): 139 | max_eval_batches: int = field(default=-1, metadata={"help": "Maximum number of evaluation steps to run."}) 140 | max_generations: int = field(default=sys.maxsize) 141 | max_generations_train: int = field(default=10) 142 | max_generations_valid: int = field(default=10) 143 | skip_generation: str = field(default="no") 144 | 145 | ema_model_averaging: str = field(default="no") 146 | ema_model_gamma: float = field(default=0.99) 147 | ema_model_start_from: int = field(default=1000) 148 | lr_decay: str = field(default="yes") 149 | eval_epochs: int = field(default=10) 150 | 151 | evaluate_during_training: str = field( 152 | default="yes", 153 | metadata={"help": "Run evaluation during training at each logging step."}, 154 | ) 155 | evaluate_before_training: str = field( 156 | default="yes", 157 | metadata={"help": "Run evaluation before training."}, 158 | ) 159 | save_at_last: str = field(default="no", metadata={"help": "Save at the end of training."}) 160 | 161 | def __post_init__(self): 162 | super(TrainingArguments, self).__post_init__() 163 | self.skip_generation = self.skip_generation.lower() in ('y', 'yes') 164 | self.ema_model_averaging = (self.ema_model_averaging.lower() in ('y', 'yes')) 165 | self.lr_decay = (self.lr_decay.lower() in ('y', 'yes')) 166 | self.evaluate_during_training = (self.evaluate_during_training in ('y', 'yes')) 167 | self.evaluate_before_training = (self.evaluate_before_training in ('y', 'yes')) 168 | self.save_at_last = (self.save_at_last in ('y', 'yes')) 169 | 170 | 171 | @dataclass 172 | class PrivacyArguments: 173 | """Arguments for differentially private training.""" 174 | per_example_max_grad_norm: float = field( 175 | default=.1, metadata={ 176 | "help": "Clipping 2-norm of per-sample gradients." 177 | } 178 | ) 179 | noise_multiplier: float = field( 180 | default=None, metadata={ 181 | "help": "Standard deviation of noise added for privacy; if `target_epsilon` is specified, " 182 | "use the one searched based budget" 183 | } 184 | ) 185 | target_epsilon: float = field( 186 | default=None, metadata={ 187 | "help": "Privacy budget; if `None` use the noise multiplier specified." 188 | } 189 | ) 190 | target_delta: float = field( 191 | default=None, metadata={ 192 | "help": "Lax probability in approximate differential privacy; if `None` use 1 / len(train_data)." 193 | } 194 | ) 195 | accounting_mode: str = field( 196 | default="rdp", metadata={"help": "One of `rdp`, `glw`, `all`."} 197 | ) 198 | non_private: str = field(default="no") 199 | clipping_mode: str = field(default="default") 200 | 201 | def __post_init__(self): 202 | self.non_private = self.non_private.lower() in ('y', 'yes') 203 | 204 | 205 | @dataclass 206 | class AuxiliaryArguments: 207 | eval_spectrum: str = field(default="no") 208 | max_spectrum_batches: int = field(default=100) 209 | max_lanczos_iter: int = field(default=100) 210 | 211 | store_grads: str = field(default="no") 212 | orthogonal_projection_path: Optional[str] = field(default=None) 213 | orthogonal_projection_rank: int = field(default=100) 214 | 215 | def __post_init__(self): 216 | self.eval_spectrum = self.eval_spectrum.lower() in TRUE_TAGS # noqa 217 | self.store_grads = self.store_grads.lower() in TRUE_TAGS # noqa 218 | -------------------------------------------------------------------------------- /examples/table2text/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /examples/table2text/data_utils/data_collator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 17 | 18 | import torch 19 | from torch.nn.utils.rnn import pad_sequence 20 | 21 | from transformers.tokenization_utils import PreTrainedTokenizer 22 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy 23 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 24 | 25 | 26 | InputDataClass = NewType("InputDataClass", Any) 27 | 28 | """ 29 | A DataCollator is a function that takes a list of samples from a Dataset 30 | and collate them into a batch, as a dictionary of Tensors. 31 | """ 32 | DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]]) 33 | 34 | 35 | @dataclass 36 | class DataCollatorForData2TextLanguageModeling: 37 | """ 38 | Data collator used for language modeling. 39 | - collates batches of tensors, honoring their tokenizer's pad_token 40 | - preprocesses batches for masked language modeling 41 | """ 42 | tokenizer: PreTrainedTokenizer 43 | mlm: bool = True 44 | format_mode: str = 'cat' 45 | mlm_probability: float = 0.15 46 | 47 | def __call__( 48 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 49 | ) -> Dict[str, torch.Tensor]: 50 | if isinstance(examples[0], (dict, BatchEncoding)): 51 | examples = [e["input_ids"] for e in examples] 52 | input_ids, labels, src, tgt, cate = zip(*examples) 53 | if self.mlm: 54 | inputs, labels = self.mask_tokens(batch) 55 | return {"input_ids": inputs, "labels": labels} 56 | else: 57 | if self.format_mode == 'cat': 58 | mode_input = 3 59 | elif self.format_mode == 'peek': 60 | mode_input = 1 61 | elif self.format_mode == 'nopeek': 62 | mode_input = 2 63 | elif self.format_mode == 'infix': 64 | mode_input = 4 65 | 66 | # mode_input = 1 # means that we take the input again. 67 | # mode_input = 2 # means that we do not peek at src again. 68 | # mode_input = 3 # means that we look at the categories, and see the input again. 69 | 70 | if mode_input == 1: 71 | # input, batch 72 | batch = self._tensorize_batch(input_ids) 73 | labels = self._tensorize_batch(labels) 74 | src = self._tensorize_batch(src) 75 | cate_batch, cate_attn = None, None 76 | # tgt = self._tensorize_batch(tgt) 77 | elif mode_input == 2: 78 | # nopeek. 79 | batch = self._tensorize_batch(tgt) 80 | labels = batch.clone() 81 | src = self._tensorize_batch(src) 82 | cate_batch, cate_attn = None, None 83 | elif mode_input == 3: 84 | batch = self._tensorize_batch(input_ids) 85 | labels = self._tensorize_batch(labels) 86 | src = self._tensorize_batch(cate) 87 | cate_batch, cate_attn = None, None 88 | elif mode_input == 4: 89 | batch = self._tensorize_batch(tgt) 90 | labels = batch.clone() 91 | src = self._tensorize_batch(src) 92 | 93 | cate_batch = self._tensorize_batch(cate) 94 | cate_attn = (cate_batch != self.tokenizer.pad_token_id) 95 | 96 | labels[labels == self.tokenizer.pad_token_id] = -100 # tgt 97 | src_attn = (src != self.tokenizer.pad_token_id) # src 98 | tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt 99 | 100 | if cate_batch is None: 101 | return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 102 | 'src':src} 103 | else: 104 | return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn': tgt_attn, 105 | 'src': src, "cate_batch":cate_batch, "cate_attn":cate_attn} 106 | 107 | def _tensorize_batch( 108 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 109 | ) -> torch.Tensor: 110 | # In order to accept both lists of lists and lists of Tensors 111 | if isinstance(examples[0], (list, tuple)): 112 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 113 | length_of_first = examples[0].size(0) 114 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 115 | if are_tensors_same_length: 116 | return torch.stack(examples, dim=0) 117 | else: 118 | if self.tokenizer._pad_token is None: 119 | raise ValueError( 120 | "You are attempting to pad samples but the tokenizer you are using" 121 | f" ({self.tokenizer.__class__.__name__}) does not have one." 122 | ) 123 | return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) 124 | 125 | def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 126 | """ 127 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 128 | """ 129 | 130 | if self.tokenizer.mask_token is None: 131 | raise ValueError( 132 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." 133 | ) 134 | 135 | labels = inputs.clone() 136 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 137 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 138 | special_tokens_mask = [ 139 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 140 | ] 141 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 142 | if self.tokenizer._pad_token is not None: 143 | padding_mask = labels.eq(self.tokenizer.pad_token_id) 144 | probability_matrix.masked_fill_(padding_mask, value=0.0) 145 | masked_indices = torch.bernoulli(probability_matrix).bool() 146 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 147 | 148 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 149 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 150 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 151 | 152 | # 10% of the time, we replace masked input tokens with random word 153 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 154 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 155 | inputs[indices_random] = random_words[indices_random] 156 | 157 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 158 | return inputs, labels 159 | 160 | 161 | @dataclass 162 | class DataCollatorForSumLanguageModeling: 163 | """ 164 | Data collator used for language modeling. 165 | - collates batches of tensors, honoring their tokenizer's pad_token 166 | - preprocesses batches for masked language modeling 167 | """ 168 | tokenizer: PreTrainedTokenizer 169 | mlm: bool = True 170 | format_mode: str = 'cat' 171 | mlm_probability: float = 0.15 172 | 173 | def __call__( 174 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 175 | ) -> Dict[str, torch.Tensor]: 176 | if isinstance(examples[0], (dict, BatchEncoding)): 177 | examples = [e["input_ids"] for e in examples] 178 | # print(examples[0]) 179 | # print(len(examples)) 180 | input_ids, labels, src, tgt = zip(*examples) 181 | # print(len(input_ids), len(labels), len(weights)) 182 | if self.mlm: 183 | inputs, labels = self.mask_tokens(batch) 184 | return {"input_ids": inputs, "labels": labels} 185 | else: 186 | 187 | # print(self.format_mode) 188 | 189 | if self.format_mode == 'peek' or self.format_mode == 'cat': 190 | mode_input = 1 191 | elif self.format_mode == 'nopeek': 192 | assert False, 'should use format_mode = peek or cat.' 193 | mode_input = 2 194 | elif self.format_mode == 'infix': 195 | assert False, 'should use format_mode = peek or cat.' 196 | mode_input = 4 197 | 198 | # mode_input = 1 # means that we take the input again. 199 | # mode_input = 2 # means that we do not peek at src again. 200 | # mode_input = 3 # means that we look at the categories, and see the input again. 201 | 202 | # print(self.format_mode, mode_input) 203 | 204 | if mode_input == 1: 205 | # input, batch 206 | batch = self._tensorize_batch(input_ids) 207 | labels = self._tensorize_batch(labels) 208 | src = self._tensorize_batch(src) 209 | 210 | labels[labels == self.tokenizer.pad_token_id] = -100 # tgt 211 | src_attn = (src != self.tokenizer.pad_token_id) # src 212 | tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt 213 | 214 | return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 215 | 'src':src} 216 | 217 | 218 | def _tensorize_batch( 219 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 220 | ) -> torch.Tensor: 221 | # In order to accept both lists of lists and lists of Tensors 222 | if isinstance(examples[0], (list, tuple)): 223 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 224 | length_of_first = examples[0].size(0) 225 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 226 | if are_tensors_same_length: 227 | return torch.stack(examples, dim=0) 228 | else: 229 | if self.tokenizer._pad_token is None: 230 | raise ValueError( 231 | "You are attempting to pad samples but the tokenizer you are using" 232 | f" ({self.tokenizer.__class__.__name__}) does not have one." 233 | ) 234 | return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) 235 | -------------------------------------------------------------------------------- /examples/table2text/decoding_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for generation.""" 16 | import logging 17 | import sys 18 | from typing import Optional 19 | 20 | import tqdm 21 | import transformers 22 | 23 | 24 | def generate( 25 | model: transformers.PreTrainedModel, 26 | tokenizer: transformers.PreTrainedTokenizer, 27 | loader=None, 28 | prompt_dataset=None, 29 | max_length=100, 30 | min_length=5, 31 | top_k=0, 32 | top_p=0.9, # Only filter with top_p. 33 | repetition_penalty=1, 34 | do_sample=False, 35 | num_beams=5, 36 | bad_words_ids=None, 37 | dummy_token_id=-100, # Used as mask. 38 | num_return_sequences=1, 39 | max_generations=sys.maxsize, 40 | device=None, 41 | padding_token="[PAD]", 42 | **kwargs, 43 | ): 44 | assert not model.training, "Generation must be when `model` is in eval mode." 45 | if kwargs: 46 | logging.warning(f"Unknown kwargs: {kwargs}") 47 | 48 | # These are linebreaks; generating these will mess up the evaluation, since those files assume one example per-line. 49 | if bad_words_ids is None: 50 | bad_words_ids = [[628], [198]] 51 | if padding_token in tokenizer.get_vocab(): 52 | bad_words_ids.append(tokenizer.encode(padding_token)) 53 | 54 | kwargs = dict( 55 | model=model, 56 | tokenizer=tokenizer, 57 | max_length=max_length, 58 | min_length=min_length, 59 | top_k=top_k, 60 | top_p=top_p, 61 | repetition_penalty=repetition_penalty, 62 | do_sample=do_sample, 63 | num_beams=num_beams, 64 | bad_words_ids=bad_words_ids, 65 | dummy_token_id=dummy_token_id, 66 | num_return_sequences=num_return_sequences, 67 | max_generations=max_generations, 68 | device=device, 69 | padding_token=padding_token, 70 | ) 71 | if loader is not None: 72 | result = _generate_with_loader(loader=loader, **kwargs) 73 | elif prompt_dataset is not None: 74 | result = _generate_with_prompt_dataset(prompt_dataset=prompt_dataset, **kwargs) 75 | else: 76 | raise ValueError(f"`loader` and `prompt_dataset` cannot both be `None`.") 77 | 78 | return result 79 | 80 | 81 | def _generate_with_loader( 82 | loader, 83 | 84 | model, 85 | tokenizer: transformers.PreTrainedTokenizer, 86 | max_length, 87 | min_length, 88 | top_k, 89 | top_p, 90 | repetition_penalty, 91 | do_sample, 92 | num_beams, 93 | bad_words_ids, 94 | dummy_token_id, 95 | num_return_sequences, 96 | max_generations, 97 | device, 98 | padding_token, 99 | ): 100 | references = [] 101 | full_generations = [] # Sentences including the prompt part. 102 | unstripped_generations = [] 103 | generations = [] 104 | 105 | stop_generation = False 106 | for batch_idx, batch in tqdm.tqdm(enumerate(loader), desc="generation"): 107 | if stop_generation: 108 | break 109 | 110 | batch_input_ids, batch_labels = batch["input_ids"], batch["labels"] 111 | # e.g., inputs_ids may be [[95, 123, 32], [198, 19, 120]], and 112 | # labels may be [[-100, 123, 32], [-100, -100, 120] 113 | 114 | for input_ids, labels in zip(batch_input_ids, batch_labels): 115 | if stop_generation: 116 | break 117 | 118 | # Find the first pad token and end the sentence from there! 119 | if padding_token in tokenizer.get_vocab(): 120 | pad_positions, = ( 121 | input_ids == tokenizer.encode(padding_token, return_tensors="pt").squeeze() 122 | ).nonzero(as_tuple=True) 123 | # Some sentences might have padding; others might not. 124 | if pad_positions.numel() == 0: 125 | first_pad_position = None 126 | else: 127 | first_pad_position = pad_positions[0] 128 | reference_str: str = tokenizer.decode(input_ids[:first_pad_position], clean_up_tokenization_spaces=True) 129 | else: 130 | reference_str: str = tokenizer.decode(input_ids, clean_up_tokenization_spaces=True) 131 | references.append(reference_str) 132 | 133 | # Find the first non- -100 position. Note there are trailing -100s. 134 | non_prompt_positions, = (labels != dummy_token_id).nonzero(as_tuple=True) 135 | first_non_prompt_position = non_prompt_positions[0].item() 136 | prompt_len = first_non_prompt_position 137 | prompt_ids = input_ids[:prompt_len] 138 | 139 | output_ids = model.generate( 140 | input_ids=prompt_ids[None, ...].to(device), 141 | max_length=max_length + prompt_len, # This cannot be a 0-D tensor! 142 | min_length=min_length, 143 | top_k=top_k, 144 | top_p=top_p, 145 | repetition_penalty=repetition_penalty, 146 | do_sample=do_sample, 147 | bad_words_ids=bad_words_ids, 148 | num_return_sequences=num_return_sequences, 149 | num_beams=num_beams, 150 | pad_token_id=tokenizer.eos_token_id, # Stop the stupid logging... 151 | ) 152 | output_ids = output_ids.squeeze(dim=0) # Throw away batch dimension. 153 | 154 | whole_str: str = tokenizer.decode(output_ids, clean_up_tokenization_spaces=True) 155 | prompt_str: str = tokenizer.decode(prompt_ids, clean_up_tokenization_spaces=True) 156 | output_str: str = whole_str[len(prompt_str):] 157 | 158 | full_generations.append(whole_str) 159 | del whole_str, prompt_str 160 | 161 | # Remove potential eos_token at the end. 162 | eos_position: Optional[int] = output_str.find(tokenizer.eos_token) 163 | if eos_position == -1: # Didn't generate eos_token; that's okay -- just skip! 164 | eos_position = None 165 | output_str = output_str[:eos_position] 166 | unstripped_generations.append(output_str) 167 | 168 | # Removing leading and trailing spaces. 169 | output_str = output_str.strip() 170 | 171 | generations.append(output_str) 172 | 173 | if len(generations) >= max_generations: 174 | stop_generation = True 175 | 176 | return full_generations, unstripped_generations, generations, references 177 | 178 | 179 | def _generate_with_prompt_dataset( 180 | prompt_dataset, 181 | 182 | model, 183 | tokenizer, 184 | max_length, 185 | min_length, 186 | top_k, 187 | top_p, 188 | repetition_penalty, 189 | do_sample, 190 | num_beams, 191 | bad_words_ids, 192 | dummy_token_id, 193 | num_return_sequences, 194 | max_generations, 195 | device, 196 | padding_token, 197 | ): 198 | references = [] 199 | full_generations = [] # Sentences including the prompt part. 200 | unstripped_generations = [] 201 | generations = [] 202 | 203 | stop_generation = False 204 | for input_ids in tqdm.tqdm(prompt_dataset, desc="generation"): 205 | if stop_generation: 206 | break 207 | 208 | prompt_len = len(input_ids[0]) 209 | output_ids = model.generate( 210 | input_ids=input_ids.to(device), 211 | max_length=max_length + prompt_len, # This cannot be a 0-D tensor! 212 | min_length=min_length, 213 | top_k=top_k, 214 | top_p=top_p, 215 | repetition_penalty=repetition_penalty, 216 | do_sample=do_sample, 217 | bad_words_ids=bad_words_ids, 218 | num_return_sequences=num_return_sequences, 219 | num_beams=num_beams, 220 | pad_token_id=tokenizer.eos_token_id, # Stop the stupid logging... 221 | ) 222 | output_ids = output_ids.squeeze(dim=0) # Throw away batch dimension. 223 | input_ids = input_ids.squeeze(dim=0) 224 | 225 | whole_str: str = tokenizer.decode(output_ids, clean_up_tokenization_spaces=True) 226 | prompt_str: str = tokenizer.decode(input_ids, clean_up_tokenization_spaces=True) 227 | output_str: str = whole_str[len(prompt_str):] 228 | 229 | full_generations.append(whole_str) 230 | del whole_str, prompt_str 231 | 232 | # Remove potential eos_token at the end. 233 | eos_position: Optional[int] = output_str.find(tokenizer.eos_token) 234 | if eos_position == -1: # Didn't generate eos_token; that's okay -- just skip! 235 | eos_position = None 236 | output_str = output_str[:eos_position] 237 | unstripped_generations.append(output_str) 238 | 239 | # Removing leading and trailing spaces. 240 | output_str = output_str.strip() 241 | 242 | generations.append(output_str) 243 | 244 | if len(generations) >= max_generations: 245 | stop_generation = True 246 | return full_generations, unstripped_generations, generations, references 247 | -------------------------------------------------------------------------------- /examples/table2text/density.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # Copyright 2019 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Code for converting Lanczos outputs to densities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import numpy as np 25 | 26 | 27 | def eigv_to_density(eig_vals, all_weights=None, grids=None, 28 | grid_len=10000, sigma_squared=None, grid_expand=1e-2): 29 | """Compute the smoothed spectral density from a set of eigenvalues. 30 | 31 | Convolves the given eigenvalues with a Gaussian kernel, weighting the values 32 | by all_weights (or uniform weighting if all_weights is None). Example output 33 | can be seen in Figure 1 of https://arxiv.org/pdf/1901.10159.pdf. Visualizing 34 | the estimated density can be done by calling plt.plot(grids, density). There 35 | is likely not a best value of sigma_squared that works for all use cases, 36 | so it is recommended to try multiple values in the range [1e-5,1e-1]. 37 | 38 | Args: 39 | eig_vals: Array of shape [num_draws, order] 40 | all_weights: Array of shape [num_draws, order], if None then weights will be 41 | taken to be uniform. 42 | grids: Array of shape [grid_len], the smoothed spectrum will be plotted 43 | in the interval [grids[0], grids[-1]]. If None then grids will be 44 | computed based on max and min eigenvalues and grid length. 45 | grid_len: Integer specifying number of grid cells to use, only used if 46 | grids is None 47 | sigma_squared: Scalar. Controls the smoothing of the spectrum estimate. 48 | If None, an appropriate value is inferred. 49 | grid_expand: Controls the window of values that grids spans. 50 | grids[0] = smallest eigenvalue - grid_expand. 51 | grids[-1] = largest_eigenvalue + grid_expand. 52 | 53 | Returns: 54 | density: Array of shape [grid_len], the estimated density, averaged over 55 | all draws. 56 | grids: Array of shape [grid_len]. The values the density is estimated on. 57 | """ 58 | if all_weights is None: 59 | all_weights = np.ones(eig_vals.shape) * 1.0 / float(eig_vals.shape[1]) 60 | num_draws = eig_vals.shape[0] 61 | 62 | lambda_max = np.nanmean(np.max(eig_vals, axis=1), axis=0) + grid_expand 63 | lambda_min = np.nanmean(np.min(eig_vals, axis=1), axis=0) - grid_expand 64 | 65 | if grids is None: 66 | assert grid_len is not None, 'grid_len is required if grids is None.' 67 | grids = np.linspace(lambda_min, lambda_max, num=grid_len) 68 | 69 | grid_len = grids.shape[0] 70 | if sigma_squared is None: 71 | sigma = 10 ** -5 * max(1, (lambda_max - lambda_min)) 72 | else: 73 | sigma = sigma_squared * max(1, (lambda_max - lambda_min)) 74 | 75 | density_each_draw = np.zeros((num_draws, grid_len)) 76 | for i in range(num_draws): 77 | 78 | if np.isnan(eig_vals[i, 0]): 79 | raise ValueError('tridaig has nan values.') 80 | else: 81 | for j in range(grid_len): 82 | x = grids[j] 83 | vals = _kernel(eig_vals[i, :], x, sigma) 84 | density_each_draw[i, j] = np.sum(vals * all_weights[i, :]) 85 | density = np.nanmean(density_each_draw, axis=0) 86 | norm_fact = np.sum(density) * (grids[1] - grids[0]) 87 | density = density / norm_fact 88 | return density, grids 89 | 90 | 91 | def tridiag_to_eigv(tridiag_list): 92 | """Preprocess the tridiagonal matrices for density estimation. 93 | 94 | Args: 95 | tridiag_list: Array of shape [num_draws, order, order] List of the 96 | tridiagonal matrices computed from running num_draws independent runs 97 | of lanczos. The output of this function can be fed directly into 98 | eigv_to_density. 99 | 100 | Returns: 101 | eig_vals: Array of shape [num_draws, order]. The eigenvalues of the 102 | tridiagonal matricies. 103 | all_weights: Array of shape [num_draws, order]. The weights associated with 104 | each eigenvalue. These weights are to be used in the kernel density 105 | estimate. 106 | """ 107 | # Calculating the node / weights from Jacobi matrices. 108 | num_draws = len(tridiag_list) 109 | num_lanczos = tridiag_list[0].shape[0] 110 | eig_vals = np.zeros((num_draws, num_lanczos)) 111 | all_weights = np.zeros((num_draws, num_lanczos)) 112 | for i in range(num_draws): 113 | nodes, evecs = np.linalg.eigh(tridiag_list[i]) 114 | index = np.argsort(nodes) 115 | nodes = nodes[index] 116 | evecs = evecs[:, index] 117 | eig_vals[i, :] = nodes 118 | all_weights[i, :] = evecs[0] ** 2 119 | return eig_vals, all_weights 120 | 121 | 122 | def tridiag_to_density(tridiag_list, sigma_squared=1e-5, grid_len=10000): 123 | """This function estimates the smoothed density from the output of lanczos. 124 | 125 | Args: 126 | tridiag_list: Array of shape [num_draws, order, order] List of the 127 | tridiagonal matrices computed from running num_draws independent runs 128 | of lanczos. 129 | sigma_squared: Controls the smoothing of the density. 130 | grid_len: Controls the granularity of the density. 131 | 132 | Returns: 133 | density: Array of size [grid_len]. The smoothed density estimate averaged 134 | over all num_draws. 135 | grids: Array of size [grid_len]. The values the density estimate is on. 136 | """ 137 | eig_vals, all_weights = tridiag_to_eigv(tridiag_list) 138 | density, grids = eigv_to_density(eig_vals, all_weights, 139 | grid_len=grid_len, 140 | sigma_squared=sigma_squared) 141 | return density, grids 142 | 143 | 144 | def _kernel(x, x0, variance): 145 | """Point estimate of the Gaussian kernel. 146 | 147 | This function computes the Gaussian kernel for 148 | C exp(-(x - x0) ^2 /(2 * variance)) where C is the appropriate normalization. 149 | variance should be a list of length 1. Either x0 or x should be a scalar. Only 150 | one of the x or x0 can be a numpy array. 151 | 152 | Args: 153 | x: Can be either scalar or array of shape [order]. Points to estimate 154 | the kernel on. 155 | x0: Scalar. Mean of the kernel. 156 | variance: Scalar. Variance of the kernel. 157 | 158 | Returns: 159 | point_estimate: A scalar corresponding to 160 | C exp(-(x - x0) ^2 /(2 * variance)). 161 | """ 162 | coeff = 1.0 / np.sqrt(2 * math.pi * variance) 163 | val = -(x0 - x) ** 2 164 | val = val / (2.0 * variance) 165 | val = np.exp(val) 166 | point_estimate = coeff * val 167 | return point_estimate 168 | -------------------------------------------------------------------------------- /examples/table2text/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Miscellaneous utilities. 16 | 17 | Mostly bespoke data loaders at the moment. 18 | """ 19 | 20 | from transformers import ( 21 | DataCollatorForLanguageModeling, 22 | DataCollatorForPermutationLanguageModeling, 23 | PreTrainedTokenizer 24 | ) 25 | 26 | from .compiled_args import DataTrainingArguments 27 | from .data_utils.data_collator import DataCollatorForData2TextLanguageModeling 28 | from .data_utils.language_modeling import LineByLineE2ETextDataset, LineByLineTriplesTextDataset 29 | 30 | 31 | def get_dataset_with_path( 32 | data_args: DataTrainingArguments, 33 | tokenizer: PreTrainedTokenizer, 34 | file_path: str, 35 | max_examples: int, 36 | **_, 37 | ): 38 | if data_args.line_by_line: 39 | if data_args.task_mode == 'e2e': 40 | dataset = LineByLineE2ETextDataset( 41 | tokenizer=tokenizer, 42 | file_path=file_path, 43 | block_size=data_args.block_size, 44 | bos_tok=tokenizer.bos_token, 45 | eos_tok=tokenizer.eos_token, 46 | max_seq_len=data_args.max_seq_len, 47 | max_examples=max_examples, 48 | ) 49 | elif data_args.task_mode == 'dart': 50 | dataset = LineByLineTriplesTextDataset( 51 | tokenizer=tokenizer, 52 | file_path=file_path, 53 | block_size=data_args.block_size, 54 | bos_tok=tokenizer.bos_token, 55 | eos_tok=tokenizer.eos_token, 56 | max_seq_len=data_args.max_seq_len, 57 | max_examples=max_examples, 58 | ) 59 | else: 60 | raise ValueError(f"Unknown `args.task_mode`: {data_args.task_mode}") 61 | 62 | else: 63 | raise ValueError("table2text task don't support anything other than line_by_line!") 64 | return dataset 65 | 66 | 67 | def get_prompt_dataset(file_path, tokenizer): 68 | with open(file_path, 'r') as f: 69 | lines = f.readlines() 70 | encoded_lines = [ 71 | tokenizer.encode(line.strip(), add_special_tokens=False, return_tensors="pt") 72 | for line in lines 73 | ] 74 | return encoded_lines 75 | 76 | 77 | def get_all_datasets(config, tokenizer, data_args, model_args, **_): 78 | kwargs = dict(data_args=data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) 79 | train_dataset = get_dataset_with_path( 80 | **kwargs, file_path=data_args.train_data_file, max_examples=data_args.max_train_examples 81 | ) 82 | valid_dataset = get_dataset_with_path( 83 | **kwargs, file_path=data_args.valid_data_file, max_examples=data_args.max_valid_examples 84 | ) 85 | eval_dataset = get_dataset_with_path( 86 | **kwargs, file_path=data_args.eval_data_file, max_examples=data_args.max_eval_examples 87 | ) 88 | 89 | if config.model_type == "xlnet": 90 | data_collator = DataCollatorForPermutationLanguageModeling( 91 | tokenizer=tokenizer, 92 | plm_probability=data_args.plm_probability, 93 | max_span_length=data_args.max_span_length, 94 | ) 95 | else: 96 | if data_args.task_mode == 'e2e' or data_args.task_mode == 'dart': 97 | data_collator = DataCollatorForData2TextLanguageModeling( 98 | tokenizer=tokenizer, mlm=False, format_mode=data_args.format_mode 99 | ) 100 | else: 101 | data_collator = DataCollatorForLanguageModeling( 102 | tokenizer=tokenizer, mlm=False, 103 | ) 104 | 105 | return train_dataset, valid_dataset, eval_dataset, data_collator 106 | -------------------------------------------------------------------------------- /examples/table2text/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | from transformers import GPT2PreTrainedModel, GPT2LMHeadModel 18 | 19 | 20 | class _View(nn.Module): 21 | def __init__(self, shape): 22 | super(_View, self).__init__() 23 | self.shape = shape 24 | 25 | def forward(self, x): 26 | return x.reshape(*self.shape) 27 | 28 | 29 | class PrefixTuner(GPT2PreTrainedModel): 30 | """A minimalistic implementation of the core components.""" 31 | 32 | def __init__(self, config, model_args, gpt2=None): 33 | super(PrefixTuner, self).__init__(config=config) 34 | 35 | # Instantiate a GPT-2, and DON'T optimizer it! 36 | if gpt2 is None: 37 | self.gpt2 = GPT2LMHeadModel.from_pretrained( 38 | model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, 39 | ) 40 | else: 41 | self.gpt2 = gpt2 42 | 43 | self.register_buffer('extra_prefix_ids', torch.arange(model_args.prefix_len)) 44 | # TODO: Also introduce the easier net. 45 | self.extra_prefix_net = nn.Sequential( 46 | nn.Embedding(model_args.prefix_len, config.n_embd), 47 | nn.Linear(config.n_embd, model_args.mid_dim), 48 | nn.Tanh(), 49 | nn.Linear(model_args.mid_dim, config.n_layer * 2 * config.n_embd), 50 | _View((-1, model_args.prefix_len, config.n_layer * 2, config.n_head, config.n_embd // config.n_head)), 51 | nn.Dropout(model_args.prefix_dropout), 52 | ) 53 | 54 | def make_past_key_values(self, bsz=None): 55 | extra_prefix_ids = self.extra_prefix_ids[None, :].expand(bsz, -1) 56 | past_key_values = self.extra_prefix_net(extra_prefix_ids) 57 | # (n_layer, batch_size, n_head, prefix_len, n_embed // n_head). 58 | # e.g., (2, 1, 12, 5, 64,). 59 | past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2, dim=0) 60 | return past_key_values 61 | 62 | def state_dict(self): 63 | """Avoid storing GPT-2, since it's not even trained.""" 64 | return self.extra_prefix_net.state_dict() 65 | 66 | def load_state_dict(self, state_dict): 67 | """Avoid loading GPT-2, since it's not even trained.""" 68 | self.extra_prefix_net.load_state_dict(state_dict) 69 | 70 | @property 71 | def major_device(self): 72 | """Returns the device where the parameters are on.""" 73 | return next(self.parameters()).device 74 | 75 | def forward( 76 | self, 77 | input_ids, 78 | attention_mask=None, 79 | token_type_ids=None, 80 | position_ids=None, 81 | head_mask=None, 82 | inputs_embeds=None, 83 | encoder_hidden_states=None, 84 | encoder_attention_mask=None, 85 | labels=None, 86 | use_cache=None, 87 | output_attentions=None, 88 | output_hidden_states=None, 89 | return_dict=None, 90 | **kwargs, 91 | ): 92 | past_key_values = self.make_past_key_values(bsz=input_ids.size(0)) 93 | return self.gpt2( 94 | input_ids=input_ids, 95 | past_key_values=past_key_values, 96 | attention_mask=attention_mask, 97 | token_type_ids=token_type_ids, 98 | position_ids=position_ids, 99 | head_mask=head_mask, 100 | inputs_embeds=inputs_embeds, 101 | encoder_hidden_states=encoder_hidden_states, 102 | encoder_attention_mask=encoder_attention_mask, 103 | labels=labels, 104 | use_cache=use_cache, 105 | output_attentions=output_attentions, 106 | output_hidden_states=output_hidden_states, 107 | return_dict=return_dict, 108 | **kwargs 109 | ) 110 | 111 | def generate(self, input_ids, num_beams, **kwargs): 112 | # Additional files also changed: 113 | # src/transformers/generation_utils.py 114 | # src/transformers/models/gpt2/modeling_gpt2.py 115 | 116 | # --- lxuechen: This part is really error-prone! 117 | # A sanity check is to optimize the model for a few updates and check if the beam-search generations changed. 118 | # The confusing logic in generation_utils: 119 | # 1) `past` is used in `GPT2LMHeadModel:prepare_inputs_for_generation`, 120 | # 2) it's converted to `past_key_values` in that function, 121 | # 3) `past_key_values` is then updated in forward due to return_dict, 122 | # 4) `past` is set to `past_key_values` in `generation_utils:_update_model_kwargs_for_generation` 123 | 124 | # This is expansion step is important for generation, since otherwise the shapes are wrong. 125 | past_key_values = self.make_past_key_values(bsz=input_ids.size(0) * num_beams) 126 | # --- 127 | 128 | return self.gpt2.generate( 129 | input_ids=input_ids, 130 | num_beams=num_beams, 131 | past_key_values=past_key_values, 132 | 133 | use_cache=True, 134 | position_ids=None, 135 | 136 | # --- lxuechen: These arguments I created to make sure prefix-tuning works correctly. 137 | # The logic: At beginning, past=None, and then it gets replaced with past_key_values. 138 | # Can't directly give in past, since otherwise, input_ids gets truncated to the last index. 139 | use_past_key_values_as_past_at_init=True, 140 | nullify_attention_mask=True, 141 | # --- 142 | 143 | **kwargs 144 | ) 145 | -------------------------------------------------------------------------------- /examples/table2text/requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | datasets 3 | -------------------------------------------------------------------------------- /examples/table2text/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_dir=${1} 4 | data_dir=${2} 5 | task_mode=${3} 6 | model_name_or_path=${4:-"gpt2"} # One of distilgpt2, gpt2, gpt2-medium, gpt2-large 7 | target_epsilon=${5:-"8"} 8 | cache_dir=${6} 9 | clipping_mode=${7:-"ghost"} # Fill 'default' to turn this off. 10 | non_private=${8:-"no"} 11 | 12 | if [[ ${task_mode} == "e2e" ]]; then 13 | data_dir="${data_dir}/data/e2e_data" 14 | target_delta=8e-6 15 | num_train_epochs=10 16 | learning_rate=2e-3 17 | max_seq_len=100 18 | else 19 | if [[ ${task_mode} == "dart" ]]; then 20 | target_delta=1e-5 21 | data_dir="${data_dir}/data/dart" 22 | num_train_epochs=15 # Approximately same number of updates. 23 | learning_rate=5e-4 # Lower learning rate for stability in large models. 24 | max_seq_len=120 25 | else 26 | echo "Unknown task: ${task_mode}" 27 | exit 1 28 | fi 29 | fi 30 | 31 | # Arguments in the last two lines are the most important. 32 | python -m table2text.run_language_modeling \ 33 | --output_dir ${output_dir} --overwrite_output_dir \ 34 | --task_mode ${task_mode} \ 35 | --model_name_or_path ${model_name_or_path} \ 36 | --tokenizer_name ${model_name_or_path} \ 37 | --do_train --do_eval \ 38 | --line_by_line \ 39 | --save_steps 100 --save_total_limit 1 --save_at_last no \ 40 | --logging_dir ${output_dir} --logging_steps -1 \ 41 | --seed 0 \ 42 | --eval_steps 100 --eval_epochs 2 --max_eval_batches 100 --evaluation_strategy epoch --evaluate_before_training "no" --evaluate_during_training "yes" --per_device_eval_batch_size 10 \ 43 | --max_generations 9223372036854775807 --max_generations_train 10 --max_generations_valid 9223372036854775807 \ 44 | --max_train_examples 9223372036854775807 --max_valid_examples 9223372036854775807 --max_eval_examples 9223372036854775807 \ 45 | --data_folder ${data_dir} --max_seq_len ${max_seq_len} --format_mode cat \ 46 | --per_example_max_grad_norm 0.1 --target_delta ${target_delta} --target_epsilon ${target_epsilon} \ 47 | --learning_rate ${learning_rate} --lr_decay "no" --num_train_epochs ${num_train_epochs} --per_device_train_batch_size 16 --gradient_accumulation_steps 64 \ 48 | --non_private ${non_private} \ 49 | --clipping_mode "${clipping_mode}" \ 50 | --cache_dir ${cache_dir} 51 | -------------------------------------------------------------------------------- /examples/table2text/run_language_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Xuechen Li. All Rights Reserved. 3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet). 19 | GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned 20 | using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss. 21 | """ 22 | 23 | import json 24 | import logging 25 | import os 26 | 27 | import torch 28 | from ml_swissknife import utils 29 | from transformers import HfArgumentParser, MODEL_WITH_LM_HEAD_MAPPING, set_seed 30 | from transformers.models.gpt2 import GPT2Tokenizer 31 | from transformers.optimization import get_linear_schedule_with_warmup 32 | 33 | from private_transformers import PrivacyEngine 34 | from .compiled_args import (AuxiliaryArguments, DataTrainingArguments, ModelArguments, PrivacyArguments, 35 | TrainingArguments) 36 | from .misc import get_all_datasets, get_prompt_dataset 37 | from .trainer import Trainer 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) 42 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 43 | 44 | 45 | def main(): 46 | parser = HfArgumentParser( 47 | (ModelArguments, DataTrainingArguments, TrainingArguments, PrivacyArguments, AuxiliaryArguments) 48 | ) 49 | model_args, data_args, training_args, privacy_args, auxiliary_args = parser.parse_args_into_dataclasses() 50 | 51 | model_args: ModelArguments 52 | data_args: DataTrainingArguments 53 | training_args: TrainingArguments 54 | privacy_args: PrivacyArguments 55 | auxiliary_args: AuxiliaryArguments 56 | 57 | if data_args.eval_data_file is None and training_args.do_eval: 58 | raise ValueError( 59 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 60 | "or remove the --do_eval argument." 61 | ) 62 | 63 | if ( 64 | os.path.exists(training_args.output_dir) 65 | and os.listdir(training_args.output_dir) 66 | and training_args.do_train 67 | and not training_args.overwrite_output_dir 68 | ): 69 | raise ValueError( 70 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use " 71 | f"--overwrite_output_dir to overcome." 72 | ) 73 | 74 | # Setup logging 75 | logging.basicConfig( 76 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 77 | datefmt="%m/%d/%Y %H:%M:%S", 78 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 79 | ) 80 | logger.warning( 81 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 82 | training_args.local_rank, 83 | training_args.device, 84 | training_args.n_gpu, 85 | bool(training_args.local_rank != -1), 86 | training_args.fp16, 87 | ) 88 | logger.info("Training/evaluation parameters %s", training_args) 89 | 90 | # Set seed 91 | set_seed(training_args.seed) 92 | 93 | # Debug mode 94 | if training_args.debug: 95 | import warnings 96 | warnings.filterwarnings("error") 97 | 98 | # Low rank models need special models! 99 | from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel 100 | 101 | # Config. 102 | config = GPT2Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 103 | config.return_dict = True 104 | config.tie_word_embeddings = False 105 | 106 | # Tokenizer; `bos_token` and `eos_token` is the same for GPT2; both are 50256. 107 | tokenizer = GPT2Tokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 108 | 109 | # Model. 110 | gpt2 = GPT2LMHeadModel.from_pretrained( 111 | model_args.model_name_or_path, 112 | config=config, 113 | cache_dir=model_args.cache_dir, 114 | ) 115 | print(f'base gpt2 model: {model_args.model_name_or_path}') 116 | print(gpt2) 117 | 118 | # Clone the embedding into the lm_head for better initialization. 119 | lm_head = gpt2.get_output_embeddings() 120 | embedding = gpt2.get_input_embeddings() 121 | lm_head.weight.data.copy_(embedding.weight.data) 122 | print(f'Cloning initial embedding into lm_head, ' 123 | f'checking norms... \n' 124 | f'\tlm_head: {lm_head.weight.norm()}, embedding: {embedding.weight.norm()}') 125 | torch.testing.assert_allclose(lm_head.weight, embedding.weight) 126 | del lm_head, embedding 127 | 128 | if data_args.block_size <= 0: 129 | data_args.block_size = tokenizer.model_max_length 130 | else: 131 | data_args.block_size = min(data_args.block_size, tokenizer.model_max_length) 132 | 133 | # Adjust tokenizer and model embeddings. 134 | print('adapt tokenizer to include [PAD]') 135 | print(f'before len(tokenizer) = {len(tokenizer)}') 136 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 137 | print(f'after len(tokenizer) = {len(tokenizer)}') 138 | print('tokenizer.eos_token:', tokenizer.eos_token, tokenizer.eos_token_id) 139 | print('tokenizer.bos_token:', tokenizer.bos_token, tokenizer.bos_token_id) 140 | 141 | print('adapt the size of lm_head and input_embeddings to include [PAD]') 142 | print('use avg-based initialization') 143 | 144 | input_embeddings_before = gpt2.get_input_embeddings().weight 145 | lm_head_before = gpt2.get_output_embeddings().weight 146 | gpt2.resize_token_embeddings(len(tokenizer)) 147 | 148 | input_embeddings_after = gpt2.get_input_embeddings().weight 149 | lm_head_after = gpt2.get_output_embeddings().weight 150 | print( 151 | f'before lm_head.weight.size() = {lm_head_before.size()}, ' 152 | f'input_embeddings_before.size() = {input_embeddings_before.size()}' 153 | ) 154 | print( 155 | f'after lm_head.weight.size() = {lm_head_after.size()}, ' 156 | f'after input_embeddings_after.size() = {input_embeddings_after.size()}' 157 | ) 158 | torch.testing.assert_allclose(lm_head_before, lm_head_after[:-1]) 159 | print('pre-chunk equal for lm_head') 160 | torch.testing.assert_allclose(input_embeddings_before, input_embeddings_after[:-1]) 161 | print('pre-chunk equal for input_embeddings') 162 | lm_head_after.data[-1] = lm_head_before.mean(dim=0) 163 | input_embeddings_after.data[-1] = input_embeddings_before.mean(dim=0) 164 | 165 | print('double check: ') 166 | print('embedding size', gpt2.get_input_embeddings().weight.size()) 167 | print('lm_head size', gpt2.get_output_embeddings().weight.size()) 168 | model = gpt2 169 | 170 | train_dataset, val_dataset, eval_dataset, data_collator = get_all_datasets( 171 | config=config, 172 | tokenizer=tokenizer, 173 | data_args=data_args, 174 | training_args=training_args, 175 | model_args=model_args, 176 | ) 177 | 178 | # Materialize the prompts. 179 | generation_stuff = dict( 180 | train_prompts=get_prompt_dataset(file_path=data_args.train_prompt_file, tokenizer=tokenizer), 181 | val_prompts=get_prompt_dataset(file_path=data_args.val_prompt_file, tokenizer=tokenizer), 182 | eval_prompts=get_prompt_dataset(file_path=data_args.eval_prompt_file, tokenizer=tokenizer), 183 | ) 184 | 185 | trainer = Trainer( 186 | model=model, 187 | tokenizer=tokenizer, 188 | args=training_args, 189 | model_args=model_args, 190 | data_args=data_args, 191 | privacy_args=privacy_args, 192 | auxiliary_args=auxiliary_args, 193 | train_dataset=train_dataset, 194 | val_dataset=val_dataset, 195 | eval_dataset=eval_dataset, 196 | data_collator=data_collator, 197 | generation_stuff=generation_stuff, 198 | ) 199 | 200 | # Massage the parameters. 201 | if model_args.attention_only: 202 | model.requires_grad_(False) 203 | for name, param in model.named_parameters(): 204 | if 'c_attn.weight' in name: 205 | param.requires_grad_(True) 206 | else: 207 | model.requires_grad_(True) 208 | if model_args.static_lm_head: 209 | model.get_output_embeddings().requires_grad_(False) 210 | if model_args.static_embedding: 211 | model.get_input_embeddings().requires_grad_(False) 212 | model.transformer.wpe.requires_grad_(False) 213 | params = tuple(param for param in model.parameters() if param.requires_grad) 214 | names = tuple(name for name, param in model.named_parameters() if param.requires_grad) 215 | num_trainable_params = sum(param.numel() for param in params) 216 | print(f"Number of trainable params: {num_trainable_params / 1e6:.4f} million") 217 | print(json.dumps(names, indent=4)) 218 | 219 | # TODO: Using a single gigantic parameter group is okay only when `weight_decay` is 0. 220 | # Biases and LM parameters should not be decayed perhaps even with privacy. 221 | optimizer = torch.optim.AdamW( 222 | params=params, 223 | lr=training_args.learning_rate, 224 | betas=(training_args.adam_beta1, training_args.adam_beta2), 225 | eps=training_args.adam_epsilon, 226 | ) 227 | trainer.optimizer = optimizer 228 | 229 | # Create the lr_scheduler. 230 | num_update_steps_per_epoch = len(trainer.get_train_dataloader()) // trainer.args.gradient_accumulation_steps 231 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 232 | t_total = int(num_update_steps_per_epoch * trainer.args.num_train_epochs) 233 | if training_args.lr_decay: 234 | trainer.lr_scheduler = get_linear_schedule_with_warmup( 235 | trainer.optimizer, 236 | num_warmup_steps=training_args.warmup_steps, 237 | num_training_steps=t_total, 238 | ) 239 | else: 240 | trainer.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(trainer.optimizer, lambda _: 1.) 241 | 242 | # Hacky way to set noise_multiplier. 243 | if privacy_args.non_private: 244 | privacy_args.noise_multiplier = 0. 245 | privacy_args.per_example_max_grad_norm = None 246 | else: 247 | actual_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps 248 | privacy_engine = PrivacyEngine( 249 | module=model, 250 | batch_size=actual_batch_size, 251 | sample_size=len(train_dataset), 252 | epochs=training_args.num_train_epochs, 253 | max_grad_norm=privacy_args.per_example_max_grad_norm, 254 | noise_multiplier=privacy_args.noise_multiplier, 255 | target_epsilon=privacy_args.target_epsilon, 256 | target_delta=privacy_args.target_delta, 257 | accounting_mode=privacy_args.accounting_mode, 258 | clipping_mode=privacy_args.clipping_mode, 259 | ) 260 | # Originally, these could have been null. 261 | privacy_args.noise_multiplier = privacy_engine.noise_multiplier 262 | privacy_args.target_delta = privacy_engine.target_delta 263 | 264 | print('privacy_args: ') 265 | print(json.dumps(privacy_args.__dict__, indent=4)) 266 | privacy_engine.attach(optimizer) 267 | 268 | # Training. 269 | if training_args.do_train: 270 | all_args = { 271 | **training_args.__dict__, 272 | **data_args.__dict__, 273 | **model_args.__dict__, 274 | **privacy_args.__dict__, 275 | } 276 | utils.jdump( 277 | all_args, 278 | os.path.join(training_args.output_dir, 'argparse.json'), 279 | default=lambda x: str(x), 280 | ) 281 | 282 | # For convenience, we also re-save the tokenizer to the same directory, 283 | # so that you can share your model easily on huggingface.co/models =) 284 | if trainer.is_world_master(): 285 | tokenizer.save_pretrained(training_args.output_dir) 286 | 287 | logger.info("*** Train ***") 288 | logger.info( 289 | f"Training set size: {len(train_dataset)}, " 290 | f"per_device_train_batch_size: {training_args.per_device_train_batch_size}, " 291 | f"gradient_accumulation_steps: {training_args.gradient_accumulation_steps}" 292 | ) 293 | # lxuechen: Especially so for the restored checkpoints. Don't resume... 294 | trainer.train(model_path=None) 295 | if training_args.save_at_last: 296 | trainer.save_model() 297 | 298 | # Evaluation 299 | if training_args.do_eval: 300 | logger.info("*** Evaluate ***") 301 | 302 | output = trainer.evaluate(log_results=False) 303 | utils.jdump( 304 | output, 305 | os.path.join(training_args.output_dir, "final_results.json"), 306 | ) 307 | 308 | logger.info("***** Eval results *****") 309 | logger.info(output) 310 | 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /private_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import lora_utils 16 | from .privacy_engine import PrivacyEngine 17 | from .transformers_support import freeze_isolated_params_for_vit 18 | 19 | __version__ = '0.2.3' 20 | -------------------------------------------------------------------------------- /private_transformers/accounting/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /private_transformers/accounting/accounting_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import math 17 | from typing import Dict, Optional, Union 18 | 19 | from . import rdp_accounting 20 | 21 | DEFAULT_ALPHAS = tuple(1 + x / 10.0 for x in range(1, 100)) + tuple(range(12, 64)) # RDP. 22 | 23 | 24 | class AccountingManager(abc.ABC): 25 | def _get_sigma_with_target_epsilon( 26 | self, 27 | target_epsilon, 28 | target_delta, 29 | sample_rate, 30 | steps, 31 | threshold, 32 | sigma_hi_init, 33 | sigma_lo_init, 34 | ): 35 | """Binary search σ given ε and δ.""" 36 | if sigma_lo_init > sigma_hi_init: 37 | raise ValueError("`sigma_lo` should be smaller than `sigma_hi`.") 38 | 39 | # Find an appropriate region for binary search. 40 | sigma_hi = sigma_hi_init 41 | sigma_lo = sigma_lo_init 42 | 43 | # Ensure sigma_hi isn't too small. 44 | while True: 45 | eps = self._compute_epsilon_from_sigma(sigma_hi, sample_rate, target_delta, steps) 46 | if eps < target_epsilon: 47 | break 48 | sigma_hi *= 2 49 | 50 | # Ensure sigma_lo isn't too large. 51 | while True: 52 | eps = self._compute_epsilon_from_sigma(sigma_lo, sample_rate, target_delta, steps) 53 | if eps > target_epsilon: 54 | break 55 | sigma_lo /= 2 56 | 57 | # Binary search. 58 | while sigma_hi - sigma_lo > threshold: 59 | sigma = (sigma_hi + sigma_lo) / 2 60 | eps = self._compute_epsilon_from_sigma(sigma, sample_rate, target_delta, steps) 61 | if eps < target_epsilon: 62 | sigma_hi = sigma 63 | else: 64 | sigma_lo = sigma 65 | 66 | # Conservative estimate. 67 | return sigma_hi 68 | 69 | @abc.abstractmethod 70 | def compute_epsilon(self, sigma, sample_rate, target_delta, steps) -> Dict: 71 | """Override for reporting results.""" 72 | raise NotImplementedError 73 | 74 | @abc.abstractmethod 75 | def _compute_epsilon_from_sigma(self, sigma, sample_rate, target_delta, steps) -> float: 76 | """Override for binary sigma search.""" 77 | raise NotImplementedError 78 | 79 | def compute_sigma( 80 | self, 81 | target_epsilon: float, 82 | target_delta: float, 83 | sample_rate: float, 84 | epochs: Optional[Union[float, int]] = None, 85 | steps=None, 86 | threshold=1e-3, 87 | sigma_hi_init=4, 88 | sigma_lo_init=0.1, 89 | ) -> float: 90 | if steps is None: 91 | if epochs is None: 92 | raise ValueError("Epochs and steps cannot both be None.") 93 | steps = math.ceil(epochs / sample_rate) 94 | return self._get_sigma_with_target_epsilon( 95 | target_epsilon=target_epsilon, 96 | target_delta=target_delta, 97 | sample_rate=sample_rate, 98 | steps=steps, 99 | threshold=threshold, 100 | sigma_hi_init=sigma_hi_init, 101 | sigma_lo_init=sigma_lo_init, 102 | ) 103 | 104 | 105 | class RDPManager(AccountingManager): 106 | def __init__(self, alphas): 107 | super(RDPManager, self).__init__() 108 | self._alphas = alphas 109 | 110 | def _compute_epsilon_from_sigma(self, sigma, sample_rate, target_delta, steps): 111 | return self.compute_epsilon(sigma, sample_rate, target_delta, steps)["eps_rdp"] 112 | 113 | def compute_epsilon(self, sigma, sample_rate, target_delta, steps) -> Dict: 114 | """Compute RDP as usual, but convert to (ε, δ)-DP based on the result by Canonne, Kamath, Steinke.""" 115 | rdp = rdp_accounting.compute_rdp(q=sample_rate, noise_multiplier=sigma, steps=steps, orders=self._alphas) 116 | eps, alpha = rdp_accounting.get_privacy_spent(orders=self._alphas, rdp=rdp, delta=target_delta) 117 | return dict(eps_rdp=eps, alpha_rdp=alpha) 118 | 119 | 120 | class GLWManager(AccountingManager): 121 | def __init__(self, eps_error=0.05): 122 | super(GLWManager, self).__init__() 123 | self._eps_error = eps_error 124 | 125 | def _compute_epsilon_from_sigma(self, sigma, sample_rate, target_delta, steps): 126 | return self.compute_epsilon(sigma, sample_rate, target_delta, steps)["eps_upper"] # Be conservative. 127 | 128 | def compute_epsilon(self, sigma, sample_rate, target_delta, steps) -> Dict: 129 | if steps == 0: 130 | return dict(eps_low=None, eps_estimate=None, eps_upper=None) 131 | 132 | from prv_accountant import Accountant 133 | accountant = Accountant( 134 | noise_multiplier=sigma, 135 | sampling_probability=sample_rate, 136 | delta=target_delta, 137 | eps_error=self._eps_error, 138 | max_compositions=steps 139 | ) 140 | eps_low, eps_estimate, eps_upper = accountant.compute_epsilon(num_compositions=steps) 141 | return dict(eps_low=eps_low, eps_estimate=eps_estimate, eps_upper=eps_upper) 142 | -------------------------------------------------------------------------------- /private_transformers/accounting/rdp_accounting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r""" 17 | This file is adapted from the privacy accounting procedure in Opacus', which in turn is adapted from tf-privacy. 18 | Below is the original documentation in Opacus. 19 | 20 | *Based on Google's TF Privacy:* https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/analysis 21 | /rdp_accountant.py. 22 | *Here, we update this code to Python 3, and optimize dependencies.* 23 | 24 | Functionality for computing Renyi Differential Privacy (RDP) of an additive 25 | Sampled Gaussian Mechanism (SGM). 26 | 27 | Example: 28 | Suppose that we have run an SGM applied to a function with L2-sensitivity of 1. 29 | 30 | Its parameters are given as a list of tuples 31 | ``[(q_1, sigma_1, steps_1), ..., (q_k, sigma_k, steps_k)],`` 32 | and we wish to compute epsilon for a given target delta. 33 | 34 | The example code would be: 35 | 36 | >>> max_order = 32 37 | >>> orders = range(2, max_order + 1) 38 | >>> rdp = np.zeros_like(orders, dtype=float) 39 | >>> for q, sigma, steps in parameters: 40 | >>> rdp += privacy_analysis.compute_rdp(q, sigma, steps, orders) 41 | >>> epsilon, opt_order = privacy_analysis.get_privacy_spent(orders, rdp, delta) 42 | """ 43 | 44 | import math 45 | import numpy as np 46 | from scipy import special 47 | from typing import List, Sequence, Union 48 | 49 | 50 | ######################## 51 | # LOG-SPACE ARITHMETIC # 52 | ######################## 53 | 54 | 55 | def _log_add(logx: float, logy: float) -> float: 56 | r"""Adds two numbers in the log space. 57 | 58 | Args: 59 | logx: First term in log space. 60 | logy: Second term in log space. 61 | 62 | Returns: 63 | Sum of numbers in log space. 64 | """ 65 | a, b = min(logx, logy), max(logx, logy) 66 | if a == -np.inf: # adding 0 67 | return b 68 | # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) 69 | return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) 70 | 71 | 72 | def _log_sub(logx: float, logy: float) -> float: 73 | r"""Subtracts two numbers in the log space. 74 | 75 | Args: 76 | logx: First term in log space. Expected to be greater than the second term. 77 | logy: First term in log space. Expected to be less than the first term. 78 | 79 | Returns: 80 | Difference of numbers in log space. 81 | 82 | Raises: 83 | ValueError 84 | If the result is negative. 85 | """ 86 | if logx < logy: 87 | raise ValueError("The result of subtraction must be non-negative.") 88 | if logy == -np.inf: # subtracting 0 89 | return logx 90 | if logx == logy: 91 | return -np.inf # 0 is represented as -np.inf in the log space. 92 | 93 | try: 94 | # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). 95 | return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 96 | except OverflowError: 97 | return logx 98 | 99 | 100 | def _compute_log_a_for_int_alpha(q: float, sigma: float, alpha: int) -> float: 101 | r"""Computes :math:`log(A_\alpha)` for integer ``alpha``. 102 | 103 | Notes: 104 | Note that 105 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``, 106 | and that 0 < ``q`` < 1. 107 | 108 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details. 109 | 110 | Args: 111 | q: Sampling rate of SGM. 112 | sigma: The standard deviation of the additive Gaussian noise. 113 | alpha: The order at which RDP is computed. 114 | 115 | Returns: 116 | :math:`log(A_\alpha)` as defined in Section 3.3 of 117 | https://arxiv.org/pdf/1908.10530.pdf. 118 | """ 119 | 120 | # Initialize with 0 in the log space. 121 | log_a = -np.inf 122 | 123 | for i in range(alpha + 1): 124 | log_coef_i = ( 125 | math.log(special.binom(alpha, i)) 126 | + i * math.log(q) 127 | + (alpha - i) * math.log(1 - q) 128 | ) 129 | 130 | s = log_coef_i + (i * i - i) / (2 * (sigma ** 2)) 131 | log_a = _log_add(log_a, s) 132 | 133 | return float(log_a) 134 | 135 | 136 | def _compute_log_a_for_frac_alpha(q: float, sigma: float, alpha: float) -> float: 137 | r"""Computes :math:`log(A_\alpha)` for fractional ``alpha``. 138 | 139 | Notes: 140 | Note that 141 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``, 142 | and that 0 < ``q`` < 1. 143 | 144 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details. 145 | 146 | Args: 147 | q: Sampling rate of SGM. 148 | sigma: The standard deviation of the additive Gaussian noise. 149 | alpha: The order at which RDP is computed. 150 | 151 | Returns: 152 | :math:`log(A_\alpha)` as defined in Section 3.3 of 153 | https://arxiv.org/pdf/1908.10530.pdf. 154 | """ 155 | # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are 156 | # initialized to 0 in the log space: 157 | log_a0, log_a1 = -np.inf, -np.inf 158 | i = 0 159 | 160 | z0 = sigma ** 2 * math.log(1 / q - 1) + 0.5 161 | 162 | while True: # do ... until loop 163 | coef = special.binom(alpha, i) 164 | log_coef = math.log(abs(coef)) 165 | j = alpha - i 166 | 167 | log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) 168 | log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) 169 | 170 | log_e0 = math.log(0.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) 171 | log_e1 = math.log(0.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) 172 | 173 | log_s0 = log_t0 + (i * i - i) / (2 * (sigma ** 2)) + log_e0 174 | log_s1 = log_t1 + (j * j - j) / (2 * (sigma ** 2)) + log_e1 175 | 176 | if coef > 0: 177 | log_a0 = _log_add(log_a0, log_s0) 178 | log_a1 = _log_add(log_a1, log_s1) 179 | else: 180 | log_a0 = _log_sub(log_a0, log_s0) 181 | log_a1 = _log_sub(log_a1, log_s1) 182 | 183 | i += 1 184 | if max(log_s0, log_s1) < -30: 185 | break 186 | 187 | return _log_add(log_a0, log_a1) 188 | 189 | 190 | def _compute_log_a(q: float, sigma: float, alpha: float) -> float: 191 | r"""Computes :math:`log(A_\alpha)` for any positive finite ``alpha``. 192 | 193 | Notes: 194 | Note that 195 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``, 196 | and that 0 < ``q`` < 1. 197 | 198 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf 199 | for details. 200 | 201 | Args: 202 | q: Sampling rate of SGM. 203 | sigma: The standard deviation of the additive Gaussian noise. 204 | alpha: The order at which RDP is computed. 205 | 206 | Returns: 207 | :math:`log(A_\alpha)` as defined in the paper mentioned above. 208 | """ 209 | if float(alpha).is_integer(): 210 | return _compute_log_a_for_int_alpha(q, sigma, int(alpha)) 211 | else: 212 | return _compute_log_a_for_frac_alpha(q, sigma, alpha) 213 | 214 | 215 | def _log_erfc(x: float) -> float: 216 | r"""Computes :math:`log(erfc(x))` with high accuracy for large ``x``. 217 | 218 | Helper function used in computation of :math:`log(A_\alpha)` 219 | for a fractional alpha. 220 | 221 | Args: 222 | x: The input to the function 223 | 224 | Returns: 225 | :math:`log(erfc(x))` 226 | """ 227 | return math.log(2) + special.log_ndtr(-x * 2 ** 0.5) 228 | 229 | 230 | def _compute_rdp(q: float, sigma: float, alpha: float) -> float: 231 | r"""Computes RDP of the Sampled Gaussian Mechanism at order ``alpha``. 232 | 233 | Args: 234 | q: Sampling rate of SGM. 235 | sigma: The standard deviation of the additive Gaussian noise. 236 | alpha: The order at which RDP is computed. 237 | 238 | Returns: 239 | RDP at order ``alpha``; can be np.inf. 240 | """ 241 | if q == 0: 242 | return 0 243 | 244 | # no privacy 245 | if sigma == 0: 246 | return np.inf 247 | 248 | if q == 1.0: 249 | return alpha / (2 * sigma ** 2) 250 | 251 | if np.isinf(alpha): 252 | return np.inf 253 | 254 | return _compute_log_a(q, sigma, alpha) / (alpha - 1) 255 | 256 | 257 | def compute_rdp( 258 | q: float, noise_multiplier: float, steps: int, orders: Union[Sequence[float], float] 259 | ) -> Union[List[float], float]: 260 | r"""Computes Renyi Differential Privacy (RDP) guarantees of the 261 | Sampled Gaussian Mechanism (SGM) iterated ``steps`` times. 262 | 263 | Args: 264 | q: Sampling rate of SGM. 265 | noise_multiplier: The ratio of the standard deviation of the 266 | additive Gaussian noise to the L2-sensitivity of the function 267 | to which it is added. Note that this is same as the standard 268 | deviation of the additive Gaussian noise when the L2-sensitivity 269 | of the function is 1. 270 | steps: The number of iterations of the mechanism. 271 | orders: An array (or a scalar) of RDP orders. 272 | 273 | Returns: 274 | The RDP guarantees at all orders; can be ``np.inf``. 275 | """ 276 | if isinstance(orders, float): 277 | rdp = _compute_rdp(q, noise_multiplier, orders) 278 | else: 279 | rdp = np.array([_compute_rdp(q, noise_multiplier, order) for order in orders]) 280 | 281 | return rdp * steps 282 | 283 | 284 | # Based on 285 | # https://github.com/tensorflow/privacy/blob/5f07198b66b3617b22609db983926e3ba97cd905/tensorflow_privacy/privacy/analysis/rdp_accountant.py#L237 286 | def get_privacy_spent(orders, rdp, delta): 287 | """Compute epsilon given a list of RDP values and target delta. 288 | Args: 289 | orders: An array (or a scalar) of orders. 290 | rdp: A list (or a scalar) of RDP guarantees. 291 | delta: The target delta. 292 | Returns: 293 | Pair of (eps, optimal_order). 294 | Raises: 295 | ValueError: If input is malformed. 296 | """ 297 | orders_vec = np.atleast_1d(orders) 298 | rdp_vec = np.atleast_1d(rdp) 299 | 300 | if delta <= 0: 301 | raise ValueError("Privacy failure probability bound delta must be >0.") 302 | if len(orders_vec) != len(rdp_vec): 303 | raise ValueError("Input lists must have the same length.") 304 | 305 | # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): 306 | # eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) ) 307 | 308 | # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4). 309 | # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1). 310 | eps_vec = [] 311 | for (a, r) in zip(orders_vec, rdp_vec): 312 | if a < 1: 313 | raise ValueError("Renyi divergence order must be >=1.") 314 | if r < 0: 315 | raise ValueError("Renyi divergence must be >=0.") 316 | 317 | if delta ** 2 + math.expm1(-r) >= 0: 318 | # In this case, we can simply bound via KL divergence: 319 | # delta <= sqrt(1-exp(-KL)). 320 | eps = 0 # No need to try further computation if we have eps = 0. 321 | elif a > 1.01: 322 | # This bound is not numerically stable as alpha->1. 323 | # Thus we have a min value of alpha. 324 | # The bound is also not useful for small alpha, so doesn't matter. 325 | eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1) 326 | else: 327 | # In this case we can't do anything. E.g., asking for delta = 0. 328 | eps = np.inf 329 | eps_vec.append(eps) 330 | 331 | idx_opt = np.argmin(eps_vec) 332 | return max(0, eps_vec[idx_opt]), orders_vec[idx_opt] 333 | -------------------------------------------------------------------------------- /private_transformers/autograd_grad_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | A large portion of this code is adapted from Opacus (https://github.com/pytorch/opacus), 18 | which is licensed under Apache License 2.0. 19 | 20 | We have modified it considerably to support ghost clipping. 21 | """ 22 | 23 | from typing import Tuple 24 | 25 | import torch 26 | import torch.nn as nn 27 | 28 | from .settings import BackwardHookMode 29 | from .supported_layers_grad_samplers import _supported_layers_grad_samplers 30 | 31 | # TODO: hooks mode should be settable based on the module. 32 | _hooks_disabled: bool = False 33 | _hooks_mode = BackwardHookMode.default 34 | 35 | 36 | def set_hooks_mode(mode): 37 | if mode not in BackwardHookMode.all(): 38 | raise ValueError(f"Unknown mode for hooks: {mode}; expected one of {BackwardHookMode.all()}.") 39 | 40 | global _hooks_mode 41 | _hooks_mode = mode # Set mode. 42 | 43 | if _hooks_mode == BackwardHookMode.ghost_grad: # Second backward pass of ghost clipping doesn't need hooks. 44 | disable_hooks() 45 | elif _hooks_mode == BackwardHookMode.ghost_norm: # First backward pass of ghost clipping needs to accumulate norms. 46 | enable_hooks() 47 | 48 | 49 | def get_hooks_mode(): 50 | global _hooks_mode 51 | return _hooks_mode 52 | 53 | 54 | def requires_grad(module: nn.Module, recurse: bool = False) -> bool: 55 | """ 56 | Checks if any parameters in a specified module require gradients. 57 | 58 | Args: 59 | module: PyTorch module whose parameters are examined 60 | recurse: Flag specifying if the gradient requirement check should 61 | be applied recursively to sub-modules of the specified module 62 | 63 | Returns: 64 | Flag indicate if any parameters require gradients 65 | """ 66 | return any(p.requires_grad for p in module.parameters(recurse)) 67 | 68 | 69 | def add_hooks(model: nn.Module, loss_reduction: str = "mean"): 70 | r""" 71 | Adds hooks to model to save activations and backprop values. 72 | The hooks will 73 | 74 | 1. save activations into ``param.activations`` during forward pass. 75 | 2. compute per-sample gradients and save them in ``param.grad_sample`` during backward pass. 76 | 77 | Args: 78 | model: Model to which hooks are added. 79 | loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. 80 | Can take values ``sum`` or ``mean``. 81 | """ 82 | if hasattr(model, "autograd_grad_sample_hooks"): 83 | raise ValueError("Trying to add hooks twice to the same model") 84 | 85 | enable_hooks() 86 | 87 | handles = [] 88 | for name, layer in model.named_modules(): 89 | if type(layer) in _supported_layers_grad_samplers: 90 | if requires_grad(layer, recurse=False): 91 | handles.append(layer.register_forward_hook(_capture_activations)) 92 | 93 | def this_backward(this_layer, grad_input, grad_output): 94 | return _capture_backprops(this_layer, grad_input, grad_output, loss_reduction) 95 | 96 | # Starting with 1.8.0, use `register_full_backward_hook`. 97 | handles.append(layer.register_backward_hook(this_backward)) 98 | 99 | model.__dict__.setdefault("autograd_grad_sample_hooks", []).extend(handles) 100 | 101 | 102 | def remove_hooks(model: nn.Module): 103 | """Removes hooks added by `add_hooks()`.""" 104 | if not hasattr(model, "autograd_grad_sample_hooks"): 105 | raise ValueError("Asked to remove hooks, but no hooks found") 106 | else: 107 | for handle in model.autograd_grad_sample_hooks: 108 | handle.remove() 109 | del model.autograd_grad_sample_hooks 110 | 111 | 112 | def disable_hooks(): 113 | """Globally disables all hooks installed by this library.""" 114 | global _hooks_disabled 115 | _hooks_disabled = True 116 | 117 | 118 | def enable_hooks(): 119 | """Globally enables all hooks installed by this library.""" 120 | global _hooks_disabled 121 | _hooks_disabled = False 122 | 123 | 124 | def _capture_activations(layer: nn.Module, inputs: Tuple, outputs: Tuple): 125 | """Forward hook handler captures and saves activations.""" 126 | if not requires_grad(layer) or not layer.training or _hooks_disabled: 127 | return 128 | 129 | if not hasattr(layer, "activations"): 130 | layer.activations = [] 131 | 132 | # This improves on original Opacus and supports additional arguments on top of the (first) activation tensor. 133 | stored_inputs = tuple(input_i.detach() if torch.is_tensor(input_i) else input_i for input_i in inputs) 134 | layer.activations.append(stored_inputs) 135 | 136 | 137 | def _capture_backprops( 138 | layer: nn.Module, 139 | inputs: Tuple[torch.Tensor], 140 | outputs: Tuple[torch.Tensor], 141 | loss_reduction: str 142 | ): 143 | """Backward hook handler captures grad_outputs.""" 144 | # This improves on the original Opacus codebase and supports multiple outputs. 145 | backprops = tuple(output_i.detach() if torch.is_tensor(output_i) else output_i for output_i in outputs) 146 | _compute_grad_sample(layer, backprops, loss_reduction) 147 | 148 | 149 | def _compute_grad_sample(layer: nn.Module, backprops: Tuple, loss_reduction: str): 150 | """Computes per-sample gradients with respect to the parameters.""" 151 | if not requires_grad(layer) or not layer.training or _hooks_disabled: 152 | return 153 | 154 | if not hasattr(layer, "activations"): 155 | raise ValueError(f"No activations detected for {type(layer)}, run forward after add_hooks(model)") 156 | 157 | # Outside of the LSTM there is "batch_first" but not for the Linear inside the LSTM 158 | if isinstance(layer.activations, list): 159 | A = layer.activations.pop() 160 | else: 161 | A = layer.activations 162 | 163 | if not hasattr(layer, "max_batch_len"): 164 | assert torch.is_tensor(A[0]), f"Internal error: first input of the following layer isn't a Tensor. \n{layer}" 165 | layer.max_batch_len = _get_batch_size(layer, A[0]) 166 | 167 | n = layer.max_batch_len 168 | if loss_reduction == "mean": 169 | B = tuple(B_i * n if torch.is_tensor(B_i) else B_i for B_i in backprops) 170 | elif loss_reduction == "sum": 171 | B = backprops 172 | else: 173 | raise ValueError(f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported") 174 | 175 | # compute grad sample for individual layers 176 | compute_layer_grad_sample = _supported_layers_grad_samplers.get(type(layer)) 177 | compute_layer_grad_sample(layer, A, B) 178 | 179 | if (not isinstance(layer.activations, list) or len(layer.activations) == 0) and hasattr(layer, "max_batch_len"): 180 | del layer.max_batch_len 181 | 182 | 183 | def _get_batch_size(layer: nn.Module, grad_sample: torch.Tensor) -> int: 184 | r""" 185 | Computes and returns the maximum batch size which is the maximum of the dimension values 186 | along 'batch_dim' axis over layer.activations + [grad_sample], where layer.activations is 187 | a list. If layer.activations is a not a list, then return grad_sample.shape[batch_dim]. 188 | """ 189 | 190 | batch_dim = 0 191 | max_batch_len = 0 192 | if isinstance(layer.activations, list): 193 | for out in layer.activations: 194 | assert torch.is_tensor(out[0]), ( 195 | f"Internal error: first input of the following layer isn't a Tensor. \n{layer}" 196 | ) 197 | if out[0].shape[batch_dim] > max_batch_len: 198 | max_batch_len = out[0].shape[batch_dim] 199 | 200 | max_batch_len = max(max_batch_len, grad_sample.shape[batch_dim]) 201 | return max_batch_len 202 | -------------------------------------------------------------------------------- /private_transformers/lora_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | LoRA layers. 17 | 18 | This version does not have merged weights for zero latency inference. It makes the code easier to read and maintain. 19 | Adapted from 20 | https://github.com/microsoft/LoRA 21 | https://www.microsoft.com/en-us/research/project/dp-transformers/ 22 | """ 23 | 24 | import torch 25 | import transformers 26 | from torch import nn 27 | 28 | 29 | class DPMergedLinear(nn.Module): 30 | def __init__( 31 | self, 32 | in_features: int, 33 | out_features: int, 34 | lora_r=0, 35 | lora_alpha=1., 36 | lora_dropout=0., 37 | ): 38 | super(DPMergedLinear, self).__init__() 39 | self.linear = nn.Linear(in_features=in_features, out_features=out_features) 40 | self.lora_r = lora_r 41 | self.lora_alpha = lora_alpha 42 | self.lora_dropout = nn.Dropout(p=lora_dropout) 43 | if self.lora_r > 0: 44 | self.lora_A = nn.Linear(in_features=in_features, out_features=lora_r, bias=False) 45 | self.lora_B = nn.Linear(in_features=lora_r, out_features=out_features, bias=False) 46 | self.scaling = self.lora_alpha / lora_r 47 | self.reset_parameters() 48 | 49 | def forward(self, x: torch.Tensor): 50 | result = self.linear(x) 51 | if self.lora_r > 0: 52 | after_dropout = self.lora_dropout(x) 53 | after_A = self.lora_A(after_dropout) 54 | after_B = self.lora_B(after_A) 55 | result += after_B * self.scaling 56 | return result 57 | 58 | def reset_parameters(self): 59 | self.linear.reset_parameters() 60 | if self.lora_r > 0: 61 | self.lora_A.reset_parameters() 62 | self.lora_B.weight.data.zero_() 63 | 64 | @staticmethod 65 | def from_transformers_conv1d( 66 | original_layer, 67 | lora_r=0, 68 | lora_alpha=1., 69 | lora_dropout=0., 70 | ) -> "DPMergedLinear": 71 | lora_layer = DPMergedLinear( 72 | in_features=original_layer.weight.shape[0], 73 | out_features=original_layer.weight.shape[1], 74 | lora_r=lora_r, 75 | lora_alpha=lora_alpha, 76 | lora_dropout=lora_dropout, 77 | ).to(original_layer.weight.device) 78 | lora_layer.linear.weight.data.copy_(original_layer.weight.T.data) 79 | lora_layer.linear.bias.data.copy_(original_layer.bias.data) 80 | return lora_layer 81 | 82 | 83 | def convert_gpt2_attention_to_lora( 84 | model: transformers.GPT2PreTrainedModel, 85 | lora_r=0, 86 | lora_alpha=1., 87 | lora_dropout=0., 88 | ) -> transformers.GPT2PreTrainedModel: 89 | if not isinstance(model, transformers.GPT2PreTrainedModel): 90 | raise TypeError("Requires a GPT2 model") 91 | 92 | if not hasattr(model, "h") and hasattr(model, "transformer"): 93 | transformer = model.transformer 94 | else: 95 | transformer = model 96 | 97 | for h_i in transformer.h: 98 | new_layer = DPMergedLinear.from_transformers_conv1d( 99 | original_layer=h_i.attn.c_attn, 100 | lora_r=lora_r, 101 | lora_alpha=lora_alpha, 102 | lora_dropout=lora_dropout, 103 | ) 104 | h_i.attn.c_attn = new_layer 105 | 106 | return model 107 | 108 | 109 | def mark_only_lora_as_trainable(model: torch.nn.Module) -> None: 110 | model.requires_grad_(True) 111 | for n, p in model.named_parameters(): 112 | if 'lora_' not in n: 113 | p.requires_grad = False 114 | -------------------------------------------------------------------------------- /private_transformers/settings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import transformers 16 | from ml_swissknife import utils 17 | 18 | 19 | class BackwardHookMode(metaclass=utils.ContainerMeta): 20 | ghost_norm = "ghost_norm" 21 | ghost_grad = "ghost_grad" 22 | default = "default" 23 | 24 | 25 | class ClippingMode(metaclass=utils.ContainerMeta): 26 | default = "default" # Global fixed. 27 | ghost = "ghost" # Global fixed clipping with ghost clipping. 28 | per_layer = "per_layer" # Per layer fixed clipping. 29 | per_layer_percentile = "per_layer_percentile" # Clip gradient per-layer based on gradient norm percentile. 30 | 31 | 32 | class AccountingMode(metaclass=utils.ContainerMeta): 33 | rdp = "rdp" 34 | glw = "glw" 35 | all_ = "all" 36 | 37 | 38 | SUPPORTED_TRANSFORMERS = ( 39 | transformers.models.openai.modeling_openai.OpenAIGPTLMHeadModel, 40 | transformers.models.openai.modeling_openai.OpenAIGPTDoubleHeadsModel, 41 | transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel, 42 | transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel, 43 | transformers.models.bert.modeling_bert.BertForSequenceClassification, 44 | transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification, 45 | transformers.models.albert.modeling_albert.AlbertForSequenceClassification, 46 | transformers.models.bart.modeling_bart.BartForConditionalGeneration, 47 | transformers.models.t5.modeling_t5.T5ForConditionalGeneration, 48 | transformers.models.opt.modeling_opt.OPTForCausalLM, 49 | transformers.models.vit.modeling_vit.ViTForImageClassification, 50 | transformers.models.deit.modeling_deit.DeiTForImageClassification, 51 | transformers.models.beit.modeling_beit.BeitForImageClassification, 52 | ) 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | 18 | import setuptools 19 | 20 | # for simplicity we actually store the version in the __version__ attribute in the source 21 | here = os.path.realpath(os.path.dirname(__file__)) 22 | with open(os.path.join(here, 'private_transformers', '__init__.py')) as f: 23 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 24 | if meta_match: 25 | version = meta_match.group(1) 26 | else: 27 | raise RuntimeError("Unable to find __version__ string.") 28 | 29 | with open(os.path.join(here, 'README.md')) as f: 30 | readme = f.read() 31 | 32 | setuptools.setup( 33 | name="private_transformers", 34 | version=version, 35 | author="Xuechen Li", 36 | author_email="lxuechen@cs.toronto.edu", 37 | description="Train Hugging Face transformers with differential privacy.", 38 | long_description=readme, 39 | url="https://github.com/lxuechen/private-transformers", 40 | packages=setuptools.find_packages(exclude=['examples', 'tests']), 41 | install_requires=[ 42 | "torch>=1.8.0", 43 | "prv-accountant", 44 | "transformers>=4.20.1", # v0.1.0 uses 4.16.2. 45 | "numpy", 46 | "scipy", 47 | "jupyterlab", 48 | "jupyter", 49 | "ml-swissknife", 50 | "opt_einsum", 51 | "pytest" 52 | ], 53 | python_requires='~=3.8', 54 | classifiers=[ 55 | "Programming Language :: Python :: 3", 56 | "License :: OSI Approved :: Apache Software License", 57 | ], 58 | ) 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Xuechen Li. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | --------------------------------------------------------------------------------