├── .gitignore
├── FAQ.md
├── LICENSE
├── README.md
├── assets
└── fig1.png
├── examples
├── __init__.py
├── classification
│ ├── README.md
│ ├── __init__.py
│ ├── data
│ │ ├── download_dataset.sh
│ │ ├── k-shot
│ │ │ └── checksum
│ │ ├── make_k_shot_without_dev.py
│ │ └── make_valid_data.py
│ ├── requirements.txt
│ ├── run_classification.py
│ ├── run_wrapper.py
│ ├── spectral_analysis
│ │ ├── 3d_surface.png
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── density.py
│ │ ├── geometric_median.py
│ │ ├── rebuttal_neurips_2022.py
│ │ ├── rebuttal_plots_neurips_2022.py
│ │ └── visuals.ipynb
│ └── src
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── compiled_args.py
│ │ ├── dataset.py
│ │ ├── label_search.py
│ │ ├── models.py
│ │ ├── processors.py
│ │ └── trainer.py
├── image_classification
│ ├── README.md
│ ├── __init__.py
│ └── main.py
└── table2text
│ ├── README.md
│ ├── __init__.py
│ ├── compiled_args.py
│ ├── data_utils
│ ├── __init__.py
│ ├── data_collator.py
│ └── language_modeling.py
│ ├── decoding_utils.py
│ ├── density.py
│ ├── misc.py
│ ├── models.py
│ ├── requirements.txt
│ ├── run.sh
│ ├── run_language_modeling.py
│ └── trainer.py
├── private_transformers
├── __init__.py
├── accounting
│ ├── __init__.py
│ ├── accounting_manager.py
│ └── rdp_accounting.py
├── autograd_grad_sample.py
├── lora_utils.py
├── privacy_engine.py
├── settings.py
├── supported_layers_grad_samplers.py
└── transformers_support.py
├── setup.py
└── tests
├── __init__.py
└── test_privacy_engine.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .idea/
132 | .DS_Store
133 |
134 | # Ignore figures and plots when uploading to github.
135 | **.png
136 | **.pdf
137 |
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | ## FAQ
2 |
3 | ### How do I perform gradient accumulation?
4 |
5 | Use `virtual_step` in combination with `step`. For example, the following gives a simplified demo of the structure:
6 |
7 | ```python
8 | import torch, transformers, private_transformers
9 |
10 | gradient_accumulation_steps = 10 # Take an update once this many iterations.
11 |
12 | batches = ... # Data.
13 | model = transformers.AutoModelWithLMHead.from_pretrained('gpt')
14 | optimizer = torch.optim.Adam(model.parameters())
15 | privacy_engine = private_transformers.PrivacyEngine(...)
16 | privacy_engine.attach(optimizer)
17 |
18 | for i, batch in enumerate(batches, 1):
19 | loss = model(batch)
20 | if i % gradient_accumulation_steps == 0:
21 | optimizer.step(loss=loss)
22 | optimizer.zero_grad()
23 | else:
24 | optimizer.virtual_step(loss=loss)
25 | ```
26 |
27 | ### What is ghost clipping?
28 |
29 | It's a per example gradient clipping (then summing) technique that avoids instantiating per example gradients. It can
30 | make private training have almost the same memory cost as non-private training.
31 | The method is based on accumulating gradient norms on a layer-by-layer basis first demonstrated
32 | in [this work](https://arxiv.org/abs/2009.03106).
33 | We implemented and extended this method so that computing gradient norms for linear layers can be cheap; this is based on a
34 | linear algebra identity that we derived in [this work](https://arxiv.org/pdf/2110.05679.pdf).
35 | [Subsequent work](https://arxiv.org/abs/2205.10683) adapted the overall approach to suit training convolutional layers.
36 |
37 | ### How did you test that ghost clipping gives the 'right' gradients?
38 |
39 | We ran stringent numerical tests to ensure the double-backward implementation is correct (e.g., remove sources of
40 | randomness like dropout and compare gradients from double backward against gradients from autodiff + for loop).
41 | Check out files in the `tests` folder for more on this.
42 |
43 | ### When can't I use ghost clipping?
44 |
45 | Ghost clipping can't handle parameter sharing, that's why in our code, we separate the lm-head out from the embedding
46 | layer for generation tasks. Similarly, it can't be applied to fine-tuning ALBERT which ties weights across many layers
47 | of the model.
48 |
49 | ### What if I want to freeze some parameters of the network while updating all others?
50 |
51 | Before creating the privacy engine and optimizer, set parts of the model which won't be optimized to
52 | have `.requires_grad=False`. The privacy engine will do the rest for you. For instance:
53 |
54 | ```python
55 | import transformers, private_transformers
56 |
57 | model = transformers.AutoModelWithLMHead.from_pretrained('gpt')
58 | # Input embeddings aren't optimized; this line needs to proceed privacy engine creation.
59 | model.get_input_embeddings().requires_grad_(False)
60 | privacy_engine = private_transformers.PrivacyEngine(model, ...)
61 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # private-transformers
2 |
3 | This codebase facilitates fast experimentation of differentially private training
4 | of [Hugging Face transformers](https://huggingface.co/transformers/).
5 |
6 | ---
7 |
8 |
9 |
10 |
11 | ## What is this? Why an extra codebase?
12 |
13 | - This codebase provides a privacy engine that builds off and rewrites [Opacus](https://github.com/pytorch/opacus) so
14 | that integration with
15 | [Hugging Face's transformers library](https://github.com/huggingface/transformers) is easy.
16 | - Additionally, we support the *ghost clipping* technique (see Section 4 of [this](https://arxiv.org/pdf/2110.05679.pdf)
17 | preprint on how it works) which allows privately training large transformers with considerably reduced memory cost --
18 | in many cases, almost as light as non-private training -- at a modest run-time overhead.
19 | - **With this codebase, we have fine-tuned very large pretrained models, yielding some of the best performing
20 | differentially private NLP models to date. Some of these models have performance matching strong non-private baseline
21 | approaches. We see strong empirical evidence that highly performant DP NLP models could be built on modest datasets.**
22 |
23 | ## Installation
24 |
25 | Make sure you have python>=3.8; run the following command:
26 |
27 | ```bash
28 | pip install git+https://github.com/lxuechen/private-transformers.git
29 | ```
30 |
31 | To check the package is installed properly, be sure to run the test suite (requires pytest and a GPU) via the following
32 | command:
33 |
34 | ```bash
35 | pytest -s tests
36 | ```
37 |
38 | ## Usage
39 |
40 | ### Basic usage
41 |
42 | Privately training Hugging Face transformers with our codebase simply consists of 4 steps:
43 |
44 | 1. Create your favourite transformer model and optimizer; attach this optimizer to a `PrivacyEngine`
45 | 2. Compute a per-example loss (1-D tensor) for a mini-batch of data
46 | 3. Pass the loss to `optimizer.step` or `optimizer.virtual_step` as a keyword argument
47 | 4. Repeat from step 2
48 |
49 | Below is a quick example:
50 |
51 | ```python
52 | import transformers, torch
53 | from private_transformers import PrivacyEngine
54 | import torch.nn.functional as F
55 |
56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57 | model = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
58 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
59 | privacy_engine = PrivacyEngine(
60 | model,
61 | batch_size=10,
62 | sample_size=50000,
63 | epochs=3,
64 | max_grad_norm=0.1,
65 | target_epsilon=3,
66 | )
67 | privacy_engine.attach(optimizer)
68 |
69 | batch_size, seq_len = 10, 20
70 | # Inputs are batch-first format, i.e., the first dimension of tensors must be batch dimension.
71 | input_ids = torch.randint(size=[batch_size, seq_len], low=0, high=100, device=device)
72 | # Calling `.train()` is very important; otherwise underlying forward and backward hooks don't run.
73 | model.train()
74 | outputs = model(input_ids=input_ids, return_dict=True)
75 | labels = input_ids[:, 1:, ]
76 | logits = outputs.logits[:, :-1, :].permute(0, 2, 1)
77 | # `loss` is a 1-D tensor of shape (batch_size,).
78 | loss = F.cross_entropy(logits, labels, reduction="none").mean(dim=1)
79 | # This step is different from existing workflows:
80 | # Don't call `loss.backward`; leave it to `optimizer.step` to handle backward.
81 | optimizer.step(loss=loss)
82 | ```
83 |
84 | The biggest differences compared to Opacus are:
85 |
86 | - We require the per-example loss (a 1-D tensor) be passed into `optimizer.step` (or `optimizer.virtual_step`).
87 | - The per-example loss must be passed in as a *keyword argument*.
88 | - `loss.backward()` shouldn't be called on the user end; it's called internally in `optimizer.step` (
89 | or `optimizer.virtual_step`).
90 | - Inputs should be in batch-first format; there isn't a toggle to switch between different formats in the engine.
91 |
92 | ### Ghost clipping: memory saving differentially private learning
93 |
94 | Turning on ghost clipping requires changing only 1 line. You should notice a drastic reduction in peak GPU memory usage
95 | once this is turned on, at a potential cost of slower training speed. One might find this especially useful when
96 | constrained to only use older GPUs with small VRAMs or fitting super large models.
97 |
98 | ```python
99 | import transformers, torch
100 | from private_transformers import PrivacyEngine
101 |
102 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
103 | model = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
104 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
105 | privacy_engine = PrivacyEngine(
106 | model,
107 | batch_size=10,
108 | sample_size=50000,
109 | epochs=3,
110 | max_grad_norm=0.1,
111 | target_epsilon=3,
112 | clipping_mode="ghost", # The only change you need to make!
113 | )
114 | privacy_engine.attach(optimizer)
115 | ```
116 |
117 | ### Examples
118 |
119 | Code in the `examples` folder roughly reproduces our results for the table-to-text and classification tasks. There may
120 | be some minor discrepancies, since hyperparameters there aren't exactly what's used in the paper. Nevertheless, it
121 | should be sufficient to get things started. Detailed instructions are in the readme file of each subfolder.
122 |
123 | ### Currently supported [Hugging Face models](https://huggingface.co/transformers/pretrained_models.html)
124 |
125 | - [OpenAIGPTLMHeadModel](https://huggingface.co/docs/transformers/model_doc/openai-gpt#transformers.OpenAIGPTLMHeadModel)
126 | - [OpenAIGPTDoubleHeadsModel](https://huggingface.co/docs/transformers/model_doc/openai-gpt#transformers.OpenAIGPTDoubleHeadsModel)
127 | - [GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel)
128 | - [GPT2DoubleHeadsModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2DoubleHeadsModel)
129 | - [BertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForSequenceClassification)
130 | - [RobertaForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaForSequenceClassification)
131 | - [AlbertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertForSequenceClassification)
132 | - [BartForConditionalGeneration](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartForConditionalGeneration)
133 | (when positional embedding layers are frozen)
134 | - [T5ForConditionalGeneration](https://huggingface.co/docs/transformers/v4.20.1/en/model_doc/t5#transformers.T5ForConditionalGeneration)
135 | - [OPTForCausalLM](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTForCausalLM)
136 | - [ViTForImageClassification](https://huggingface.co/docs/transformers/v4.20.1/en/model_doc/vit#transformers.ViTForImageClassification)
137 | (when isolated parameters are frozen; see [this example](examples/image_classification/main.py))
138 | - [DeiTForImageClassification](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTForImageClassification)
139 | (when isolated parameters are frozen)
140 | - [BeitForImageClassification](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitForImageClassification)
141 | (when isolated parameters are frozen)
142 |
143 | Not all models in the Hugging Face library are supported. The main additional work to support a model is to
144 |
145 | 1. Support per-example gradients for bespoke modules
146 | (e.g., [T5LayerNorm](https://huggingface.co/transformers/_modules/transformers/modeling_t5.html)), and
147 | 2. Ensure `position_ids` are repeated (duplicated along batch dim 0). Normally, to save memory, one creates positional
148 | embedding for one instance and rely on broadcasting when there're multiple instances within a batch. This creates a
149 | problem with per-sample gradient accumulation, so we instead duplicate inputs to positional embeddings.
150 |
151 | We plan to support more models in the future if there's such a need. Feel free to open an issue if you may want to try
152 | out specific models that aren't in the current list.
153 |
154 | ## FAQ
155 |
156 | I wrote some stuff to potential questions [here](https://github.com/lxuechen/private-transformers/blob/main/FAQ.md).
157 | These include performing gradient accumulation, ghost clipping, and freezing parts of a model.
158 |
159 | ## Acknowledgements
160 |
161 | It would have been impossible to develop this codebase without cool past works and existing codebases. We roughly follow
162 | the `PrivacyEngine` design in `Opacus==0.13.0`. We directly use
163 | an [off-the-shelf package](https://github.com/microsoft/prv_accountant) for tightly tracking tradeoff functions while
164 | composing multiple private mechanisms.
165 |
166 | ## Disclaimer
167 |
168 | - This codebase is not yet production-grade, e.g., cryptographically secure PRNGs are required for sampling noise -- our
169 | codebase currently does not use these strong PRNGs as they tend to slow down training. This codebase also isn't immune
170 | to [floating point representation attacks](https://github.com/pytorch/opacus/pull/260).
171 | - This codebase is born out of the need to experiment with various things for differentially private NLP rapidly. I've
172 | tried my best to write clean code, though parts of this codebase may be less tidy than I had hoped
173 | given the extremely tight timeline.
174 |
175 | ## Citation
176 |
177 | If you found this codebase useful in your research, please consider citing:
178 |
179 | ```
180 | @inproceedings{
181 | li2022large,
182 | title={Large Language Models Can Be Strong Differentially Private Learners},
183 | author={Xuechen Li and Florian Tramer and Percy Liang and Tatsunori Hashimoto},
184 | booktitle={International Conference on Learning Representations},
185 | year={2022},
186 | url={https://openreview.net/forum?id=bVuP3ltATMz}
187 | }
188 |
189 | @inproceedings{
190 | li2022when,
191 | title={When Does Differentially Private Learning Not Suffer in High Dimensions?},
192 | author={Xuechen Li and Daogao Liu and Tatsunori Hashimoto and Huseyin A Inan and Janardhan Kulkarni and YinTat Lee and Abhradeep Guha Thakurta},
193 | booktitle={Advances in Neural Information Processing Systems},
194 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
195 | year={2022},
196 | url={https://openreview.net/forum?id=FR--mkQu0dw}
197 | }
198 | ```
199 |
--------------------------------------------------------------------------------
/assets/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxuechen/private-transformers/18ccc4eab7355e4ac96051a82434796f6aa4624b/assets/fig1.png
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/classification/README.md:
--------------------------------------------------------------------------------
1 | ## Reproducing results for sentence classification
2 |
3 | ### Requirements
4 |
5 | In addition to requirements of the `private-transformers` package, install additional requirements by running the
6 | following from the `examples` folder of this repo:
7 |
8 | ```bash
9 | pip install -r classification/requirements.txt
10 | ```
11 |
12 | This code is tested against `transformers==4.11.3`, but should also work for slightly earlier versions.
13 |
14 | ### Getting the data
15 |
16 | This part of the codebase is adapted from the excellent work
17 | by [[Gao et al., 2021](https://arxiv.org/pdf/2012.15723.pdf)]. We reuse their data pipeline. To obtain the data, run the
18 | following:
19 |
20 | ```bash
21 | cd data
22 | bash download_dataset.sh
23 | ```
24 |
25 | This should produce a `data/original` subfolder that contains all the data that we need.
26 |
27 | ### Running
28 |
29 | Use the `run_wrapper.py` script in the folder. This Python script produces a text string for the command and runs it.
30 |
31 | Supply at least 2 arguments:
32 |
33 | - `--output_dir`: path to a folder where results will be written
34 | - `--task_name`: name of task; one of `sst-2`, `qnli`, `qqp`, `mnli`
35 |
36 | For instance, run the following under the `examples/` folder:
37 |
38 | ```bash
39 | python -m classification.run_wrapper --output_dir --task_name
40 | ```
41 |
42 | The script by default uses ghost clipping, and the micro batch size is tweaked so that things should run smoothly even
43 | on a Titan Xp with 12Gigs of VRAM. For SST-2, the run-time of this script on an RTX 3090 is roughly less than one and a
44 | half hours. Larger datasets take longer to train.
45 |
46 | Additional arguments:
47 |
48 | - `--target_epsilon`: Target privacy spending
49 | - `--model_name_or_path`: The pretrained model; one of `distilbert-base-uncased`, `bert-base-uncased`
50 | , `bert-large-uncased`, `distilroberta-base`, `roberta-base`, `roberta-large`
51 | - `--few_shot_type`: Whether to use the generic prompt formatter described in Section 3.2 of our paper. `prompt` is to
52 | use, `finetune` is to not use.
53 | - `--ghost_clipping`: Whether to use ghost clipping for memory saving; one of `yes`, `no`
54 | Note keeping other training hyperparameter (e.g., number of training epochs, clipping norm, learning rate) the same,
55 | things should still work
56 | - `--data_dir`: Path to where data is stored; if data is obtained via the procedure described above, just stick to the
57 | defaults.
58 |
59 | Training on the larger datasets for even more epochs should bring further performance gains.
60 |
61 | ### Notes
62 |
63 | - We have reproduced some results in our paper with the codebase of
64 | a [concurrent anonymous submission](https://openreview.net/pdf?id=Q42f0dfjECO). Our modified version of their codebase
65 | is located at [this link](https://github.com/lxuechen/Differentially-Private-Fine-tuning-of-Language-Models). This
66 | code is modified from their original codebase and only optimizes the dense/linear layers in a Transformer model, and
67 | hence is not strictly full fine-tuning (since the embedding and LayerNorm layers aren't updated). The main difference
68 | from their original setup is that we run everything in full precision (i.e., fp32), not mixed-precision.
69 | - We got similar results as those reported in the paper with Opacus, but with the embedding subnetworks (word embedding,
70 | positional embedding, token type embedding) frozen. Note that unfreezing the embedding subnetwork and plugging such a
71 | model (from HF) into Opacus would result in errors, due to how HF transformers are implemented.
72 |
--------------------------------------------------------------------------------
/examples/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/classification/data/download_dataset.sh:
--------------------------------------------------------------------------------
1 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar
2 | tar xvf datasets.tar
3 |
--------------------------------------------------------------------------------
/examples/classification/data/k-shot/checksum:
--------------------------------------------------------------------------------
1 | fe166f3b9e9952cb729a8d2c3bd682ae SST-2/16-13/train.tsv
2 | 0eeeb2195d82899d15ab5eeaddcaa944 SST-2/16-13/dev.tsv
3 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-13/test.tsv
4 | 6661938fcefe356e8c9c3f866b67c047 SST-2/16-21/train.tsv
5 | 74a12c79dea73deb5526c398b1bc7bc0 SST-2/16-21/dev.tsv
6 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-21/test.tsv
7 | edce4c52734c2084476ac03bfdfa5c77 SST-2/16-42/train.tsv
8 | 4ba7f2ffe0fc89844d1184f294f15774 SST-2/16-42/dev.tsv
9 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-42/test.tsv
10 | 783b9e625ec6dea210cc4062d464aba2 SST-2/16-87/train.tsv
11 | 2972df0d5eace29b48eb8ce77770c3d5 SST-2/16-87/dev.tsv
12 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-87/test.tsv
13 | f5079b5b5bd27087a7d836e7e8faedc7 SST-2/16-100/train.tsv
14 | 1d65f0f67026bbdcbfef6ed3ada2510a SST-2/16-100/dev.tsv
15 | aeac355ccdb43bc747e816eff2d74aa4 SST-2/16-100/test.tsv
16 | c446d010f7ea4846be28d58d81169648 sst-5/16-13/train.csv
17 | aca7b0130f532a7a3289feaf2c9cf510 sst-5/16-13/dev.csv
18 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-13/test.csv
19 | e3d8f0d8ee2c3d19cbb8cf370a80c0ff sst-5/16-21/train.csv
20 | a673a34db10fa7ea40bee8d635a58d71 sst-5/16-21/dev.csv
21 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-21/test.csv
22 | c53271e5fbb67f49ff6ccec9058f4198 sst-5/16-42/train.csv
23 | 040144885c068d74bfee3415ee77a9f9 sst-5/16-42/dev.csv
24 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-42/test.csv
25 | c77128dc79022b7d645e587823c90fa9 sst-5/16-87/train.csv
26 | 74a356f47dbeaaba86dc87d6de8f822b sst-5/16-87/dev.csv
27 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-87/test.csv
28 | bee588d86301e040f13b2e26d3cbd3bc sst-5/16-100/train.csv
29 | 3de890aa28b082bb00c88e92329d9736 sst-5/16-100/dev.csv
30 | 08e609c5c95d32903967c6af50a36d40 sst-5/16-100/test.csv
31 | 77e589985c38f642b743c00f454a8f72 mr/16-13/train.csv
32 | 0487ece2b558b211787050177529cc15 mr/16-13/dev.csv
33 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-13/test.csv
34 | 50f61d4ad574c862103a281cfb277568 mr/16-21/train.csv
35 | 36a5b9a0aef14b5ef85a794a79f7a670 mr/16-21/dev.csv
36 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-21/test.csv
37 | 7f2be6d8286c853be70d6b731529110e mr/16-42/train.csv
38 | f331ad440925ff5550065bafa4d1e14e mr/16-42/dev.csv
39 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-42/test.csv
40 | 9ad43ad9a35fd4fe88a8d74c164f3056 mr/16-87/train.csv
41 | 7650a096ae10bc0d22de08d28b080736 mr/16-87/dev.csv
42 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-87/test.csv
43 | edab3bd41c79c073dda7469f31761202 mr/16-100/train.csv
44 | d241fa66a9302aa417876971fcc309a1 mr/16-100/dev.csv
45 | 2ab6b6176555e141e7682bfbd17626c4 mr/16-100/test.csv
46 | fef4839c9b1532845efe85f459c4657a cr/16-13/train.csv
47 | 543574c5b1fe954e1d82f109a060b8a1 cr/16-13/dev.csv
48 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-13/test.csv
49 | 890b0ed3d8f57ed19e90884580cc08af cr/16-21/train.csv
50 | 02246fc77e862f8d973058478d89bfb9 cr/16-21/dev.csv
51 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-21/test.csv
52 | 09260d2c41f4fbf48fad2cdb09b9bd19 cr/16-42/train.csv
53 | 23f70a47d55def2c4308cebcdaf1f5b9 cr/16-42/dev.csv
54 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-42/test.csv
55 | 2c47c48c1223ff979f2e6e8243fa3219 cr/16-87/train.csv
56 | 1f0f5e7a984740751ea2b5314f847566 cr/16-87/dev.csv
57 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-87/test.csv
58 | 6c11804b15bbe6582ebf0eab71513d0d cr/16-100/train.csv
59 | 844e4d9cbc956e9495dfd77e3881caba cr/16-100/dev.csv
60 | 829c32e754d756b5da2b6d88cde7bc1a cr/16-100/test.csv
61 | 18764f2c39c6fa88b4c3871d79ea3a99 mpqa/16-13/train.csv
62 | 5d3c82a72ae0319afbe940e4d35ec402 mpqa/16-13/dev.csv
63 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-13/test.csv
64 | 1e71fb7b6def9a89678f350c92f386f7 mpqa/16-21/train.csv
65 | 37a581edfaa917b4975777d1bb531dad mpqa/16-21/dev.csv
66 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-21/test.csv
67 | 4d889a2e53e2fc9f89aef415e52c4376 mpqa/16-42/train.csv
68 | d58bd2e6128dd61460e597b41aed4073 mpqa/16-42/dev.csv
69 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-42/test.csv
70 | a11bd12d360e1a0c66fe5823b95d40db mpqa/16-87/train.csv
71 | a9aea0b29f42d955c430249b5aba3ff4 mpqa/16-87/dev.csv
72 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-87/test.csv
73 | d8b8864485fef5582b862e999f1d78fa mpqa/16-100/train.csv
74 | d05230e664b98d5c368a4616fed106b8 mpqa/16-100/dev.csv
75 | b69942543fec7bb2ac8ba1c5f08d6404 mpqa/16-100/test.csv
76 | 9007424294b4a73428032b4ba872d798 subj/16-13/train.csv
77 | ed1c7a9a31cb8ed53798e5d5c0a0ba6a subj/16-13/dev.csv
78 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-13/test.csv
79 | e91660061eac76f4d2496f0ee3a17a55 subj/16-21/train.csv
80 | 9bb1c6a470f9c437734d5000a9887298 subj/16-21/dev.csv
81 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-21/test.csv
82 | 5ad5810af6f69c0cae872f9b8ac96e4b subj/16-42/train.csv
83 | d689743060587300b50ba25303cee457 subj/16-42/dev.csv
84 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-42/test.csv
85 | 116db5d05fb7152af549b322138ddf0c subj/16-87/train.csv
86 | 8fcbfc9937902ed65b5f457db143eeae subj/16-87/dev.csv
87 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-87/test.csv
88 | dafc9f79ea2643f9d0db4bb9cd50ed68 subj/16-100/train.csv
89 | d205bd3e890483010b73b40fe2341874 subj/16-100/dev.csv
90 | 94c0bea4a5ff6092de1dcda7c46c7b65 subj/16-100/test.csv
91 | 50cff1542195f5321cb951ca1d980ec0 trec/16-13/train.csv
92 | 97ca7ed4ec756c13b6233c2d66c97ee5 trec/16-13/dev.csv
93 | b4e08b69eb3aae325197216450e3a28b trec/16-13/test.csv
94 | 178b968a79785571398a99082f2059aa trec/16-21/train.csv
95 | df1be41d4a11a1f10b99bfa9dab069fd trec/16-21/dev.csv
96 | b4e08b69eb3aae325197216450e3a28b trec/16-21/test.csv
97 | f17a8bb90b8110ad37632ded1dc0b515 trec/16-42/train.csv
98 | 682e1b9724b472132249396e826c7741 trec/16-42/dev.csv
99 | b4e08b69eb3aae325197216450e3a28b trec/16-42/test.csv
100 | efe409aae7de9e639e5da33430626e83 trec/16-87/train.csv
101 | 62eb496a54c2c810703c8a3d47213237 trec/16-87/dev.csv
102 | b4e08b69eb3aae325197216450e3a28b trec/16-87/test.csv
103 | 5df9333e324779c32bdc14eed7a3f51c trec/16-100/train.csv
104 | 983ce074f8cc101e7539122d20dc2815 trec/16-100/dev.csv
105 | b4e08b69eb3aae325197216450e3a28b trec/16-100/test.csv
106 | 8b3b16bdc7ccdf7c5e0397caf1ce61c1 CoLA/16-13/train.tsv
107 | 2a3b6875b91447043e5fc7366d6d6c59 CoLA/16-13/dev.tsv
108 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-13/test.tsv
109 | 9af338ce1d82a300b1d09e06da913454 CoLA/16-21/train.tsv
110 | ad4bd1c5d1bbc9cb149cc10648ade05f CoLA/16-21/dev.tsv
111 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-21/test.tsv
112 | 35870db90cf9b4044159d49cf4c4a620 CoLA/16-42/train.tsv
113 | 6f2b371c0fa209c253d49e20941d44c5 CoLA/16-42/dev.tsv
114 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-42/test.tsv
115 | f0260a425124e254cef6c41e78c3b990 CoLA/16-87/train.tsv
116 | 30975efe3642c02d2afaf6fa6849b5b1 CoLA/16-87/dev.tsv
117 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-87/test.tsv
118 | c02bce246a4481b2733522d82200d81e CoLA/16-100/train.tsv
119 | fda65a964d3ebeec632f8c45fac8209e CoLA/16-100/dev.tsv
120 | c5475ccefc9e7ca0917294b8bbda783c CoLA/16-100/test.tsv
121 | 6c61998bef38ccadacc2df0013035715 MNLI/16-13/train.tsv
122 | 5e84a205fcb8c20d8706df7e5cc7c03f MNLI/16-13/dev_matched.tsv
123 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-13/test_matched.tsv
124 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-13/test_mismatched.tsv
125 | 139ce8bcc9b973b164fe2065a507819a MNLI/16-21/train.tsv
126 | 5da136973f2ccb8c3cd421c74cf343e8 MNLI/16-21/dev_matched.tsv
127 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-21/test_matched.tsv
128 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-21/test_mismatched.tsv
129 | 5adecf01452cdd271940d5d7f094c77e MNLI/16-42/train.tsv
130 | c51ee1e3afccc39cd4f4fd660f6b547f MNLI/16-42/dev_matched.tsv
131 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-42/test_matched.tsv
132 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-42/test_mismatched.tsv
133 | c2540989877d39d56aa34f503bf8d145 MNLI/16-87/train.tsv
134 | f53efd36c22ed01af9b61ee6f917a714 MNLI/16-87/dev_matched.tsv
135 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-87/test_matched.tsv
136 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-87/test_mismatched.tsv
137 | 288a569addee0a37675ba0d27bb5cf9a MNLI/16-100/train.tsv
138 | 2139f9ecc2aabf761ca904a7ecfa2652 MNLI/16-100/dev_matched.tsv
139 | c3fa2817007f4cdf1a03663611a8ad23 MNLI/16-100/test_matched.tsv
140 | b219e6fe74e4aa779e2f417ffe713053 MNLI/16-100/test_mismatched.tsv
141 | 10672f831d104166efe0d3678303f612 SNLI/16-13/train.tsv
142 | 85e3b9d4646c19709cdd09a36c6dda7e SNLI/16-13/dev.tsv
143 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-13/test.tsv
144 | b2dbd54b5152c755cda4c0b3b5f0f0f2 SNLI/16-21/train.tsv
145 | 204397b61efbfe84a731183340c935df SNLI/16-21/dev.tsv
146 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-21/test.tsv
147 | 684cdeec26f1150b412a6fe42b203215 SNLI/16-42/train.tsv
148 | fccf683689b3276582041c2b1108f8fe SNLI/16-42/dev.tsv
149 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-42/test.tsv
150 | 442e9e4ab7372b1b15ca7f40d87559f0 SNLI/16-87/train.tsv
151 | f4359f6d8c33af80c22e089a63fd4c24 SNLI/16-87/dev.tsv
152 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-87/test.tsv
153 | 2b68c0dff2302bd3b013a1de14d20f1b SNLI/16-100/train.tsv
154 | b0bf433e9c0f6b5456ae5833dee41eee SNLI/16-100/dev.tsv
155 | 39d5338b0c2299bde98a3cdd0c4d04a5 SNLI/16-100/test.tsv
156 | 4fd9e83941409d616f67a4fb6e68206b QNLI/16-13/train.tsv
157 | a54657913dd71710bb0eff9b8c044746 QNLI/16-13/dev.tsv
158 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-13/test.tsv
159 | f316c511c6990c2648b821b99351f185 QNLI/16-21/train.tsv
160 | f7dd6b59f755d86a00395d2ee7fc4938 QNLI/16-21/dev.tsv
161 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-21/test.tsv
162 | edbea1a802d48db725c61d487597ab60 QNLI/16-42/train.tsv
163 | e0d03f6166b24a8c16985e32a8315427 QNLI/16-42/dev.tsv
164 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-42/test.tsv
165 | 2bcc47b08f67553f5e9f45991af07d10 QNLI/16-87/train.tsv
166 | 2b7f1fc57c802433c59357f659888a08 QNLI/16-87/dev.tsv
167 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-87/test.tsv
168 | 85c42afc824ac7a0474ccc42239f664c QNLI/16-100/train.tsv
169 | 9f193ffb9e102d28a25524f9a778b819 QNLI/16-100/dev.tsv
170 | 1e81e211959605f144ba6c0ad7dc948b QNLI/16-100/test.tsv
171 | 0aaae3483435626089adbc4ae851b75d RTE/16-13/train.tsv
172 | 8fc45ab58e72e16cf90547710f3b5d15 RTE/16-13/dev.tsv
173 | 973cb4178d4534cf745a01c309d4a66c RTE/16-13/test.tsv
174 | 79490f931a776fad7f3b2ba31f9142d1 RTE/16-21/train.tsv
175 | 96ac052f30101901e6bc04adb4875edb RTE/16-21/dev.tsv
176 | 973cb4178d4534cf745a01c309d4a66c RTE/16-21/test.tsv
177 | 4d45391333d2f56d16633d7b85d3e051 RTE/16-42/train.tsv
178 | a64069d0ca424995d35e05e67c03ceb9 RTE/16-42/dev.tsv
179 | 973cb4178d4534cf745a01c309d4a66c RTE/16-42/test.tsv
180 | 46af88a12bc44c12a40ea2cb79e94837 RTE/16-87/train.tsv
181 | 19ee5593b53d1b290737eb9de2ca4869 RTE/16-87/dev.tsv
182 | 973cb4178d4534cf745a01c309d4a66c RTE/16-87/test.tsv
183 | 548d8775d98ca0638615b9e0bfbe5c32 RTE/16-100/train.tsv
184 | bc66584b1fa19f9620051e8db7bcc1dc RTE/16-100/dev.tsv
185 | 973cb4178d4534cf745a01c309d4a66c RTE/16-100/test.tsv
186 | 4ee5a1591a6b385ed235dff1fe0bc2ff MRPC/16-13/train.tsv
187 | 55da0a2ef25c0bdda9ba0bbe1a757644 MRPC/16-13/dev.tsv
188 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-13/test.tsv
189 | 17f3132a74d23a122c8e8bc6936de492 MRPC/16-21/train.tsv
190 | 2f0e716c6a5ddaf45d03d0b48075fd5a MRPC/16-21/dev.tsv
191 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-21/test.tsv
192 | f96a9434f7cdef032f544f405498c266 MRPC/16-42/train.tsv
193 | a2a9648bec87318674338caed2538091 MRPC/16-42/dev.tsv
194 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-42/test.tsv
195 | cad49c8459f1eef52093b9527ec87735 MRPC/16-87/train.tsv
196 | b9664f662b04ec18feb8605462a8a1cb MRPC/16-87/dev.tsv
197 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-87/test.tsv
198 | 29d9a7598589c1636f64c7cf3dd25de8 MRPC/16-100/train.tsv
199 | f65a7cbec81d738f594db0040fb32610 MRPC/16-100/dev.tsv
200 | 185958e46ba556b38c6a7cc63f3a2135 MRPC/16-100/test.tsv
201 | 89173509ae29f234632fe2e7805860ba QQP/16-13/train.tsv
202 | 14a3c59ae8a5298372923dadc3c86012 QQP/16-13/dev.tsv
203 | cff6a448d1580132367c22fc449ec214 QQP/16-13/test.tsv
204 | ce7ed8d40a8e0c090b8c4634914ef895 QQP/16-21/train.tsv
205 | c29687f39ece2c692efca3ade0b91912 QQP/16-21/dev.tsv
206 | cff6a448d1580132367c22fc449ec214 QQP/16-21/test.tsv
207 | 88f2cc2d8a4c7b691099a47794ce2aee QQP/16-42/train.tsv
208 | c1b2198ba8a84976bb1084c164418234 QQP/16-42/dev.tsv
209 | cff6a448d1580132367c22fc449ec214 QQP/16-42/test.tsv
210 | 7110878494e561015e94fe459115c92c QQP/16-87/train.tsv
211 | cb7a47bad6671d75a93ec2db13c7964a QQP/16-87/dev.tsv
212 | cff6a448d1580132367c22fc449ec214 QQP/16-87/test.tsv
213 | b6a7ea69840e305608faabd068a74f06 QQP/16-100/train.tsv
214 | 136f4a247a897a80285a1e215a850d0a QQP/16-100/dev.tsv
215 | cff6a448d1580132367c22fc449ec214 QQP/16-100/test.tsv
216 | 634047695a8036cfe2523cbbc58f2fbb STS-B/16-13/train.tsv
217 | f17c612f0da73b1599a9946b1e8c7d06 STS-B/16-13/dev.tsv
218 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-13/test.tsv
219 | af75494229a00b1ea764e7e38a3fbcf8 STS-B/16-21/train.tsv
220 | a79b76d7ee8efcb7cdefb4056b0606b8 STS-B/16-21/dev.tsv
221 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-21/test.tsv
222 | c4577d13b90f33d9a61ca7abb8126132 STS-B/16-42/train.tsv
223 | 08b70fdb9484181f9331409329cd2259 STS-B/16-42/dev.tsv
224 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-42/test.tsv
225 | c1a41f8d2a0867802920e447b4e6e73b STS-B/16-87/train.tsv
226 | 430153037049cab16d6141e766b4cca2 STS-B/16-87/dev.tsv
227 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-87/test.tsv
228 | 5f8d5211ff72f49ebf181b63e6390449 STS-B/16-100/train.tsv
229 | d826c4b3fa6d419aed50442b7f1f6adc STS-B/16-100/dev.tsv
230 | 5f4d6b0d2a5f268b1b56db773ab2f1fe STS-B/16-100/test.tsv
231 |
--------------------------------------------------------------------------------
/examples/classification/data/make_k_shot_without_dev.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """The datasets in the k-shot folder contain dev.tsv; we make the test set the dev set in the new k-shot.
16 |
17 | python -m classification.data.make_k_shot_without_dev
18 | """
19 | import os
20 |
21 | from ml_swissknife import utils
22 |
23 | join = os.path.join
24 |
25 | base_dir = '/nlp/scr/lxuechen/data/lm-bff/data/k-shot'
26 | new_dir = '/nlp/scr/lxuechen/data/lm-bff/data/k-shot-no-dev'
27 |
28 | task_names = ("SST-2", "QNLI", "MNLI", "QQP")
29 | for task_name in task_names:
30 | folder = join(base_dir, task_name)
31 | new_folder = join(new_dir, task_name)
32 |
33 | for name in utils.listdir(folder):
34 | subfolder = join(folder, name)
35 | new_subfolder = join(new_folder, name)
36 | os.makedirs(new_subfolder, exist_ok=True)
37 |
38 | train = join(subfolder, 'train.tsv')
39 | new_train = join(new_subfolder, 'train.tsv')
40 | os.system(f'cp {train} {new_train}')
41 |
42 | if task_name == "MNLI":
43 | test = join(subfolder, 'test_matched.tsv')
44 | new_dev = join(new_subfolder, 'dev_matched.tsv')
45 | os.system(f'cp {test} {new_dev}')
46 |
47 | test = join(subfolder, 'test_mismatched.tsv')
48 | new_dev = join(new_subfolder, 'dev_mismatched.tsv')
49 | os.system(f'cp {test} {new_dev}')
50 | else:
51 | test = join(subfolder, 'test.tsv')
52 | new_dev = join(new_subfolder, 'dev.tsv')
53 | os.system(f'cp {test} {new_dev}')
54 |
--------------------------------------------------------------------------------
/examples/classification/data/make_valid_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Make the separate validation data, so that we don't tune on dev set.
16 |
17 | python -m classification.data.make_valid_data
18 | """
19 | import os
20 |
21 | import fire
22 | import numpy as np
23 | import tqdm
24 |
25 |
26 | def write_lines(path, lines, mode="w"):
27 | os.makedirs(os.path.dirname(path), exist_ok=True)
28 | with open(path, mode) as f:
29 | f.writelines(lines)
30 | print(len(lines))
31 |
32 |
33 | def main():
34 | valid_percentage = 0.1
35 | original_dir = "/nlp/scr/lxuechen/data/lm-bff/data/original"
36 | new_dir = "/nlp/scr/lxuechen/data/lm-bff/data/glue-with-validation"
37 |
38 | task_folders = ("GLUE-SST-2", "QNLI", "QQP")
39 | for task_folder in task_folders:
40 | # Create train and valid splits.
41 | full_train_path = os.path.join(original_dir, task_folder, 'train.tsv')
42 | with open(full_train_path, 'r') as f:
43 | full_train = f.readlines()
44 |
45 | header = full_train[0]
46 | full_train = full_train[1:] # Remove header.
47 |
48 | indices = np.random.permutation(len(full_train))
49 | new_valid_size = int(len(indices) * valid_percentage)
50 | new_train_size = len(indices) - new_valid_size
51 | new_train_indices = indices[:new_train_size]
52 | new_valid_indices = indices[new_train_size:]
53 | assert len(new_train_indices) == new_train_size
54 | assert len(new_valid_indices) == new_valid_size
55 |
56 | new_train = [header] + [full_train[i] for i in new_train_indices]
57 | new_valid = [header] + [full_train[i] for i in new_valid_indices]
58 |
59 | new_train_path = os.path.join(new_dir, task_folder, 'train.tsv')
60 | new_valid_path = os.path.join(new_dir, task_folder, 'dev.tsv')
61 |
62 | write_lines(new_train_path, new_train)
63 | write_lines(new_valid_path, new_valid)
64 | del new_train, new_valid, new_train_path, new_valid_path
65 | del new_train_size, new_train_indices
66 | del new_valid_size, new_valid_indices
67 |
68 | # Make test!
69 | test_path = os.path.join(original_dir, task_folder, 'dev.tsv')
70 | new_test_path = os.path.join(new_dir, task_folder, 'test.tsv')
71 | os.system(f'cp {test_path} {new_test_path}')
72 | del test_path, new_test_path
73 |
74 | # Make valid set for MNLI; different, since matched/mismatched!
75 | task_folder = "MNLI"
76 | matched_genres = ['slate', 'government', 'telephone', 'travel', 'fiction']
77 | mismatched_genres = ['letters', 'verbatim', 'facetoface', 'oup', 'nineeleven']
78 | full_train_path = os.path.join(original_dir, task_folder, 'train.tsv')
79 | with open(full_train_path, 'r') as f:
80 | full_train = f.readlines()
81 | full_train_csv = [line.split('\t') for line in full_train]
82 |
83 | # Check the lengths are correct.
84 | l = len(full_train_csv[0])
85 | for line in full_train_csv:
86 | assert l == len(line)
87 |
88 | # Remove header.
89 | header = full_train[0]
90 | header_csv = full_train_csv[0]
91 |
92 | full_train = full_train[1:]
93 | full_train_csv = full_train_csv[1:]
94 |
95 | # Get index of genre.
96 | genre_index = header_csv.index('genre')
97 |
98 | # Shuffle both!
99 | indices = np.random.permutation(len(full_train))
100 | full_train = [full_train[i] for i in indices]
101 | full_train_csv = [full_train_csv[i] for i in indices]
102 |
103 | # Split validation.
104 | new_valid_size = int(len(indices) * valid_percentage)
105 | new_matched_valid_size = new_mismatched_valid_size = new_valid_size // 2
106 |
107 | # Fetch the indices.
108 | new_train_indices = []
109 | new_matched_valid_indices = []
110 | new_mismatched_valid_indices = []
111 | matched_count = mismatched_count = 0
112 | for i, row in enumerate(full_train_csv):
113 | genre = row[genre_index]
114 | if genre in matched_genres and matched_count < new_matched_valid_size:
115 | new_matched_valid_indices.append(i)
116 | matched_count += 1
117 | elif genre in mismatched_genres and mismatched_count < new_mismatched_valid_size:
118 | new_mismatched_valid_indices.append(i)
119 | mismatched_count += 1
120 | else:
121 | new_train_indices.append(i)
122 |
123 | new_matched_valid_indices = set(new_matched_valid_indices)
124 | new_mismatched_valid_indices = set(new_mismatched_valid_indices)
125 |
126 | new_train = [header]
127 | new_matched_valid = [header]
128 | new_mismatched_valid = [header]
129 | for i, line in tqdm.tqdm(enumerate(full_train)):
130 | if i in new_matched_valid_indices:
131 | new_matched_valid.append(line)
132 | elif i in new_mismatched_valid_indices:
133 | new_mismatched_valid.append(line)
134 | else:
135 | new_train.append(line)
136 |
137 | new_train_path = os.path.join(new_dir, task_folder, 'train.tsv')
138 | new_matched_valid_path = os.path.join(new_dir, task_folder, 'dev_matched.tsv')
139 | new_mismatched_valid_path = os.path.join(new_dir, task_folder, 'dev_mismatched.tsv')
140 |
141 | write_lines(new_train_path, new_train)
142 | write_lines(new_matched_valid_path, new_matched_valid)
143 | write_lines(new_mismatched_valid_path, new_mismatched_valid)
144 |
145 | matched_test_path = os.path.join(original_dir, task_folder, 'dev_matched.tsv')
146 | new_matched_test_path = os.path.join(new_dir, task_folder, 'test_matched.tsv')
147 | os.system(f'cp {matched_test_path} {new_matched_test_path}')
148 |
149 | mismatched_test_path = os.path.join(original_dir, task_folder, 'dev_mismatched.tsv')
150 | new_mismatched_test_path = os.path.join(new_dir, task_folder, 'test_mismatched.tsv')
151 | os.system(f'cp {mismatched_test_path} {new_mismatched_test_path}')
152 |
153 |
154 | if __name__ == "__main__":
155 | fire.Fire(main)
156 |
--------------------------------------------------------------------------------
/examples/classification/requirements.txt:
--------------------------------------------------------------------------------
1 | argcomplete==1.12.1
2 | avro-python3==1.9.2.1
3 | azure-storage-blob==12.4.0
4 | bottle==0.12.19
5 | certifi==2021.5.30
6 | chardet==3.0.4
7 | charset-normalizer==2.0.4
8 | click==8.0.1
9 | crcmod==1.7
10 | cycler==0.10.0
11 | diffimg==0.2.3
12 | docopt==0.6.2
13 | fastavro==1.4.1
14 | filelock==3.0.12
15 | fire
16 | fusepy==2.0.4
17 | future==0.18.2
18 | gdown>=3.13.0
19 | httplib2==0.17.4
20 | idna==3.2
21 | imageio==2.9.0
22 | indexed-gzip-fileobj-fork-epicfaace==1.5.4
23 | isodate==0.6.0
24 | joblib==1.0.1
25 | kiwisolver==1.3.1
26 | markdown2==2.3.10
27 | marshmallow==2.15.1
28 | marshmallow-jsonapi==0.15.1
29 | matplotlib==3.4.3
30 | mock==2.0.0
31 | networkx==2.6.2
32 | nltk==3.6.2
33 | numpy==1.21.2
34 | oauth2client==4.1.3
35 | opacus==0.13.0
36 | packaging==21.0
37 | pandas==1.3.2
38 | pathtools==0.1.2
39 | pbr==5.6.0
40 | Pillow==8.3.1
41 | psutil==5.7.2
42 | pyasn1==0.4.8
43 | pyasn1-modules==0.2.8
44 | pycparser==2.20
45 | pydot==1.4.2
46 | pymongo==3.11.4
47 | pyparsing==2.4.7
48 | PySocks==1.7.1
49 | python-dateutil==2.8.*
50 | pytz==2021.1
51 | PyWavelets==1.1.1
52 | PyYAML==5.4.*
53 | regex==2021.8.3
54 | requests
55 | retry==0.9.2
56 | sacremoses==0.0.45
57 | scikit-image==0.18.2
58 | scikit-learn==0.24.2
59 | scipy==1.7.1
60 | seaborn==0.11.2
61 | selenium==3.141.0
62 | sentence-transformers>=2.0.0
63 | sentencepiece==0.1.96
64 | sentry-sdk==0.18.0
65 | six==1.15.0
66 | SQLAlchemy==1.3.19
67 | termcolor==1.1.0
68 | threadpoolctl==2.2.0
69 | tifffile==2021.8.8
70 | tokenizers==0.10.3
71 | tqdm>=4.62.1
72 | typing-extensions==3.7.4.3
73 | urllib3==1.26.*
74 | watchdog==0.10.3
75 | websocket-client==1.0.1
76 | gpytorch
77 | jupyterlab
78 |
--------------------------------------------------------------------------------
/examples/classification/run_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Wrapper launcher script."""
16 |
17 | import os
18 |
19 | import fire
20 |
21 | from .src import common
22 |
23 |
24 | def _get_command(
25 | task_name,
26 | output_dir,
27 | model_name_or_path,
28 | data_dir,
29 | learning_rate,
30 | clipping_mode: str,
31 | non_private,
32 | target_epsilon,
33 | few_shot_type,
34 | seed,
35 | attention_only,
36 | static_lm_head,
37 | static_embedding,
38 | randomly_initialize,
39 | per_device_train_batch_size,
40 | batch_size,
41 | num_train_epochs,
42 | eval_steps,
43 | eval_spectrum,
44 | max_spectrum_batches,
45 | max_lanczos_iter,
46 | store_grads,
47 | orthogonal_projection_path,
48 | orthogonal_projection_rank,
49 | ):
50 | task_name_to_factor = {
51 | "sst-2": 1, "qnli": 2, "qqp": 6, "mnli": 6,
52 | }
53 | factor = task_name_to_factor[task_name]
54 |
55 | if batch_size is None:
56 | base_batch_size = 1000
57 | # This batch size selection roughly ensures the sampling rates on different
58 | # datasets are in the same ballpark.
59 | batch_size = int(base_batch_size * factor)
60 | gradient_accumulation_steps = batch_size // per_device_train_batch_size
61 |
62 | if num_train_epochs is None:
63 | base_num_train_epochs = 3
64 | num_train_epochs = int(base_num_train_epochs * factor)
65 |
66 | if learning_rate is None:
67 | if non_private.lower() in ('yes', 'y', 'true', 't'):
68 | learning_rate = 5e-5
69 | else:
70 | learning_rate = 5e-4
71 |
72 | data_dir = f"{data_dir}/{common.task_name2suffix_name[task_name]}"
73 | template = {
74 | "sst-2": "*cls**sent_0*_It_was*mask*.*sep+*",
75 | "mnli": "*cls**sent-_0*?*mask*,*+sentl_1**sep+*",
76 | "qnli": "*cls**sent-_0*?*mask*,*+sentl_1**sep+*",
77 | "qqp": "*cls**sent-_0**mask*,*+sentl_1**sep+*",
78 | }[task_name]
79 |
80 | # Epochs chosen roughly to match e2e number of updates. We didn't hyperparameter tune on classification tasks :)
81 | cmd = f'''
82 | python -m classification.run_classification \
83 | --task_name {task_name} \
84 | --data_dir {data_dir} \
85 | --output_dir {output_dir} \
86 | --overwrite_output_dir \
87 | --model_name_or_path {model_name_or_path} \
88 | --few_shot_type {few_shot_type} \
89 | --num_k 1 \
90 | --num_sample 1 --seed {seed} \
91 | --template {template} \
92 | --non_private {non_private} \
93 | --num_train_epochs {num_train_epochs} \
94 | --target_epsilon {target_epsilon} \
95 | --per_device_train_batch_size {per_device_train_batch_size} \
96 | --gradient_accumulation_steps {gradient_accumulation_steps} \
97 | --per_device_eval_batch_size 8 \
98 | --per_example_max_grad_norm 0.1 --clipping_mode {clipping_mode} \
99 | --learning_rate {learning_rate} \
100 | --lr_decay yes \
101 | --adam_epsilon 1e-08 \
102 | --weight_decay 0 \
103 | --max_seq_len 256 \
104 | --evaluation_strategy steps --eval_steps {eval_steps} --evaluate_before_training True \
105 | --do_train --do_eval \
106 | --first_sent_limit 200 --other_sent_limit 200 --truncate_head yes \
107 | --attention_only {attention_only} --static_lm_head {static_lm_head} --static_embedding {static_embedding} \
108 | --randomly_initialize {randomly_initialize} \
109 | --eval_spectrum {eval_spectrum} --max_spectrum_batches {max_spectrum_batches} --max_lanczos_iter {max_lanczos_iter} \
110 | --store_grads {store_grads}'''
111 | if orthogonal_projection_path is not None:
112 | cmd += f' --orthogonal_projection_path {orthogonal_projection_path}'
113 | cmd += f' --orthogonal_projection_rank {orthogonal_projection_rank}'
114 | return cmd
115 |
116 |
117 | def main(
118 | output_dir,
119 | task_name,
120 | few_shot_type="prompt",
121 | seed=42,
122 | model_name_or_path="roberta-base",
123 | data_dir="classification/data/original",
124 | learning_rate=None,
125 | clipping_mode="ghost",
126 | non_private="no",
127 | target_epsilon=8,
128 | attention_only="no",
129 | static_lm_head="no",
130 | static_embedding="no",
131 | per_device_train_batch_size=20,
132 | eval_steps=10,
133 | eval_spectrum="no",
134 | max_spectrum_batches=2,
135 | max_lanczos_iter=2,
136 | randomly_initialize="no",
137 | batch_size=None,
138 | num_train_epochs=None,
139 | store_grads="no",
140 | orthogonal_projection_path=None,
141 | orthogonal_projection_rank=100,
142 | ):
143 | command = _get_command(
144 | output_dir=output_dir,
145 | task_name=task_name,
146 | model_name_or_path=model_name_or_path,
147 | data_dir=data_dir,
148 | learning_rate=learning_rate,
149 | clipping_mode=clipping_mode,
150 | non_private=non_private,
151 | target_epsilon=target_epsilon,
152 | few_shot_type=few_shot_type,
153 | seed=seed,
154 | attention_only=attention_only,
155 | static_lm_head=static_lm_head,
156 | static_embedding=static_embedding,
157 | per_device_train_batch_size=per_device_train_batch_size,
158 | eval_steps=eval_steps,
159 | eval_spectrum=eval_spectrum,
160 | max_spectrum_batches=max_spectrum_batches,
161 | max_lanczos_iter=max_lanczos_iter,
162 | randomly_initialize=randomly_initialize,
163 | batch_size=batch_size,
164 | num_train_epochs=num_train_epochs,
165 | store_grads=store_grads,
166 | orthogonal_projection_path=orthogonal_projection_path,
167 | orthogonal_projection_rank=orthogonal_projection_rank,
168 | )
169 | print('Running command:')
170 | print(command)
171 | os.system(command)
172 |
173 |
174 | if __name__ == "__main__":
175 | fire.Fire(main)
176 |
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/3d_surface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxuechen/private-transformers/18ccc4eab7355e4ac96051a82434796f6aa4624b/examples/classification/spectral_analysis/3d_surface.png
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/README.md:
--------------------------------------------------------------------------------
1 | ## Experiments for spectral analysis
2 |
3 | Everything below should be run from the `examples/` folder.
4 |
5 | 1. To run with the geometric median example in the paper, use the following command and supply an `` of your
6 | choice:
7 |
8 | ```bash
9 | python -m classification.spectral_analysis.geometric_median --img_dir
10 | ```
11 |
12 | 2. Spectral analysis.
13 |
14 | 2.1. To reproduce the spectral analysis experiments, one first need a first round of fine-tuning to collect
15 | gradients.
16 | Run the following command with `` of your choice. Note everything down the line about PCA will be stored
17 | here, so make sure you have enough diskspace! It's perhaps safe to reserve 500G~1T. The spectral analyses are very
18 | diskspace intensive. Note below `` can be one of `distilroberta-base`, `roberta-base`
19 | , `roberta-large`.
20 |
21 | ```bash
22 | CUDA_VISIBLE_DEVICES=0 python -m classification.spectral_analysis.rebuttal_neurips_2022 \
23 | --task "run_save_grads" \
24 | --train_dir \
25 | --model_name_or_path
26 | ```
27 |
28 | 2.2. Now run PCA with orthogonal iteration to extract top eigenvectors. The command below runs PCA based on 4k
29 | checkpoints (4k gradients stored along the trajectory), and extracts the top 1k eigenvalues and
30 | eigenvectors. `batch_size` can be set small to save memory (it affects distributed matmul). Note for
31 | the `roberta-large` experiment, you would likely need several GPUs. For reference, I used 4 A6000 (each with 48G
32 | VRAM) for that experiment. The code is written in a
33 | way so that computation can be distributed across many GPUs on a single machine, and should be
34 | fast with enough accelerators. **`` below must be the same as in the previous command.**
35 |
36 | ```bash
37 | python -m classification.spectral_analysis.rebuttal_neurips_2022 \
38 | --task "run_pca" \
39 | --train_dir \
40 | --n 4000 \
41 | --k 1000 \
42 | --num_power_iteration 10 \
43 | --batch_size 20
44 | ```
45 |
46 | 2.3. For re-training in subspace, we need to specify to the command the place where the PCA results are stored in
47 | order to use those. The PCA results will be in `/orthproj/all/`. There will likely be a couple of
48 | checkpoints in this folder, each of which correspondes to a different iteration of the orthogonal iteration. Now run
49 | the
50 | following command. Note that below 1) `` should **not** be the same as `` to avoid
51 | overwriting previous PCA results, and 2) `` should be smaller than `k` from the previous command since it's the
52 | rank of the subspace.
53 |
54 | ```bash
55 | CUDA_VISIBLE_DEVICES=0 python -m classification.spectral_analysis.rebuttal_neurips_2022 \
56 | --task "run_retrain_single" \
57 | --output_dir \
58 | --orthogonal_projection_path "/orthproj/all/global_step_x.pt" \
59 | --rank \
60 | --model_name_or_path
61 | ```
62 |
63 | ## Citation
64 |
65 | If you found this codebase useful in your research, please consider citing:
66 |
67 | ```@misc{li2022when,
68 | doi = {10.48550/ARXIV.2207.00160},
69 | url = {https://arxiv.org/abs/2207.00160},
70 | author = {Li, Xuechen and Liu, Daogao and Hashimoto, Tatsunori and Inan, Huseyin A. and Kulkarni, Janardhan and Lee, Yin Tat and Thakurta, Abhradeep Guha},
71 | keywords = {Machine Learning (cs.LG), Cryptography and Security (cs.CR), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
72 | title = {When Does Differentially Private Learning Not Suffer in High Dimensions?},
73 | publisher = {arXiv},
74 | year = {2022},
75 | copyright = {Creative Commons Attribution 4.0 International}
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/density.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | # Copyright 2019 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Code for converting Lanczos outputs to densities."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 |
24 | import numpy as np
25 |
26 |
27 | def eigv_to_density(eig_vals, all_weights=None, grids=None,
28 | grid_len=10000, sigma_squared=None, grid_expand=1e-2):
29 | """Compute the smoothed spectral density from a set of eigenvalues.
30 |
31 | Convolves the given eigenvalues with a Gaussian kernel, weighting the values
32 | by all_weights (or uniform weighting if all_weights is None). Example output
33 | can be seen in Figure 1 of https://arxiv.org/pdf/1901.10159.pdf. Visualizing
34 | the estimated density can be done by calling plt.plot(grids, density). There
35 | is likely not a best value of sigma_squared that works for all use cases,
36 | so it is recommended to try multiple values in the range [1e-5,1e-1].
37 |
38 | Args:
39 | eig_vals: Array of shape [num_draws, order]
40 | all_weights: Array of shape [num_draws, order], if None then weights will be
41 | taken to be uniform.
42 | grids: Array of shape [grid_len], the smoothed spectrum will be plotted
43 | in the interval [grids[0], grids[-1]]. If None then grids will be
44 | computed based on max and min eigenvalues and grid length.
45 | grid_len: Integer specifying number of grid cells to use, only used if
46 | grids is None
47 | sigma_squared: Scalar. Controls the smoothing of the spectrum estimate.
48 | If None, an appropriate value is inferred.
49 | grid_expand: Controls the window of values that grids spans.
50 | grids[0] = smallest eigenvalue - grid_expand.
51 | grids[-1] = largest_eigenvalue + grid_expand.
52 |
53 | Returns:
54 | density: Array of shape [grid_len], the estimated density, averaged over
55 | all draws.
56 | grids: Array of shape [grid_len]. The values the density is estimated on.
57 | """
58 | if all_weights is None:
59 | all_weights = np.ones(eig_vals.shape) * 1.0 / float(eig_vals.shape[1])
60 | num_draws = eig_vals.shape[0]
61 |
62 | lambda_max = np.nanmean(np.max(eig_vals, axis=1), axis=0) + grid_expand
63 | lambda_min = np.nanmean(np.min(eig_vals, axis=1), axis=0) - grid_expand
64 |
65 | if grids is None:
66 | assert grid_len is not None, 'grid_len is required if grids is None.'
67 | grids = np.linspace(lambda_min, lambda_max, num=grid_len)
68 |
69 | grid_len = grids.shape[0]
70 | if sigma_squared is None:
71 | sigma = 10 ** -5 * max(1, (lambda_max - lambda_min))
72 | else:
73 | sigma = sigma_squared * max(1, (lambda_max - lambda_min))
74 |
75 | density_each_draw = np.zeros((num_draws, grid_len))
76 | for i in range(num_draws):
77 |
78 | if np.isnan(eig_vals[i, 0]):
79 | raise ValueError('tridaig has nan values.')
80 | else:
81 | for j in range(grid_len):
82 | x = grids[j]
83 | vals = _kernel(eig_vals[i, :], x, sigma)
84 | density_each_draw[i, j] = np.sum(vals * all_weights[i, :])
85 | density = np.nanmean(density_each_draw, axis=0)
86 | norm_fact = np.sum(density) * (grids[1] - grids[0])
87 | density = density / norm_fact
88 | return density, grids
89 |
90 |
91 | def tridiag_to_eigv(tridiag_list):
92 | """Preprocess the tridiagonal matrices for density estimation.
93 |
94 | Args:
95 | tridiag_list: Array of shape [num_draws, order, order] List of the
96 | tridiagonal matrices computed from running num_draws independent runs
97 | of lanczos. The output of this function can be fed directly into
98 | eigv_to_density.
99 |
100 | Returns:
101 | eig_vals: Array of shape [num_draws, order]. The eigenvalues of the
102 | tridiagonal matricies.
103 | all_weights: Array of shape [num_draws, order]. The weights associated with
104 | each eigenvalue. These weights are to be used in the kernel density
105 | estimate.
106 | """
107 | # Calculating the node / weights from Jacobi matrices.
108 | num_draws = len(tridiag_list)
109 | num_lanczos = tridiag_list[0].shape[0]
110 | eig_vals = np.zeros((num_draws, num_lanczos))
111 | all_weights = np.zeros((num_draws, num_lanczos))
112 | for i in range(num_draws):
113 | nodes, evecs = np.linalg.eigh(tridiag_list[i])
114 | index = np.argsort(nodes)
115 | nodes = nodes[index]
116 | evecs = evecs[:, index]
117 | eig_vals[i, :] = nodes
118 | all_weights[i, :] = evecs[0] ** 2
119 | return eig_vals, all_weights
120 |
121 |
122 | def tridiag_to_density(tridiag_list, sigma_squared=1e-5, grid_len=10000):
123 | """This function estimates the smoothed density from the output of lanczos.
124 |
125 | Args:
126 | tridiag_list: Array of shape [num_draws, order, order] List of the
127 | tridiagonal matrices computed from running num_draws independent runs
128 | of lanczos.
129 | sigma_squared: Controls the smoothing of the density.
130 | grid_len: Controls the granularity of the density.
131 |
132 | Returns:
133 | density: Array of size [grid_len]. The smoothed density estimate averaged
134 | over all num_draws.
135 | grids: Array of size [grid_len]. The values the density estimate is on.
136 | """
137 | eig_vals, all_weights = tridiag_to_eigv(tridiag_list)
138 | density, grids = eigv_to_density(eig_vals, all_weights,
139 | grid_len=grid_len,
140 | sigma_squared=sigma_squared)
141 | return density, grids
142 |
143 |
144 | def _kernel(x, x0, variance):
145 | """Point estimate of the Gaussian kernel.
146 |
147 | This function computes the Gaussian kernel for
148 | C exp(-(x - x0) ^2 /(2 * variance)) where C is the appropriate normalization.
149 | variance should be a list of length 1. Either x0 or x should be a scalar. Only
150 | one of the x or x0 can be a numpy array.
151 |
152 | Args:
153 | x: Can be either scalar or array of shape [order]. Points to estimate
154 | the kernel on.
155 | x0: Scalar. Mean of the kernel.
156 | variance: Scalar. Variance of the kernel.
157 |
158 | Returns:
159 | point_estimate: A scalar corresponding to
160 | C exp(-(x - x0) ^2 /(2 * variance)).
161 | """
162 | coeff = 1.0 / np.sqrt(2 * math.pi * variance)
163 | val = -(x0 - x) ** 2
164 | val = val / (2.0 * variance)
165 | val = np.exp(val)
166 | point_estimate = coeff * val
167 | return point_estimate
168 |
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/geometric_median.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Toy example on geometric median estimation in the paper.
17 |
18 | python -m classification.spectral_analysis.geometric_median --img_dir "/mnt/disks/disk-2/dump/spectrum/geometric_median"
19 | """
20 | import dataclasses
21 | import logging
22 | import math
23 | import sys
24 | from typing import Tuple
25 |
26 | import fire
27 | import numpy as np
28 | import torch
29 | import tqdm
30 | from ml_swissknife import utils
31 |
32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33 |
34 |
35 | @dataclasses.dataclass
36 | class Data:
37 | beta_train: torch.Tensor
38 | beta_test: torch.Tensor
39 | Ar: torch.Tensor # A^{1/2}.
40 | sensitivity: float
41 |
42 | def __post_init__(self):
43 | self.n_train, self.d = self.beta_train.size()
44 | self.n_test = self.beta_test.shape[0]
45 |
46 |
47 | class Modes(metaclass=utils.ContainerMeta):
48 | const = "const"
49 | quarter = "quarter"
50 | sqrt = "sqrt"
51 | linear = "linear"
52 | quadratic = "quadratic"
53 |
54 |
55 | def make_data(
56 | betas=None,
57 | n_train=100000, n_test=100000, d=10, dmin=1, mu_beta=0.2, si_beta=0.1,
58 | mode="linear",
59 | g0=1.,
60 | ):
61 | if betas is None:
62 | beta_train, beta_test = make_beta(
63 | n_train=n_train, n_test=n_test, d=d, dmin=dmin, mu_beta=mu_beta, si_beta=si_beta
64 | )
65 | else:
66 | beta_train, beta_test = betas
67 | n_train, d = beta_train.size()
68 | n_test, _ = beta_test.size()
69 |
70 | if mode == Modes.const:
71 | Ar = g0 * torch.arange(1, d + 1, device=device)
72 | elif mode == Modes.quarter:
73 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -.25
74 | elif mode == Modes.sqrt:
75 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -.5
76 | elif mode == Modes.linear:
77 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -1.
78 | elif mode == Modes.quadratic:
79 | Ar = g0 * torch.arange(1, d + 1, device=device) ** -2.
80 | else:
81 | raise ValueError(f"Unknown mode: {mode}")
82 |
83 | sensitivity = 2 * g0 / n_train
84 |
85 | return Data(beta_train=beta_train, beta_test=beta_test, Ar=Ar, sensitivity=sensitivity)
86 |
87 |
88 | def make_beta(n_train, n_test, d, dmin, mu_beta, si_beta):
89 | if d < dmin:
90 | raise ValueError(f"d < dmin")
91 |
92 | beta_train = mu_beta + torch.randn(size=(n_train, d), device=device) * si_beta
93 | beta_train[:, dmin:] = 0. # Ensure init distance to opt is the same.
94 |
95 | beta_test = mu_beta + torch.randn(size=(n_test, d), device=device) * si_beta
96 | beta_test[:, dmin:] = 0. # Same distribution as train.
97 |
98 | return beta_train, beta_test
99 |
100 |
101 | def evaluate(data: Data, beta: torch.Tensor) -> Tuple:
102 | """Compute loss 1 / n sum_i | A^{1/2} (beta - beta_i) |_2 for train and test."""
103 |
104 | def compute_loss(samples):
105 | res = data.Ar[None, :] * (beta - samples) # (n, d).
106 | return res.norm(2, dim=1).mean(dim=0).item()
107 |
108 | return tuple(
109 | compute_loss(samples=samples)
110 | for samples in (data.beta_train, data.beta_test)
111 | )
112 |
113 |
114 | def train_one_step(data: Data, beta, lr, epsilon, delta, weight_decay):
115 | res = data.Ar[None, :] * (beta - data.beta_train) # (n, d).
116 | grad = data.Ar * (res / res.norm(2, dim=1, keepdim=True)).mean(dim=0)
117 |
118 | gaussian_mechanism_variance = 2. * math.log(1.25 / delta) * data.sensitivity ** 2. / epsilon ** 2.
119 | grad_priv = grad + torch.randn_like(grad) * math.sqrt(gaussian_mechanism_variance)
120 | beta = beta - lr * (grad_priv + weight_decay * beta)
121 | return beta
122 |
123 |
124 | @torch.no_grad()
125 | def train(data: Data, num_steps, eval_steps, lr, weight_decay, epsilon, delta, tag, verbose, seed):
126 | utils.manual_seed(seed)
127 |
128 | per_step_epsilon, per_step_delta = make_per_step_privacy_spending(
129 | target_epsilon=epsilon, target_delta=delta, num_steps=num_steps
130 | )
131 |
132 | beta = torch.zeros(size=(1, data.d,), device=device)
133 | beta_avg = beta.clone()
134 |
135 | for global_step in range(0, num_steps):
136 | if global_step % eval_steps == 0:
137 | tr_loss, te_loss = evaluate(data=data, beta=beta_avg)
138 | if verbose:
139 | logging.warning(
140 | f"tag: {tag}, global_step: {global_step}, lr: {lr:.6f}, num_steps: {num_steps}, "
141 | f"train_loss: {tr_loss:.6f}, test_loss: {te_loss:.6f}"
142 | )
143 |
144 | beta = train_one_step(
145 | data=data,
146 | beta=beta,
147 | lr=lr, weight_decay=weight_decay,
148 | epsilon=per_step_epsilon, delta=per_step_delta,
149 | )
150 | beta_avg = beta_avg * global_step / (global_step + 1) + beta / (global_step + 1)
151 |
152 | final_tr_loss, final_te_loss = evaluate(data=data, beta=beta_avg)
153 | if verbose:
154 | logging.warning(
155 | f"tag: {tag}, final, lr: {lr:.6f}, num_steps: {num_steps}, "
156 | f"train_loss: {final_tr_loss:.6f}, te_loss: {final_te_loss:.6f}"
157 | )
158 |
159 | return beta_avg, (final_tr_loss, final_te_loss)
160 |
161 |
162 | def make_per_step_privacy_spending(
163 | target_epsilon, target_delta, num_steps, threshold=1e-4,
164 | ):
165 | per_step_delta = target_delta / (num_steps + 1)
166 |
167 | def adv_composition(per_step_epsilon):
168 | total_epsilon = (
169 | math.sqrt(2 * num_steps * math.log(1 / per_step_delta)) * per_step_epsilon +
170 | num_steps * per_step_epsilon * (math.exp(per_step_epsilon) - 1)
171 | )
172 | return total_epsilon
173 |
174 | minval, maxval = 1e-6, 5
175 | while maxval - minval > threshold:
176 | midval = (maxval + minval) / 2
177 | eps = adv_composition(midval)
178 | if eps > target_epsilon:
179 | maxval = midval
180 | else:
181 | minval = midval
182 | per_step_epsilon = minval
183 | return per_step_epsilon, per_step_delta
184 |
185 |
186 | def main(
187 | img_dir=None, eval_steps=10000, weight_decay=0, epsilon=2, delta=1e-6,
188 | n_train=10000, n_test=10000, dmin=1, mu_beta=1., si_beta=1, g0=3.,
189 | seeds=(42, 96, 10000, 999, 101), # Some arbitrary numbers.
190 | modes=(Modes.const, Modes.sqrt, Modes.linear), # A subset of all possible modes for visualization.
191 | verbose=False,
192 | quick=False, # Use small data if True.
193 | ):
194 | if quick:
195 | dims = (10, 50,)
196 | num_steps_list = (10, 20,)
197 | lrs = (1e-4, 3e-4,)
198 | else:
199 | dims = (20, 50, 100, 200, 500, 1000, 2000)
200 | num_steps_list = (10, 20, 40, 80, 160, 320, 640, 1280, 2560, 5120)
201 | lrs = (1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1, 1, 3,)
202 |
203 | tr_losses = {mode: [] for mode in modes}
204 | te_losses = {mode: [] for mode in modes}
205 | for dim in tqdm.tqdm(dims, desc="dims"):
206 | betas = make_beta(n_train=n_train, n_test=n_test, d=dim, dmin=dmin, mu_beta=mu_beta, si_beta=si_beta)
207 | data = tuple(make_data(betas=betas, mode=mode, g0=g0) for mode in modes)
208 |
209 | tr_loss = {mode: [sys.maxsize] for mode in modes}
210 | te_loss = {mode: [sys.maxsize] for mode in modes}
211 | for this_data, this_mode in tqdm.tqdm(utils.zip_(data, modes), desc="modes", total=len(data)):
212 |
213 | # Hyperparameter tuning.
214 | for num_steps in num_steps_list:
215 | for lr in lrs:
216 | kwargs = dict(
217 | data=this_data,
218 | num_steps=num_steps,
219 | lr=lr,
220 |
221 | eval_steps=eval_steps,
222 | weight_decay=weight_decay,
223 | epsilon=epsilon,
224 | delta=delta,
225 | tag=this_mode,
226 | verbose=verbose,
227 | )
228 |
229 | tr_results = []
230 | te_results = []
231 | for seed in seeds:
232 | _, (a, b) = train(**kwargs, seed=seed)
233 | tr_results.append(a)
234 | te_results.append(b)
235 |
236 | if np.mean(tr_results) < np.mean(tr_loss[this_mode]):
237 | tr_loss[this_mode] = tr_results
238 | te_loss[this_mode] = te_results
239 |
240 | # update after hp tuning.
241 | for this_mode in modes:
242 | tr_losses[this_mode].append(tr_loss[this_mode])
243 | te_losses[this_mode].append(te_loss[this_mode])
244 |
245 | raw_data = dict(tr_losses=tr_losses, te_losses=te_losses, modes=modes, dims=dims)
246 |
247 | if img_dir is not None:
248 | utils.jdump(raw_data, utils.join(img_dir, 'toyplot.json'))
249 |
250 | plot_modes = modes
251 | linestyles = ("-", "--", ":", "-.")
252 | markers = ("o", "+", "x", "^")
253 |
254 | tr_plotting = dict(
255 | errorbars=tuple(
256 | dict(
257 | x=dims,
258 | y=np.mean(np.array(tr_losses[this_mode]), axis=1),
259 | yerr=np.std(np.array(tr_losses[this_mode]), axis=1),
260 | label=this_mode, marker=markers[mode_idx],
261 | linestyle=linestyles[mode_idx]
262 | )
263 | for mode_idx, this_mode in enumerate(plot_modes)
264 | ),
265 | options=dict(xlabel="$d$", ylabel="train loss")
266 | )
267 | utils.plot_wrapper(
268 | img_path=utils.join(img_dir, 'trplot'),
269 | suffixes=('.png', '.pdf'),
270 | **tr_plotting,
271 | )
272 |
273 | te_plotting = dict(
274 | errorbars=tuple(
275 | dict(
276 | x=dims,
277 | y=np.mean(np.array(te_losses[this_mode]), axis=1),
278 | yerr=np.std(np.array(te_losses[this_mode]), axis=1),
279 | label=this_mode, marker=markers[mode_idx],
280 | linestyle=linestyles[mode_idx]
281 | )
282 | for mode_idx, this_mode in enumerate(plot_modes)
283 | ),
284 | options=dict(xlabel="$d$", ylabel="test loss")
285 | )
286 | utils.plot_wrapper(
287 | img_path=utils.join(img_dir, 'teplot'),
288 | suffixes=('.png', '.pdf'),
289 | **te_plotting,
290 | )
291 |
292 |
293 | if __name__ == "__main__":
294 | fire.Fire(main)
295 |
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/rebuttal_neurips_2022.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Experiments ran pre- and post-rebuttals."""
16 | import logging
17 | import os
18 | from typing import Optional
19 |
20 | import fire
21 | import torch
22 | import tqdm
23 | from ml_swissknife import utils, numerical_distributed
24 | from torch.utils.data import DataLoader, TensorDataset
25 |
26 |
27 | def run_save_grads(
28 | num_train_epochs=60, # This amounts to 4k updates, roughly.
29 | model_name_or_path="roberta-base",
30 | train_dir=None,
31 | per_device_train_batch_size=25,
32 | ):
33 | if train_dir is None:
34 | train_dir = utils.join("/mnt/data1/dump/", 'rebuttal_v2', f'run-{model_name_or_path}')
35 | command = f'''python -m classification.run_wrapper \
36 | --output_dir {train_dir} \
37 | --task_name "sst-2" \
38 | --model_name_or_path "{model_name_or_path}" \
39 | --attention_only "yes" \
40 | --static_lm_head "yes" \
41 | --num_train_epochs {num_train_epochs} \
42 | --eval_spectrum "no" \
43 | --non_private "no" \
44 | --eval_steps 50 \
45 | --randomly_initialize "no" \
46 | --per_device_train_batch_size {per_device_train_batch_size} \
47 | --batch_size 1000 \
48 | --clipping_mode "default" \
49 | --store_grads "yes"'''
50 | os.system(command)
51 |
52 |
53 | def run_pca(
54 | # Place where grads are stored and where results will be stored.
55 | train_dir="/mnt/disks/disk-2/dump/privlm/roberta/sst-2",
56 | n=2000, # How many checkpoints?
57 | k=1000, # How many eigenvectors?
58 | num_power_iteration=10,
59 | batch_size=20, # Batch size for processing the checkpoints in matmul.
60 | seed=42, # Controls randomness in sampling the first vector in orthogonal iteration.
61 | start_index=0, # The index of the first checkpoint to be selected.
62 | eval_steps=5, # Evaluate PCA accuracy once this many iterations.
63 | save_steps=5, # Save eigenvalue and eigenvector tensors once this many iterations.
64 | disable_tqdm=False,
65 | dtype="float", # String repr of dtype.
66 | ):
67 | utils.manual_seed(seed)
68 |
69 | ckpt_dir = utils.join(train_dir, 'grad_trajectory')
70 | dump_dir = utils.join(train_dir, 'orthproj')
71 |
72 | all_ckpts = utils.all_ckpts(ckpt_dir, sort=True)
73 | tgt_ckpts = all_ckpts[start_index:start_index + n]
74 | dataset = torch.stack([
75 | torch.load(ckpt_path)["flat_grad"] for ckpt_path in tqdm.tqdm(tgt_ckpts, desc="load data")
76 | ]).to(utils.get_dtype(dtype))
77 | input_mat = DataLoader(dataset=TensorDataset(dataset), batch_size=batch_size)
78 |
79 | def callback(global_step, eigenvalues, eigenvectors):
80 | if global_step % save_steps == 0:
81 | utils.tsave(
82 | dict(eigenvalues=eigenvalues, eigenvectors=eigenvectors),
83 | utils.join(dump_dir, "all", f"global_step_{global_step:06d}.pt")
84 | )
85 | utils.tsave(
86 | dict(eigenvalues=eigenvalues),
87 | utils.join(dump_dir, "eigenvalues", f"global_step_{global_step:06d}.evals")
88 | )
89 | if global_step % eval_steps == 0:
90 | err_abs, err_rel = numerical_distributed.check_error(
91 | input_mat=input_mat, eigenvectors=eigenvectors, disable_tqdm=disable_tqdm
92 | )
93 | logging.warning(f"global_step: {global_step}, abs error: {err_abs:.6f}, rel error: {err_rel:.6f}")
94 |
95 | numerical_distributed.orthogonal_iteration(
96 | input_mat=input_mat,
97 | k=k,
98 | num_power_iteration=num_power_iteration,
99 | callback=callback,
100 | disable_tqdm=disable_tqdm,
101 | )
102 |
103 |
104 | def run_retrain_single(
105 | output_dir: str,
106 | orthogonal_projection_path: str,
107 | model_name_or_path: str,
108 | rank: Optional[int] = None,
109 | seed=42,
110 | ):
111 | cmd = f'''python -m classification.run_wrapper \
112 | --output_dir {output_dir} \
113 | --task_name "sst-2" \
114 | --model_name_or_path {model_name_or_path} \
115 | --few_shot_type "prompt" \
116 | --attention_only "yes" \
117 | --static_lm_head "yes" \
118 | --per_device_train_batch_size 25 \
119 | --batch_size 1000 \
120 | --clipping_mode "default" \
121 | --num_train_epochs 4 \
122 | --eval_spectrum "no" \
123 | --non_private "no" \
124 | --eval_steps 25 \
125 | --randomly_initialize "no" \
126 | --seed {seed} \
127 | --orthogonal_projection_path {orthogonal_projection_path}'''
128 | if rank is not None:
129 | cmd += f' --orthogonal_projection_rank {rank}'
130 | os.system(cmd)
131 |
132 |
133 | def main(task, **kwargs):
134 | globals()[task](**kwargs)
135 |
136 |
137 | if __name__ == "__main__":
138 | fire.Fire(main)
139 |
--------------------------------------------------------------------------------
/examples/classification/spectral_analysis/rebuttal_plots_neurips_2022.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Plot 1) spectral decay, 2) retrain curves.
17 | """
18 |
19 | import math
20 |
21 | import fire
22 | import numpy as np
23 | import scipy.stats
24 | import torch
25 | from ml_swissknife import utils
26 |
27 | from . import density
28 |
29 |
30 | def plot1(
31 | ckpt_path: str, # Path to eigenvalues.
32 | dump_dir="./classification/plots",
33 | img_name="",
34 | k=500,
35 | **kwargs,
36 | ):
37 | """Eigenvalues.
38 |
39 | Run on gvm.
40 | """
41 | # Roberta-large
42 | # python -m classification.spectral_analysis.rebuttal_plots_neurips_2022 --task "plot1" --ckpt_path "/mnt/data1/dump/rebuttal/run-roberta-large/orthproj/eigenvalues/global_step_000005.evals" --img_name "large" --k 100
43 | if img_name != "":
44 | img_name = f'-{img_name}'
45 |
46 | state_dicts = torch.load(ckpt_path)
47 | eigenvalues = state_dicts["eigenvalues"].numpy()
48 | eigenvalues = -np.sort(-eigenvalues)
49 | k = min(k, len(eigenvalues))
50 |
51 | # Linear fit.
52 | x = np.arange(1, k + 1)
53 | g = np.sqrt(eigenvalues[:k])
54 | logg = np.log(g)
55 | logx = np.log(x)
56 |
57 | linfit = scipy.stats.linregress(logx, logg)
58 | g_linfit = np.exp(logx * linfit.slope + linfit.intercept)
59 |
60 | print("slope:", linfit.slope)
61 | print("R value:", linfit.rvalue)
62 |
63 | plots = [
64 | dict(x=x, y=g, marker='+', linewidth=0, label="estimated values", markersize=8, alpha=0.8),
65 | dict(x=x, y=g_linfit,
66 | label=f"linear fit: $\log y = {linfit.slope:.2f} \log x {linfit.intercept:.2f} $ ($R^2="
67 | f"{linfit.rvalue ** 2.:.3f}$)"),
68 | ]
69 | utils.plot_wrapper(
70 | img_path=utils.join(dump_dir, f"eigenvalue-linfit{img_name}"),
71 | suffixes=(".png", ".pdf"),
72 | plots=plots,
73 | options=dict(xlabel="$k$", ylabel="$\lambda(H^\\top H)^{1/2}$", xscale='log', yscale='log')
74 | )
75 |
76 | # Spectral density.
77 | sigma_squared = 1e-6
78 | evals = np.sqrt(eigenvalues[None, :k])
79 | den, gri = density.eigv_to_density(evals, sigma_squared=sigma_squared, grid_len=300000, grid_expand=3e-4)
80 | utils.plot_wrapper(
81 | img_path=utils.join(dump_dir, f'eigenvalue-density{img_name}'),
82 | suffixes=(".png", ".pdf"),
83 | plots=[dict(x=gri, y=den, label=f"bandwidth $\sigma={math.sqrt(sigma_squared):.5f}$")],
84 | options=dict(xlabel="$\lambda(H^\\top H)^{1/2}$", ylabel="Density of KDE",
85 | ylim=dict(bottom=1e-10, top=2e2),
86 | xscale="log", yscale='log')
87 | )
88 |
89 |
90 | def plot2(
91 | base_dir: str,
92 | img_name="",
93 | seeds=(42, 9008, 0),
94 | ranks=(10, 20, 100, None),
95 | dump_dir="./classification/plots",
96 | markers=('x', '^', '+', 'o'),
97 | roberta_large=False,
98 | **kwargs,
99 | ):
100 | """Retrain.
101 |
102 | Run locally.
103 | """
104 | # Roberta-large
105 | # python -m classification.spectral_analysis.rebuttal_plots_neurips_2022 --task "plot2" --img_name "large" --base_dir "/mnt/data1/dump/rebuttal" --roberta_large True
106 | if img_name != "":
107 | img_name = f'-{img_name}'
108 |
109 | errorbars = []
110 | for rank, marker in utils.zip_(ranks, markers):
111 | results = []
112 | for seed in seeds:
113 | if roberta_large:
114 | output_dir = utils.join(
115 | f"{base_dir}/roberta_prompt_large_retrain_{rank}_{seed}/sst-2",
116 | 'log_history.json'
117 | )
118 | else:
119 | output_dir = utils.join(
120 | f"{base_dir}/roberta_prompt_retrain_{rank}_{seed}/sst-2",
121 | 'log_history.json'
122 | )
123 | record = utils.jload(output_dir)
124 | results.append([dumpi['dev']['eval_acc'] for dumpi in record])
125 | steps = [dumpi['step'] for dumpi in record]
126 |
127 | label = f"subspace rank={rank}" if rank is not None else "original"
128 | mu, si = utils.average_over_seed(results)
129 | errorbar = dict(x=steps, y=mu, yerr=si, label=label, marker=marker)
130 | errorbars.append(errorbar)
131 |
132 | img_path = utils.join(dump_dir, f'plot2{img_name}')
133 | utils.plot_wrapper(
134 | img_path=img_path,
135 | suffixes=('.png', '.pdf'),
136 | errorbars=errorbars,
137 | options=dict(xlabel="iteration", ylabel="SST-2 classification accuracy (dev)")
138 | )
139 |
140 |
141 | def plot_all(**kwargs):
142 | # rebuttal roberta-base experiments.
143 | # python -m classification.spectral_analysis.rebuttal_plots_neurips_2022 --task "plot_all" --base_dir "/mnt/data1/dump/rebuttal" --ckpt_path "/mnt/data1/dump/rebuttal/run-roberta-base/orthproj/eigenvalues/global_step_000010.evals"
144 | plot1(**kwargs)
145 | plot2(**kwargs)
146 |
147 |
148 | def main(task="plot_all", **kwargs):
149 | utils.runs_tasks(
150 | task=task,
151 | task_names=("plot_all", "plot1", "plot2"),
152 | task_callables=(plot_all, plot1, plot2),
153 | **kwargs,
154 | )
155 |
156 |
157 | if __name__ == "__main__":
158 | fire.Fire(main)
159 |
--------------------------------------------------------------------------------
/examples/classification/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/classification/src/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | task_name2suffix_name = {"sst-2": "GLUE-SST-2", "mnli": "MNLI", "qqp": "QQP", "qnli": "QNLI"}
18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 | true_tags = ('y', 'yes', 't', 'true')
20 |
--------------------------------------------------------------------------------
/examples/classification/src/compiled_args.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | import transformers
18 |
19 | from .common import true_tags
20 | from typing import Optional
21 |
22 |
23 | @dataclass
24 | class PrivacyArguments:
25 | """Arguments for differentially private training."""
26 |
27 | per_example_max_grad_norm: float = field(
28 | default=.1, metadata={
29 | "help": "Clipping 2-norm of per-sample gradients."
30 | }
31 | )
32 | noise_multiplier: float = field(
33 | default=None, metadata={
34 | "help": "Standard deviation of noise added for privacy; if `target_epsilon` is specified, "
35 | "use the one searched based budget"
36 | }
37 | )
38 | target_epsilon: float = field(
39 | default=None, metadata={
40 | "help": "Privacy budget; if `None` use the noise multiplier specified."
41 | }
42 | )
43 | target_delta: float = field(
44 | default=None, metadata={
45 | "help": "Lax probability in approximate differential privacy; if `None` use 1 / len(train_data)."
46 | }
47 | )
48 | non_private: str = field(
49 | default="yes", metadata={"help": "Train non-privately if True."}
50 | )
51 | accounting_mode: str = field(
52 | default="rdp", metadata={"help": "One of (`rdp`, `glw`, `all`)."}
53 | )
54 | clipping_mode: str = field(
55 | default="default"
56 | )
57 |
58 | def __post_init__(self):
59 | self.non_private = self.non_private.lower() in true_tags # noqa
60 |
61 |
62 | @dataclass
63 | class TrainingArguments(transformers.TrainingArguments):
64 | eval_epochs: int = field(default=10, metadata={"help": "Evaluate once such epochs"})
65 | evaluate_before_training: bool = field(default=False, metadata={"help": "Run evaluation before training."})
66 | lr_decay: str = field(
67 | default="no", metadata={"help": "Apply the usual linear decay if `yes`, otherwise no deacy."}
68 | )
69 | evaluate_test_split: bool = field(default=False, metadata={"help": "Run evaluation on the test split"})
70 |
71 | def __post_init__(self):
72 | super(TrainingArguments, self).__post_init__()
73 | self.lr_decay = self.lr_decay.lower() in true_tags # noqa
74 |
75 |
76 | @dataclass
77 | class AuxiliaryArguments:
78 | eval_spectrum: str = field(default="no")
79 | max_spectrum_batches: int = field(default=100)
80 | max_lanczos_iter: int = field(default=100)
81 |
82 | store_grads: str = field(default="no")
83 | orthogonal_projection_path: Optional[str] = field(default=None)
84 | orthogonal_projection_rank: int = field(default=100)
85 |
86 | def __post_init__(self):
87 | self.eval_spectrum = self.eval_spectrum.lower() in true_tags # noqa
88 | self.store_grads = self.store_grads.lower() in true_tags # noqa
89 |
--------------------------------------------------------------------------------
/examples/classification/src/label_search.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Automatic label search helpers."""
16 |
17 | import itertools
18 | import logging
19 | import multiprocessing
20 |
21 | import numpy as np
22 | import scipy.spatial as spatial
23 | import scipy.special as special
24 | import scipy.stats as stats
25 | import tqdm
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | def select_likely_words(train_logits, train_labels, k_likely=1000, vocab=None, is_regression=False):
31 | """Pre-select likely words based on conditional likelihood."""
32 | indices = []
33 | if is_regression:
34 | median = np.median(train_labels)
35 | train_labels = (train_labels > median).astype(np.int)
36 | num_labels = np.max(train_labels) + 1
37 | for idx in range(num_labels):
38 | label_logits = train_logits[train_labels == idx]
39 | scores = label_logits.mean(axis=0)
40 | kept = []
41 | for i in np.argsort(-scores):
42 | text = vocab[i]
43 | if not text.startswith("Ġ"):
44 | continue
45 | kept.append(i)
46 | indices.append(kept[:k_likely])
47 | return indices
48 |
49 |
50 | def select_neighbors(distances, k_neighbors, valid):
51 | """Select k nearest neighbors based on distance (filtered to be within the 'valid' set)."""
52 | indices = np.argsort(distances)
53 | neighbors = []
54 | for i in indices:
55 | if i not in valid:
56 | continue
57 | neighbors.append(i)
58 | if k_neighbors > 0:
59 | return neighbors[:k_neighbors]
60 | return neighbors
61 |
62 |
63 | def init(train_logits, train_labels):
64 | global logits, labels
65 | logits = train_logits
66 | labels = train_labels
67 |
68 |
69 | def eval_pairing_acc(pairing):
70 | global logits, labels
71 | label_logits = np.take(logits, pairing, axis=-1)
72 | preds = np.argmax(label_logits, axis=-1)
73 | correct = np.sum(preds == labels)
74 | return correct / len(labels)
75 |
76 |
77 | def eval_pairing_corr(pairing):
78 | global logits, labels
79 | if pairing[0] == pairing[1]:
80 | return -1
81 | label_logits = np.take(logits, pairing, axis=-1)
82 | label_probs = special.softmax(label_logits, axis=-1)[:, 1]
83 | pearson_corr = stats.pearsonr(label_probs, labels)[0]
84 | return pearson_corr
85 |
86 |
87 | def find_labels(
88 | model,
89 | train_logits,
90 | train_labels,
91 | seed_labels=None,
92 | k_likely=1000,
93 | k_neighbors=None,
94 | top_n=-1,
95 | vocab=None,
96 | is_regression=False,
97 | ):
98 | # Get top indices based on conditional likelihood using the LM.
99 | likely_indices = select_likely_words(
100 | train_logits=train_logits,
101 | train_labels=train_labels,
102 | k_likely=k_likely,
103 | vocab=vocab,
104 | is_regression=is_regression)
105 |
106 | logger.info("Top labels (conditional) per class:")
107 | for i, inds in enumerate(likely_indices):
108 | logger.info("\t| Label %d: %s", i, ", ".join([vocab[i] for i in inds[:10]]))
109 |
110 | # Convert to sets.
111 | valid_indices = [set(inds) for inds in likely_indices]
112 |
113 | # If specified, further re-rank according to nearest neighbors of seed labels.
114 | # Otherwise, keep ranking as is (based on conditional likelihood only).
115 | if seed_labels:
116 | assert (vocab is not None)
117 | seed_ids = [vocab.index(l) for l in seed_labels]
118 | vocab_vecs = model.lm_head.decoder.weight.detach().cpu().numpy()
119 | seed_vecs = np.take(vocab_vecs, seed_ids, axis=0)
120 |
121 | # [num_labels, vocab_size]
122 | label_distances = spatial.distance.cdist(seed_vecs, vocab_vecs, metric="cosine")
123 |
124 | # Establish label candidates (as k nearest neighbors).
125 | label_candidates = []
126 | logger.info("Re-ranked by nearest neighbors:")
127 | for i, distances in enumerate(label_distances):
128 | label_candidates.append(select_neighbors(distances, k_neighbors, valid_indices[i]))
129 | logger.info("\t| Label: %s", seed_labels[i])
130 | logger.info("\t| Neighbors: %s", " ".join([vocab[idx] for idx in label_candidates[i]]))
131 | else:
132 | label_candidates = likely_indices
133 |
134 | # Brute-force search all valid pairings.
135 | pairings = list(itertools.product(*label_candidates))
136 |
137 | if is_regression:
138 | eval_pairing = eval_pairing_corr
139 | metric = "corr"
140 | else:
141 | eval_pairing = eval_pairing_acc
142 | metric = "acc"
143 |
144 | # Score each pairing.
145 | pairing_scores = []
146 | with multiprocessing.Pool(initializer=init, initargs=(train_logits, train_labels)) as workers:
147 | with tqdm.tqdm(total=len(pairings)) as pbar:
148 | chunksize = max(10, int(len(pairings) / 1000))
149 | for score in workers.imap(eval_pairing, pairings, chunksize=chunksize):
150 | pairing_scores.append(score)
151 | pbar.update()
152 |
153 | # Take top-n.
154 | best_idx = np.argsort(-np.array(pairing_scores))[:top_n]
155 | best_scores = [pairing_scores[i] for i in best_idx]
156 | best_pairings = [pairings[i] for i in best_idx]
157 |
158 | logger.info("Automatically searched pairings:")
159 | for i, indices in enumerate(best_pairings):
160 | logger.info("\t| %s (%s = %2.2f)", " ".join([vocab[j] for j in indices]), metric, best_scores[i])
161 |
162 | return best_pairings
163 |
--------------------------------------------------------------------------------
/examples/image_classification/README.md:
--------------------------------------------------------------------------------
1 | ## Image classification with vision Transformers
2 |
3 | ### Notes
4 |
5 | `main.py` contains simple training code that enables either full fine-tuning (up to isolated embedding parameters) or
6 | linear probing on CIFAR-10. For image classification, linear probing tends to perform generally better than full
7 | fine-tuning. This is confirmed in my personal experiments and also recent works (e.g., [[1]](https://arxiv.org/pdf/2204.13650.pdf) [[2]](https://arxiv.org/pdf/2205.02973.pdf)).
8 |
9 | [1] De, Soham, et al. "Unlocking high-accuracy differentially private image classification through scale." arXiv preprint arXiv:2204.13650 (2022).
10 |
11 | [2] Mehta, Harsh, et al. "Large scale transfer learning for differentially private image classification." arXiv preprint arXiv:2205.02973 (2022).
12 |
--------------------------------------------------------------------------------
/examples/image_classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/image_classification/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """CIFAR-10 classification with Vi-T."""
16 | import logging
17 |
18 | import fire
19 | import torch
20 | import torch.nn.functional as F
21 | import tqdm
22 | import transformers
23 | from ml_swissknife import utils
24 | from torchvision import transforms
25 |
26 | import private_transformers
27 |
28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29 |
30 |
31 | @torch.no_grad()
32 | def evaluate(loader, model):
33 | model.eval()
34 | xents, zeons = [], []
35 | for i, (images, labels) in enumerate(loader):
36 | images, labels = tuple(t.to(device) for t in (images, labels))
37 | logits = model(pixel_values=images).logits
38 | xents.append(F.cross_entropy(logits, labels, reduction='none'))
39 | zeons.append(logits.argmax(dim=-1).ne(labels).float())
40 | return tuple(torch.cat(lst).mean().item() for lst in (xents, zeons))
41 |
42 |
43 | def main(
44 | model_name_or_path='google/vit-base-patch16-224',
45 | train_batch_size=1000,
46 | per_device_train_batch_size=50,
47 | test_batch_size=500,
48 | epochs=10,
49 | target_epsilon=2,
50 | lr=2e-3,
51 | max_grad_norm=0.1,
52 | linear_probe=True,
53 | ):
54 | gradient_accumulation_steps = train_batch_size // per_device_train_batch_size
55 |
56 | image_transform = transforms.Compose([
57 | transforms.Resize((224, 224)),
58 | transforms.ToTensor(),
59 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
60 | ])
61 | train_loader, test_loader = utils.get_loader(
62 | data_name='cifar10',
63 | task="classification",
64 | train_batch_size=per_device_train_batch_size,
65 | test_batch_size=test_batch_size,
66 | data_aug=False,
67 | train_transform=image_transform,
68 | test_transform=image_transform,
69 | )
70 |
71 | config = transformers.AutoConfig.from_pretrained(model_name_or_path)
72 | config.num_labels = 10
73 | model = transformers.ViTForImageClassification.from_pretrained(
74 | model_name_or_path,
75 | config=config,
76 | ignore_mismatched_sizes=True # Default pre-trained model has 1k classes; we only have 10.
77 | ).to(device)
78 | if linear_probe:
79 | model.requires_grad_(False)
80 | model.classifier.requires_grad_(True)
81 | logging.warning("Linear probe classification head.")
82 | else:
83 | private_transformers.freeze_isolated_params_for_vit(model)
84 | logging.warning("Full fine-tune up to isolated embedding parameters.")
85 |
86 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
87 | privacy_engine = private_transformers.PrivacyEngine(
88 | model,
89 | batch_size=train_batch_size,
90 | sample_size=50000,
91 | epochs=epochs,
92 | max_grad_norm=max_grad_norm,
93 | target_epsilon=target_epsilon,
94 | )
95 | privacy_engine.attach(optimizer)
96 |
97 | train_loss_meter = utils.AvgMeter()
98 | for epoch in range(epochs):
99 | optimizer.zero_grad()
100 | pbar = tqdm.tqdm(enumerate(train_loader, 1), total=len(train_loader))
101 | for global_step, (images, labels) in pbar:
102 | model.train()
103 | images, labels = tuple(t.to(device) for t in (images, labels))
104 | logits = model(pixel_values=images).logits
105 | loss = F.cross_entropy(logits, labels, reduction="none")
106 | train_loss_meter.step(loss.mean().item())
107 | if global_step % gradient_accumulation_steps == 0:
108 | optimizer.step(loss=loss)
109 | optimizer.zero_grad()
110 | else:
111 | optimizer.virtual_step(loss=loss)
112 | pbar.set_description(f"Train loss running average: {train_loss_meter.item():.4f}")
113 | avg_xent, avg_zeon = evaluate(test_loader, model)
114 | logging.warning(
115 | f"Epoch: {epoch}, average cross ent loss: {avg_xent:.4f}, average zero one loss: {avg_zeon:.4f}"
116 | )
117 |
118 |
119 | if __name__ == "__main__":
120 | fire.Fire(main)
121 |
--------------------------------------------------------------------------------
/examples/table2text/README.md:
--------------------------------------------------------------------------------
1 | ## Reproducing results for table-to-text generation
2 |
3 | ### Requirements
4 |
5 | In addition to requirements of the `private-transformers` package, install additional requirements by running the
6 | following from the `examples` folder of this repo:
7 |
8 | ```bash
9 | pip install -r table2text/requirements.txt
10 | ```
11 |
12 | ### Getting the data
13 |
14 | We host the datasets for E2E and DART on Google drive at
15 | this [link](https://drive.google.com/file/d/1Re1wyUPtS3IalSsVVJhSg2sn8UNa7DM7/view?usp=sharing). Download and unzip the
16 | folder to a reasonable location. The unzipped folder is named `prefix-tuning`, since it's adapted from data used in the
17 | prefix-tuning paper.
18 |
19 | ### Running
20 |
21 | Use the `run.sh` script in the folder.
22 |
23 | Supply at least 3 arguments:
24 |
25 | - `--output_dir`: path to a folder where results will be written
26 | - `--data_folder`: path to the unzipped data folder
27 | - `--task_mode`: name of task; one of `e2e` and `dart`
28 |
29 | For instance, to fine-tune GPT-2 on E2E at ε = 8, run the following from the `examples` folder of this repo:
30 |
31 | ```bash
32 | bash table2text/run.sh "e2e"
33 | ```
34 |
35 | The script by default uses ghost clipping, and the micro batch size is tweaked so that things should run smoothly even
36 | on a Titan Xp with 12Gigs of VRAM. For E2E, the run-time of this script on an RTX 3090 is roughly less than one and a
37 | half hours.
38 |
39 | Feel free to toggle other arguments like `target_epsilon` and `model_name_or_path` of the `run.sh` script to use
40 | different privacy levels and models. The other hyperparameters should still mostly work for workloads with varied model
41 | and privacy level.
42 |
43 | ### Automatic evaluation
44 |
45 | While the runs automatically decode from the model (via beam-search) once in a while during training, the script does
46 | not run any evaluation on top of the generations. For our paper, we ran the
47 | official [e2e-metrics](https://github.com/tuetschek/e2e-metrics) for evaluating common metrics (e.g., BLEU, ROUGE) on
48 | E2E, and used the [evaluation pipeline in the GEM-benchmark](https://github.com/GEM-benchmark/GEM-metrics) for DART.
49 |
--------------------------------------------------------------------------------
/examples/table2text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/table2text/compiled_args.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Compilation of all the arguments."""
16 | import logging
17 | import os
18 | import sys
19 | from dataclasses import dataclass, field
20 | from typing import Optional
21 |
22 | import transformers
23 |
24 | MODEL_CONFIG_CLASSES = list(transformers.MODEL_WITH_LM_HEAD_MAPPING.keys())
25 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
26 |
27 | TRUE_TAGS = ('y', 'yes', 't', 'true')
28 |
29 |
30 | # See all possible arguments in src/transformers/training_args.py
31 | # or by passing the --help flag to this script.
32 | # We now keep distinct sets of args, for a cleaner separation of concerns.
33 | @dataclass
34 | class ModelArguments:
35 | """
36 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
37 | """
38 | model_name_or_path: Optional[str] = field(
39 | default=None,
40 | metadata={
41 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from "
42 | "scratch."
43 | },
44 | )
45 | model_type: Optional[str] = field(
46 | default=None,
47 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
48 | )
49 | config_name: Optional[str] = field(
50 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
51 | )
52 | tokenizer_name: Optional[str] = field(
53 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
54 | )
55 | cache_dir: Optional[str] = field(
56 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
57 | )
58 |
59 | static_lm_head: str = field(default='no')
60 | static_embedding: str = field(default='no')
61 | attention_only: str = field(default="no")
62 |
63 | def __post_init__(self):
64 | self.static_lm_head = self.static_lm_head.lower() in TRUE_TAGS
65 | self.static_embedding = self.static_embedding.lower() in TRUE_TAGS
66 | self.attention_only = self.attention_only.lower() in TRUE_TAGS
67 |
68 |
69 | @dataclass
70 | class DataTrainingArguments:
71 | """
72 | Arguments pertaining to what data we are going to input our model for training and eval.
73 | """
74 | data_folder: Optional[str] = field(default=None, metadata={"help": "Path to folder with all the data."})
75 |
76 | # Useful for truncating the dataset.
77 | max_train_examples: Optional[int] = field(default=sys.maxsize)
78 | max_valid_examples: Optional[int] = field(default=sys.maxsize)
79 | max_eval_examples: Optional[int] = field(default=sys.maxsize)
80 |
81 | line_by_line: bool = field(
82 | default=True,
83 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
84 | )
85 | task_mode: Optional[str] = field(
86 | default=None, metadata={"help": "The name of the task."}
87 | )
88 | format_mode: Optional[str] = field(
89 | default='cat', metadata={"help": "The mode of data2text format (cat, peek, nopeek)"}
90 | )
91 | max_source_length: Optional[int] = field(
92 | default=512, metadata={"help": "the max source length of summarization data. "}
93 | )
94 | train_max_target_length: Optional[int] = field(
95 | default=100, metadata={"help": "the max target length for training data. "}
96 | )
97 | val_max_target_length: Optional[int] = field(
98 | default=100, metadata={"help": "the max target length for dev data. "}
99 | )
100 | block_size: int = field(
101 | default=-1,
102 | metadata={
103 | "help": "Optional input sequence length after tokenization."
104 | "The training dataset will be truncated in block of this size for training."
105 | "Default to the model max input length for single sentence inputs (take into account special "
106 | "tokens)."
107 | },
108 | )
109 | overwrite_cache: bool = field(
110 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
111 | )
112 | max_seq_len: int = field(default=sys.maxsize)
113 |
114 | def __post_init__(self):
115 | if self.data_folder is not None:
116 | logging.warning(f'Overriding dataset paths using those given in `data_folder`')
117 |
118 | if self.task_mode == "e2e":
119 | self.train_data_file = os.path.join(self.data_folder, 'src1_train.txt')
120 | self.valid_data_file = os.path.join(self.data_folder, 'src1_valid.txt')
121 | self.eval_data_file = os.path.join(self.data_folder, 'src1_test.txt')
122 |
123 | self.train_prompt_file = os.path.join(self.data_folder, 'prompts_train.txt')
124 | self.val_prompt_file = os.path.join(self.data_folder, 'prompts_valid.txt')
125 | self.eval_prompt_file = os.path.join(self.data_folder, 'prompts_test.txt')
126 |
127 | elif self.task_mode == "dart":
128 | self.train_data_file = os.path.join(self.data_folder, 'dart-v1.1.1-full-train.json')
129 | self.valid_data_file = os.path.join(self.data_folder, 'dart-v1.1.1-full-dev.json')
130 | self.eval_data_file = os.path.join(self.data_folder, 'dart-v1.1.1-full-test.json')
131 |
132 | self.train_prompt_file = os.path.join(self.data_folder, 'prompts_train.txt')
133 | self.val_prompt_file = os.path.join(self.data_folder, 'prompts_valid.txt')
134 | self.eval_prompt_file = os.path.join(self.data_folder, 'prompts_test.txt')
135 |
136 |
137 | @dataclass
138 | class TrainingArguments(transformers.TrainingArguments):
139 | max_eval_batches: int = field(default=-1, metadata={"help": "Maximum number of evaluation steps to run."})
140 | max_generations: int = field(default=sys.maxsize)
141 | max_generations_train: int = field(default=10)
142 | max_generations_valid: int = field(default=10)
143 | skip_generation: str = field(default="no")
144 |
145 | ema_model_averaging: str = field(default="no")
146 | ema_model_gamma: float = field(default=0.99)
147 | ema_model_start_from: int = field(default=1000)
148 | lr_decay: str = field(default="yes")
149 | eval_epochs: int = field(default=10)
150 |
151 | evaluate_during_training: str = field(
152 | default="yes",
153 | metadata={"help": "Run evaluation during training at each logging step."},
154 | )
155 | evaluate_before_training: str = field(
156 | default="yes",
157 | metadata={"help": "Run evaluation before training."},
158 | )
159 | save_at_last: str = field(default="no", metadata={"help": "Save at the end of training."})
160 |
161 | def __post_init__(self):
162 | super(TrainingArguments, self).__post_init__()
163 | self.skip_generation = self.skip_generation.lower() in ('y', 'yes')
164 | self.ema_model_averaging = (self.ema_model_averaging.lower() in ('y', 'yes'))
165 | self.lr_decay = (self.lr_decay.lower() in ('y', 'yes'))
166 | self.evaluate_during_training = (self.evaluate_during_training in ('y', 'yes'))
167 | self.evaluate_before_training = (self.evaluate_before_training in ('y', 'yes'))
168 | self.save_at_last = (self.save_at_last in ('y', 'yes'))
169 |
170 |
171 | @dataclass
172 | class PrivacyArguments:
173 | """Arguments for differentially private training."""
174 | per_example_max_grad_norm: float = field(
175 | default=.1, metadata={
176 | "help": "Clipping 2-norm of per-sample gradients."
177 | }
178 | )
179 | noise_multiplier: float = field(
180 | default=None, metadata={
181 | "help": "Standard deviation of noise added for privacy; if `target_epsilon` is specified, "
182 | "use the one searched based budget"
183 | }
184 | )
185 | target_epsilon: float = field(
186 | default=None, metadata={
187 | "help": "Privacy budget; if `None` use the noise multiplier specified."
188 | }
189 | )
190 | target_delta: float = field(
191 | default=None, metadata={
192 | "help": "Lax probability in approximate differential privacy; if `None` use 1 / len(train_data)."
193 | }
194 | )
195 | accounting_mode: str = field(
196 | default="rdp", metadata={"help": "One of `rdp`, `glw`, `all`."}
197 | )
198 | non_private: str = field(default="no")
199 | clipping_mode: str = field(default="default")
200 |
201 | def __post_init__(self):
202 | self.non_private = self.non_private.lower() in ('y', 'yes')
203 |
204 |
205 | @dataclass
206 | class AuxiliaryArguments:
207 | eval_spectrum: str = field(default="no")
208 | max_spectrum_batches: int = field(default=100)
209 | max_lanczos_iter: int = field(default=100)
210 |
211 | store_grads: str = field(default="no")
212 | orthogonal_projection_path: Optional[str] = field(default=None)
213 | orthogonal_projection_rank: int = field(default=100)
214 |
215 | def __post_init__(self):
216 | self.eval_spectrum = self.eval_spectrum.lower() in TRUE_TAGS # noqa
217 | self.store_grads = self.store_grads.lower() in TRUE_TAGS # noqa
218 |
--------------------------------------------------------------------------------
/examples/table2text/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/examples/table2text/data_utils/data_collator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
17 |
18 | import torch
19 | from torch.nn.utils.rnn import pad_sequence
20 |
21 | from transformers.tokenization_utils import PreTrainedTokenizer
22 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy
23 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
24 |
25 |
26 | InputDataClass = NewType("InputDataClass", Any)
27 |
28 | """
29 | A DataCollator is a function that takes a list of samples from a Dataset
30 | and collate them into a batch, as a dictionary of Tensors.
31 | """
32 | DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
33 |
34 |
35 | @dataclass
36 | class DataCollatorForData2TextLanguageModeling:
37 | """
38 | Data collator used for language modeling.
39 | - collates batches of tensors, honoring their tokenizer's pad_token
40 | - preprocesses batches for masked language modeling
41 | """
42 | tokenizer: PreTrainedTokenizer
43 | mlm: bool = True
44 | format_mode: str = 'cat'
45 | mlm_probability: float = 0.15
46 |
47 | def __call__(
48 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
49 | ) -> Dict[str, torch.Tensor]:
50 | if isinstance(examples[0], (dict, BatchEncoding)):
51 | examples = [e["input_ids"] for e in examples]
52 | input_ids, labels, src, tgt, cate = zip(*examples)
53 | if self.mlm:
54 | inputs, labels = self.mask_tokens(batch)
55 | return {"input_ids": inputs, "labels": labels}
56 | else:
57 | if self.format_mode == 'cat':
58 | mode_input = 3
59 | elif self.format_mode == 'peek':
60 | mode_input = 1
61 | elif self.format_mode == 'nopeek':
62 | mode_input = 2
63 | elif self.format_mode == 'infix':
64 | mode_input = 4
65 |
66 | # mode_input = 1 # means that we take the input again.
67 | # mode_input = 2 # means that we do not peek at src again.
68 | # mode_input = 3 # means that we look at the categories, and see the input again.
69 |
70 | if mode_input == 1:
71 | # input, batch
72 | batch = self._tensorize_batch(input_ids)
73 | labels = self._tensorize_batch(labels)
74 | src = self._tensorize_batch(src)
75 | cate_batch, cate_attn = None, None
76 | # tgt = self._tensorize_batch(tgt)
77 | elif mode_input == 2:
78 | # nopeek.
79 | batch = self._tensorize_batch(tgt)
80 | labels = batch.clone()
81 | src = self._tensorize_batch(src)
82 | cate_batch, cate_attn = None, None
83 | elif mode_input == 3:
84 | batch = self._tensorize_batch(input_ids)
85 | labels = self._tensorize_batch(labels)
86 | src = self._tensorize_batch(cate)
87 | cate_batch, cate_attn = None, None
88 | elif mode_input == 4:
89 | batch = self._tensorize_batch(tgt)
90 | labels = batch.clone()
91 | src = self._tensorize_batch(src)
92 |
93 | cate_batch = self._tensorize_batch(cate)
94 | cate_attn = (cate_batch != self.tokenizer.pad_token_id)
95 |
96 | labels[labels == self.tokenizer.pad_token_id] = -100 # tgt
97 | src_attn = (src != self.tokenizer.pad_token_id) # src
98 | tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt
99 |
100 | if cate_batch is None:
101 | return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn,
102 | 'src':src}
103 | else:
104 | return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn': tgt_attn,
105 | 'src': src, "cate_batch":cate_batch, "cate_attn":cate_attn}
106 |
107 | def _tensorize_batch(
108 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
109 | ) -> torch.Tensor:
110 | # In order to accept both lists of lists and lists of Tensors
111 | if isinstance(examples[0], (list, tuple)):
112 | examples = [torch.tensor(e, dtype=torch.long) for e in examples]
113 | length_of_first = examples[0].size(0)
114 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
115 | if are_tensors_same_length:
116 | return torch.stack(examples, dim=0)
117 | else:
118 | if self.tokenizer._pad_token is None:
119 | raise ValueError(
120 | "You are attempting to pad samples but the tokenizer you are using"
121 | f" ({self.tokenizer.__class__.__name__}) does not have one."
122 | )
123 | return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
124 |
125 | def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
126 | """
127 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
128 | """
129 |
130 | if self.tokenizer.mask_token is None:
131 | raise ValueError(
132 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
133 | )
134 |
135 | labels = inputs.clone()
136 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
137 | probability_matrix = torch.full(labels.shape, self.mlm_probability)
138 | special_tokens_mask = [
139 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
140 | ]
141 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
142 | if self.tokenizer._pad_token is not None:
143 | padding_mask = labels.eq(self.tokenizer.pad_token_id)
144 | probability_matrix.masked_fill_(padding_mask, value=0.0)
145 | masked_indices = torch.bernoulli(probability_matrix).bool()
146 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
147 |
148 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
149 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
150 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
151 |
152 | # 10% of the time, we replace masked input tokens with random word
153 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
154 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
155 | inputs[indices_random] = random_words[indices_random]
156 |
157 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
158 | return inputs, labels
159 |
160 |
161 | @dataclass
162 | class DataCollatorForSumLanguageModeling:
163 | """
164 | Data collator used for language modeling.
165 | - collates batches of tensors, honoring their tokenizer's pad_token
166 | - preprocesses batches for masked language modeling
167 | """
168 | tokenizer: PreTrainedTokenizer
169 | mlm: bool = True
170 | format_mode: str = 'cat'
171 | mlm_probability: float = 0.15
172 |
173 | def __call__(
174 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
175 | ) -> Dict[str, torch.Tensor]:
176 | if isinstance(examples[0], (dict, BatchEncoding)):
177 | examples = [e["input_ids"] for e in examples]
178 | # print(examples[0])
179 | # print(len(examples))
180 | input_ids, labels, src, tgt = zip(*examples)
181 | # print(len(input_ids), len(labels), len(weights))
182 | if self.mlm:
183 | inputs, labels = self.mask_tokens(batch)
184 | return {"input_ids": inputs, "labels": labels}
185 | else:
186 |
187 | # print(self.format_mode)
188 |
189 | if self.format_mode == 'peek' or self.format_mode == 'cat':
190 | mode_input = 1
191 | elif self.format_mode == 'nopeek':
192 | assert False, 'should use format_mode = peek or cat.'
193 | mode_input = 2
194 | elif self.format_mode == 'infix':
195 | assert False, 'should use format_mode = peek or cat.'
196 | mode_input = 4
197 |
198 | # mode_input = 1 # means that we take the input again.
199 | # mode_input = 2 # means that we do not peek at src again.
200 | # mode_input = 3 # means that we look at the categories, and see the input again.
201 |
202 | # print(self.format_mode, mode_input)
203 |
204 | if mode_input == 1:
205 | # input, batch
206 | batch = self._tensorize_batch(input_ids)
207 | labels = self._tensorize_batch(labels)
208 | src = self._tensorize_batch(src)
209 |
210 | labels[labels == self.tokenizer.pad_token_id] = -100 # tgt
211 | src_attn = (src != self.tokenizer.pad_token_id) # src
212 | tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt
213 |
214 | return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn,
215 | 'src':src}
216 |
217 |
218 | def _tensorize_batch(
219 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
220 | ) -> torch.Tensor:
221 | # In order to accept both lists of lists and lists of Tensors
222 | if isinstance(examples[0], (list, tuple)):
223 | examples = [torch.tensor(e, dtype=torch.long) for e in examples]
224 | length_of_first = examples[0].size(0)
225 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
226 | if are_tensors_same_length:
227 | return torch.stack(examples, dim=0)
228 | else:
229 | if self.tokenizer._pad_token is None:
230 | raise ValueError(
231 | "You are attempting to pad samples but the tokenizer you are using"
232 | f" ({self.tokenizer.__class__.__name__}) does not have one."
233 | )
234 | return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
235 |
--------------------------------------------------------------------------------
/examples/table2text/decoding_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for generation."""
16 | import logging
17 | import sys
18 | from typing import Optional
19 |
20 | import tqdm
21 | import transformers
22 |
23 |
24 | def generate(
25 | model: transformers.PreTrainedModel,
26 | tokenizer: transformers.PreTrainedTokenizer,
27 | loader=None,
28 | prompt_dataset=None,
29 | max_length=100,
30 | min_length=5,
31 | top_k=0,
32 | top_p=0.9, # Only filter with top_p.
33 | repetition_penalty=1,
34 | do_sample=False,
35 | num_beams=5,
36 | bad_words_ids=None,
37 | dummy_token_id=-100, # Used as mask.
38 | num_return_sequences=1,
39 | max_generations=sys.maxsize,
40 | device=None,
41 | padding_token="[PAD]",
42 | **kwargs,
43 | ):
44 | assert not model.training, "Generation must be when `model` is in eval mode."
45 | if kwargs:
46 | logging.warning(f"Unknown kwargs: {kwargs}")
47 |
48 | # These are linebreaks; generating these will mess up the evaluation, since those files assume one example per-line.
49 | if bad_words_ids is None:
50 | bad_words_ids = [[628], [198]]
51 | if padding_token in tokenizer.get_vocab():
52 | bad_words_ids.append(tokenizer.encode(padding_token))
53 |
54 | kwargs = dict(
55 | model=model,
56 | tokenizer=tokenizer,
57 | max_length=max_length,
58 | min_length=min_length,
59 | top_k=top_k,
60 | top_p=top_p,
61 | repetition_penalty=repetition_penalty,
62 | do_sample=do_sample,
63 | num_beams=num_beams,
64 | bad_words_ids=bad_words_ids,
65 | dummy_token_id=dummy_token_id,
66 | num_return_sequences=num_return_sequences,
67 | max_generations=max_generations,
68 | device=device,
69 | padding_token=padding_token,
70 | )
71 | if loader is not None:
72 | result = _generate_with_loader(loader=loader, **kwargs)
73 | elif prompt_dataset is not None:
74 | result = _generate_with_prompt_dataset(prompt_dataset=prompt_dataset, **kwargs)
75 | else:
76 | raise ValueError(f"`loader` and `prompt_dataset` cannot both be `None`.")
77 |
78 | return result
79 |
80 |
81 | def _generate_with_loader(
82 | loader,
83 |
84 | model,
85 | tokenizer: transformers.PreTrainedTokenizer,
86 | max_length,
87 | min_length,
88 | top_k,
89 | top_p,
90 | repetition_penalty,
91 | do_sample,
92 | num_beams,
93 | bad_words_ids,
94 | dummy_token_id,
95 | num_return_sequences,
96 | max_generations,
97 | device,
98 | padding_token,
99 | ):
100 | references = []
101 | full_generations = [] # Sentences including the prompt part.
102 | unstripped_generations = []
103 | generations = []
104 |
105 | stop_generation = False
106 | for batch_idx, batch in tqdm.tqdm(enumerate(loader), desc="generation"):
107 | if stop_generation:
108 | break
109 |
110 | batch_input_ids, batch_labels = batch["input_ids"], batch["labels"]
111 | # e.g., inputs_ids may be [[95, 123, 32], [198, 19, 120]], and
112 | # labels may be [[-100, 123, 32], [-100, -100, 120]
113 |
114 | for input_ids, labels in zip(batch_input_ids, batch_labels):
115 | if stop_generation:
116 | break
117 |
118 | # Find the first pad token and end the sentence from there!
119 | if padding_token in tokenizer.get_vocab():
120 | pad_positions, = (
121 | input_ids == tokenizer.encode(padding_token, return_tensors="pt").squeeze()
122 | ).nonzero(as_tuple=True)
123 | # Some sentences might have padding; others might not.
124 | if pad_positions.numel() == 0:
125 | first_pad_position = None
126 | else:
127 | first_pad_position = pad_positions[0]
128 | reference_str: str = tokenizer.decode(input_ids[:first_pad_position], clean_up_tokenization_spaces=True)
129 | else:
130 | reference_str: str = tokenizer.decode(input_ids, clean_up_tokenization_spaces=True)
131 | references.append(reference_str)
132 |
133 | # Find the first non- -100 position. Note there are trailing -100s.
134 | non_prompt_positions, = (labels != dummy_token_id).nonzero(as_tuple=True)
135 | first_non_prompt_position = non_prompt_positions[0].item()
136 | prompt_len = first_non_prompt_position
137 | prompt_ids = input_ids[:prompt_len]
138 |
139 | output_ids = model.generate(
140 | input_ids=prompt_ids[None, ...].to(device),
141 | max_length=max_length + prompt_len, # This cannot be a 0-D tensor!
142 | min_length=min_length,
143 | top_k=top_k,
144 | top_p=top_p,
145 | repetition_penalty=repetition_penalty,
146 | do_sample=do_sample,
147 | bad_words_ids=bad_words_ids,
148 | num_return_sequences=num_return_sequences,
149 | num_beams=num_beams,
150 | pad_token_id=tokenizer.eos_token_id, # Stop the stupid logging...
151 | )
152 | output_ids = output_ids.squeeze(dim=0) # Throw away batch dimension.
153 |
154 | whole_str: str = tokenizer.decode(output_ids, clean_up_tokenization_spaces=True)
155 | prompt_str: str = tokenizer.decode(prompt_ids, clean_up_tokenization_spaces=True)
156 | output_str: str = whole_str[len(prompt_str):]
157 |
158 | full_generations.append(whole_str)
159 | del whole_str, prompt_str
160 |
161 | # Remove potential eos_token at the end.
162 | eos_position: Optional[int] = output_str.find(tokenizer.eos_token)
163 | if eos_position == -1: # Didn't generate eos_token; that's okay -- just skip!
164 | eos_position = None
165 | output_str = output_str[:eos_position]
166 | unstripped_generations.append(output_str)
167 |
168 | # Removing leading and trailing spaces.
169 | output_str = output_str.strip()
170 |
171 | generations.append(output_str)
172 |
173 | if len(generations) >= max_generations:
174 | stop_generation = True
175 |
176 | return full_generations, unstripped_generations, generations, references
177 |
178 |
179 | def _generate_with_prompt_dataset(
180 | prompt_dataset,
181 |
182 | model,
183 | tokenizer,
184 | max_length,
185 | min_length,
186 | top_k,
187 | top_p,
188 | repetition_penalty,
189 | do_sample,
190 | num_beams,
191 | bad_words_ids,
192 | dummy_token_id,
193 | num_return_sequences,
194 | max_generations,
195 | device,
196 | padding_token,
197 | ):
198 | references = []
199 | full_generations = [] # Sentences including the prompt part.
200 | unstripped_generations = []
201 | generations = []
202 |
203 | stop_generation = False
204 | for input_ids in tqdm.tqdm(prompt_dataset, desc="generation"):
205 | if stop_generation:
206 | break
207 |
208 | prompt_len = len(input_ids[0])
209 | output_ids = model.generate(
210 | input_ids=input_ids.to(device),
211 | max_length=max_length + prompt_len, # This cannot be a 0-D tensor!
212 | min_length=min_length,
213 | top_k=top_k,
214 | top_p=top_p,
215 | repetition_penalty=repetition_penalty,
216 | do_sample=do_sample,
217 | bad_words_ids=bad_words_ids,
218 | num_return_sequences=num_return_sequences,
219 | num_beams=num_beams,
220 | pad_token_id=tokenizer.eos_token_id, # Stop the stupid logging...
221 | )
222 | output_ids = output_ids.squeeze(dim=0) # Throw away batch dimension.
223 | input_ids = input_ids.squeeze(dim=0)
224 |
225 | whole_str: str = tokenizer.decode(output_ids, clean_up_tokenization_spaces=True)
226 | prompt_str: str = tokenizer.decode(input_ids, clean_up_tokenization_spaces=True)
227 | output_str: str = whole_str[len(prompt_str):]
228 |
229 | full_generations.append(whole_str)
230 | del whole_str, prompt_str
231 |
232 | # Remove potential eos_token at the end.
233 | eos_position: Optional[int] = output_str.find(tokenizer.eos_token)
234 | if eos_position == -1: # Didn't generate eos_token; that's okay -- just skip!
235 | eos_position = None
236 | output_str = output_str[:eos_position]
237 | unstripped_generations.append(output_str)
238 |
239 | # Removing leading and trailing spaces.
240 | output_str = output_str.strip()
241 |
242 | generations.append(output_str)
243 |
244 | if len(generations) >= max_generations:
245 | stop_generation = True
246 | return full_generations, unstripped_generations, generations, references
247 |
--------------------------------------------------------------------------------
/examples/table2text/density.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | # Copyright 2019 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Code for converting Lanczos outputs to densities."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 |
24 | import numpy as np
25 |
26 |
27 | def eigv_to_density(eig_vals, all_weights=None, grids=None,
28 | grid_len=10000, sigma_squared=None, grid_expand=1e-2):
29 | """Compute the smoothed spectral density from a set of eigenvalues.
30 |
31 | Convolves the given eigenvalues with a Gaussian kernel, weighting the values
32 | by all_weights (or uniform weighting if all_weights is None). Example output
33 | can be seen in Figure 1 of https://arxiv.org/pdf/1901.10159.pdf. Visualizing
34 | the estimated density can be done by calling plt.plot(grids, density). There
35 | is likely not a best value of sigma_squared that works for all use cases,
36 | so it is recommended to try multiple values in the range [1e-5,1e-1].
37 |
38 | Args:
39 | eig_vals: Array of shape [num_draws, order]
40 | all_weights: Array of shape [num_draws, order], if None then weights will be
41 | taken to be uniform.
42 | grids: Array of shape [grid_len], the smoothed spectrum will be plotted
43 | in the interval [grids[0], grids[-1]]. If None then grids will be
44 | computed based on max and min eigenvalues and grid length.
45 | grid_len: Integer specifying number of grid cells to use, only used if
46 | grids is None
47 | sigma_squared: Scalar. Controls the smoothing of the spectrum estimate.
48 | If None, an appropriate value is inferred.
49 | grid_expand: Controls the window of values that grids spans.
50 | grids[0] = smallest eigenvalue - grid_expand.
51 | grids[-1] = largest_eigenvalue + grid_expand.
52 |
53 | Returns:
54 | density: Array of shape [grid_len], the estimated density, averaged over
55 | all draws.
56 | grids: Array of shape [grid_len]. The values the density is estimated on.
57 | """
58 | if all_weights is None:
59 | all_weights = np.ones(eig_vals.shape) * 1.0 / float(eig_vals.shape[1])
60 | num_draws = eig_vals.shape[0]
61 |
62 | lambda_max = np.nanmean(np.max(eig_vals, axis=1), axis=0) + grid_expand
63 | lambda_min = np.nanmean(np.min(eig_vals, axis=1), axis=0) - grid_expand
64 |
65 | if grids is None:
66 | assert grid_len is not None, 'grid_len is required if grids is None.'
67 | grids = np.linspace(lambda_min, lambda_max, num=grid_len)
68 |
69 | grid_len = grids.shape[0]
70 | if sigma_squared is None:
71 | sigma = 10 ** -5 * max(1, (lambda_max - lambda_min))
72 | else:
73 | sigma = sigma_squared * max(1, (lambda_max - lambda_min))
74 |
75 | density_each_draw = np.zeros((num_draws, grid_len))
76 | for i in range(num_draws):
77 |
78 | if np.isnan(eig_vals[i, 0]):
79 | raise ValueError('tridaig has nan values.')
80 | else:
81 | for j in range(grid_len):
82 | x = grids[j]
83 | vals = _kernel(eig_vals[i, :], x, sigma)
84 | density_each_draw[i, j] = np.sum(vals * all_weights[i, :])
85 | density = np.nanmean(density_each_draw, axis=0)
86 | norm_fact = np.sum(density) * (grids[1] - grids[0])
87 | density = density / norm_fact
88 | return density, grids
89 |
90 |
91 | def tridiag_to_eigv(tridiag_list):
92 | """Preprocess the tridiagonal matrices for density estimation.
93 |
94 | Args:
95 | tridiag_list: Array of shape [num_draws, order, order] List of the
96 | tridiagonal matrices computed from running num_draws independent runs
97 | of lanczos. The output of this function can be fed directly into
98 | eigv_to_density.
99 |
100 | Returns:
101 | eig_vals: Array of shape [num_draws, order]. The eigenvalues of the
102 | tridiagonal matricies.
103 | all_weights: Array of shape [num_draws, order]. The weights associated with
104 | each eigenvalue. These weights are to be used in the kernel density
105 | estimate.
106 | """
107 | # Calculating the node / weights from Jacobi matrices.
108 | num_draws = len(tridiag_list)
109 | num_lanczos = tridiag_list[0].shape[0]
110 | eig_vals = np.zeros((num_draws, num_lanczos))
111 | all_weights = np.zeros((num_draws, num_lanczos))
112 | for i in range(num_draws):
113 | nodes, evecs = np.linalg.eigh(tridiag_list[i])
114 | index = np.argsort(nodes)
115 | nodes = nodes[index]
116 | evecs = evecs[:, index]
117 | eig_vals[i, :] = nodes
118 | all_weights[i, :] = evecs[0] ** 2
119 | return eig_vals, all_weights
120 |
121 |
122 | def tridiag_to_density(tridiag_list, sigma_squared=1e-5, grid_len=10000):
123 | """This function estimates the smoothed density from the output of lanczos.
124 |
125 | Args:
126 | tridiag_list: Array of shape [num_draws, order, order] List of the
127 | tridiagonal matrices computed from running num_draws independent runs
128 | of lanczos.
129 | sigma_squared: Controls the smoothing of the density.
130 | grid_len: Controls the granularity of the density.
131 |
132 | Returns:
133 | density: Array of size [grid_len]. The smoothed density estimate averaged
134 | over all num_draws.
135 | grids: Array of size [grid_len]. The values the density estimate is on.
136 | """
137 | eig_vals, all_weights = tridiag_to_eigv(tridiag_list)
138 | density, grids = eigv_to_density(eig_vals, all_weights,
139 | grid_len=grid_len,
140 | sigma_squared=sigma_squared)
141 | return density, grids
142 |
143 |
144 | def _kernel(x, x0, variance):
145 | """Point estimate of the Gaussian kernel.
146 |
147 | This function computes the Gaussian kernel for
148 | C exp(-(x - x0) ^2 /(2 * variance)) where C is the appropriate normalization.
149 | variance should be a list of length 1. Either x0 or x should be a scalar. Only
150 | one of the x or x0 can be a numpy array.
151 |
152 | Args:
153 | x: Can be either scalar or array of shape [order]. Points to estimate
154 | the kernel on.
155 | x0: Scalar. Mean of the kernel.
156 | variance: Scalar. Variance of the kernel.
157 |
158 | Returns:
159 | point_estimate: A scalar corresponding to
160 | C exp(-(x - x0) ^2 /(2 * variance)).
161 | """
162 | coeff = 1.0 / np.sqrt(2 * math.pi * variance)
163 | val = -(x0 - x) ** 2
164 | val = val / (2.0 * variance)
165 | val = np.exp(val)
166 | point_estimate = coeff * val
167 | return point_estimate
168 |
--------------------------------------------------------------------------------
/examples/table2text/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Miscellaneous utilities.
16 |
17 | Mostly bespoke data loaders at the moment.
18 | """
19 |
20 | from transformers import (
21 | DataCollatorForLanguageModeling,
22 | DataCollatorForPermutationLanguageModeling,
23 | PreTrainedTokenizer
24 | )
25 |
26 | from .compiled_args import DataTrainingArguments
27 | from .data_utils.data_collator import DataCollatorForData2TextLanguageModeling
28 | from .data_utils.language_modeling import LineByLineE2ETextDataset, LineByLineTriplesTextDataset
29 |
30 |
31 | def get_dataset_with_path(
32 | data_args: DataTrainingArguments,
33 | tokenizer: PreTrainedTokenizer,
34 | file_path: str,
35 | max_examples: int,
36 | **_,
37 | ):
38 | if data_args.line_by_line:
39 | if data_args.task_mode == 'e2e':
40 | dataset = LineByLineE2ETextDataset(
41 | tokenizer=tokenizer,
42 | file_path=file_path,
43 | block_size=data_args.block_size,
44 | bos_tok=tokenizer.bos_token,
45 | eos_tok=tokenizer.eos_token,
46 | max_seq_len=data_args.max_seq_len,
47 | max_examples=max_examples,
48 | )
49 | elif data_args.task_mode == 'dart':
50 | dataset = LineByLineTriplesTextDataset(
51 | tokenizer=tokenizer,
52 | file_path=file_path,
53 | block_size=data_args.block_size,
54 | bos_tok=tokenizer.bos_token,
55 | eos_tok=tokenizer.eos_token,
56 | max_seq_len=data_args.max_seq_len,
57 | max_examples=max_examples,
58 | )
59 | else:
60 | raise ValueError(f"Unknown `args.task_mode`: {data_args.task_mode}")
61 |
62 | else:
63 | raise ValueError("table2text task don't support anything other than line_by_line!")
64 | return dataset
65 |
66 |
67 | def get_prompt_dataset(file_path, tokenizer):
68 | with open(file_path, 'r') as f:
69 | lines = f.readlines()
70 | encoded_lines = [
71 | tokenizer.encode(line.strip(), add_special_tokens=False, return_tensors="pt")
72 | for line in lines
73 | ]
74 | return encoded_lines
75 |
76 |
77 | def get_all_datasets(config, tokenizer, data_args, model_args, **_):
78 | kwargs = dict(data_args=data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir)
79 | train_dataset = get_dataset_with_path(
80 | **kwargs, file_path=data_args.train_data_file, max_examples=data_args.max_train_examples
81 | )
82 | valid_dataset = get_dataset_with_path(
83 | **kwargs, file_path=data_args.valid_data_file, max_examples=data_args.max_valid_examples
84 | )
85 | eval_dataset = get_dataset_with_path(
86 | **kwargs, file_path=data_args.eval_data_file, max_examples=data_args.max_eval_examples
87 | )
88 |
89 | if config.model_type == "xlnet":
90 | data_collator = DataCollatorForPermutationLanguageModeling(
91 | tokenizer=tokenizer,
92 | plm_probability=data_args.plm_probability,
93 | max_span_length=data_args.max_span_length,
94 | )
95 | else:
96 | if data_args.task_mode == 'e2e' or data_args.task_mode == 'dart':
97 | data_collator = DataCollatorForData2TextLanguageModeling(
98 | tokenizer=tokenizer, mlm=False, format_mode=data_args.format_mode
99 | )
100 | else:
101 | data_collator = DataCollatorForLanguageModeling(
102 | tokenizer=tokenizer, mlm=False,
103 | )
104 |
105 | return train_dataset, valid_dataset, eval_dataset, data_collator
106 |
--------------------------------------------------------------------------------
/examples/table2text/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from torch import nn
17 | from transformers import GPT2PreTrainedModel, GPT2LMHeadModel
18 |
19 |
20 | class _View(nn.Module):
21 | def __init__(self, shape):
22 | super(_View, self).__init__()
23 | self.shape = shape
24 |
25 | def forward(self, x):
26 | return x.reshape(*self.shape)
27 |
28 |
29 | class PrefixTuner(GPT2PreTrainedModel):
30 | """A minimalistic implementation of the core components."""
31 |
32 | def __init__(self, config, model_args, gpt2=None):
33 | super(PrefixTuner, self).__init__(config=config)
34 |
35 | # Instantiate a GPT-2, and DON'T optimizer it!
36 | if gpt2 is None:
37 | self.gpt2 = GPT2LMHeadModel.from_pretrained(
38 | model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir,
39 | )
40 | else:
41 | self.gpt2 = gpt2
42 |
43 | self.register_buffer('extra_prefix_ids', torch.arange(model_args.prefix_len))
44 | # TODO: Also introduce the easier net.
45 | self.extra_prefix_net = nn.Sequential(
46 | nn.Embedding(model_args.prefix_len, config.n_embd),
47 | nn.Linear(config.n_embd, model_args.mid_dim),
48 | nn.Tanh(),
49 | nn.Linear(model_args.mid_dim, config.n_layer * 2 * config.n_embd),
50 | _View((-1, model_args.prefix_len, config.n_layer * 2, config.n_head, config.n_embd // config.n_head)),
51 | nn.Dropout(model_args.prefix_dropout),
52 | )
53 |
54 | def make_past_key_values(self, bsz=None):
55 | extra_prefix_ids = self.extra_prefix_ids[None, :].expand(bsz, -1)
56 | past_key_values = self.extra_prefix_net(extra_prefix_ids)
57 | # (n_layer, batch_size, n_head, prefix_len, n_embed // n_head).
58 | # e.g., (2, 1, 12, 5, 64,).
59 | past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2, dim=0)
60 | return past_key_values
61 |
62 | def state_dict(self):
63 | """Avoid storing GPT-2, since it's not even trained."""
64 | return self.extra_prefix_net.state_dict()
65 |
66 | def load_state_dict(self, state_dict):
67 | """Avoid loading GPT-2, since it's not even trained."""
68 | self.extra_prefix_net.load_state_dict(state_dict)
69 |
70 | @property
71 | def major_device(self):
72 | """Returns the device where the parameters are on."""
73 | return next(self.parameters()).device
74 |
75 | def forward(
76 | self,
77 | input_ids,
78 | attention_mask=None,
79 | token_type_ids=None,
80 | position_ids=None,
81 | head_mask=None,
82 | inputs_embeds=None,
83 | encoder_hidden_states=None,
84 | encoder_attention_mask=None,
85 | labels=None,
86 | use_cache=None,
87 | output_attentions=None,
88 | output_hidden_states=None,
89 | return_dict=None,
90 | **kwargs,
91 | ):
92 | past_key_values = self.make_past_key_values(bsz=input_ids.size(0))
93 | return self.gpt2(
94 | input_ids=input_ids,
95 | past_key_values=past_key_values,
96 | attention_mask=attention_mask,
97 | token_type_ids=token_type_ids,
98 | position_ids=position_ids,
99 | head_mask=head_mask,
100 | inputs_embeds=inputs_embeds,
101 | encoder_hidden_states=encoder_hidden_states,
102 | encoder_attention_mask=encoder_attention_mask,
103 | labels=labels,
104 | use_cache=use_cache,
105 | output_attentions=output_attentions,
106 | output_hidden_states=output_hidden_states,
107 | return_dict=return_dict,
108 | **kwargs
109 | )
110 |
111 | def generate(self, input_ids, num_beams, **kwargs):
112 | # Additional files also changed:
113 | # src/transformers/generation_utils.py
114 | # src/transformers/models/gpt2/modeling_gpt2.py
115 |
116 | # --- lxuechen: This part is really error-prone!
117 | # A sanity check is to optimize the model for a few updates and check if the beam-search generations changed.
118 | # The confusing logic in generation_utils:
119 | # 1) `past` is used in `GPT2LMHeadModel:prepare_inputs_for_generation`,
120 | # 2) it's converted to `past_key_values` in that function,
121 | # 3) `past_key_values` is then updated in forward due to return_dict,
122 | # 4) `past` is set to `past_key_values` in `generation_utils:_update_model_kwargs_for_generation`
123 |
124 | # This is expansion step is important for generation, since otherwise the shapes are wrong.
125 | past_key_values = self.make_past_key_values(bsz=input_ids.size(0) * num_beams)
126 | # ---
127 |
128 | return self.gpt2.generate(
129 | input_ids=input_ids,
130 | num_beams=num_beams,
131 | past_key_values=past_key_values,
132 |
133 | use_cache=True,
134 | position_ids=None,
135 |
136 | # --- lxuechen: These arguments I created to make sure prefix-tuning works correctly.
137 | # The logic: At beginning, past=None, and then it gets replaced with past_key_values.
138 | # Can't directly give in past, since otherwise, input_ids gets truncated to the last index.
139 | use_past_key_values_as_past_at_init=True,
140 | nullify_attention_mask=True,
141 | # ---
142 |
143 | **kwargs
144 | )
145 |
--------------------------------------------------------------------------------
/examples/table2text/requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | datasets
3 |
--------------------------------------------------------------------------------
/examples/table2text/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | output_dir=${1}
4 | data_dir=${2}
5 | task_mode=${3}
6 | model_name_or_path=${4:-"gpt2"} # One of distilgpt2, gpt2, gpt2-medium, gpt2-large
7 | target_epsilon=${5:-"8"}
8 | cache_dir=${6}
9 | clipping_mode=${7:-"ghost"} # Fill 'default' to turn this off.
10 | non_private=${8:-"no"}
11 |
12 | if [[ ${task_mode} == "e2e" ]]; then
13 | data_dir="${data_dir}/data/e2e_data"
14 | target_delta=8e-6
15 | num_train_epochs=10
16 | learning_rate=2e-3
17 | max_seq_len=100
18 | else
19 | if [[ ${task_mode} == "dart" ]]; then
20 | target_delta=1e-5
21 | data_dir="${data_dir}/data/dart"
22 | num_train_epochs=15 # Approximately same number of updates.
23 | learning_rate=5e-4 # Lower learning rate for stability in large models.
24 | max_seq_len=120
25 | else
26 | echo "Unknown task: ${task_mode}"
27 | exit 1
28 | fi
29 | fi
30 |
31 | # Arguments in the last two lines are the most important.
32 | python -m table2text.run_language_modeling \
33 | --output_dir ${output_dir} --overwrite_output_dir \
34 | --task_mode ${task_mode} \
35 | --model_name_or_path ${model_name_or_path} \
36 | --tokenizer_name ${model_name_or_path} \
37 | --do_train --do_eval \
38 | --line_by_line \
39 | --save_steps 100 --save_total_limit 1 --save_at_last no \
40 | --logging_dir ${output_dir} --logging_steps -1 \
41 | --seed 0 \
42 | --eval_steps 100 --eval_epochs 2 --max_eval_batches 100 --evaluation_strategy epoch --evaluate_before_training "no" --evaluate_during_training "yes" --per_device_eval_batch_size 10 \
43 | --max_generations 9223372036854775807 --max_generations_train 10 --max_generations_valid 9223372036854775807 \
44 | --max_train_examples 9223372036854775807 --max_valid_examples 9223372036854775807 --max_eval_examples 9223372036854775807 \
45 | --data_folder ${data_dir} --max_seq_len ${max_seq_len} --format_mode cat \
46 | --per_example_max_grad_norm 0.1 --target_delta ${target_delta} --target_epsilon ${target_epsilon} \
47 | --learning_rate ${learning_rate} --lr_decay "no" --num_train_epochs ${num_train_epochs} --per_device_train_batch_size 16 --gradient_accumulation_steps 64 \
48 | --non_private ${non_private} \
49 | --clipping_mode "${clipping_mode}" \
50 | --cache_dir ${cache_dir}
51 |
--------------------------------------------------------------------------------
/examples/table2text/run_language_modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Xuechen Li. All Rights Reserved.
3 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """
18 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
19 | GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
20 | using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
21 | """
22 |
23 | import json
24 | import logging
25 | import os
26 |
27 | import torch
28 | from ml_swissknife import utils
29 | from transformers import HfArgumentParser, MODEL_WITH_LM_HEAD_MAPPING, set_seed
30 | from transformers.models.gpt2 import GPT2Tokenizer
31 | from transformers.optimization import get_linear_schedule_with_warmup
32 |
33 | from private_transformers import PrivacyEngine
34 | from .compiled_args import (AuxiliaryArguments, DataTrainingArguments, ModelArguments, PrivacyArguments,
35 | TrainingArguments)
36 | from .misc import get_all_datasets, get_prompt_dataset
37 | from .trainer import Trainer
38 |
39 | logger = logging.getLogger(__name__)
40 |
41 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
42 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
43 |
44 |
45 | def main():
46 | parser = HfArgumentParser(
47 | (ModelArguments, DataTrainingArguments, TrainingArguments, PrivacyArguments, AuxiliaryArguments)
48 | )
49 | model_args, data_args, training_args, privacy_args, auxiliary_args = parser.parse_args_into_dataclasses()
50 |
51 | model_args: ModelArguments
52 | data_args: DataTrainingArguments
53 | training_args: TrainingArguments
54 | privacy_args: PrivacyArguments
55 | auxiliary_args: AuxiliaryArguments
56 |
57 | if data_args.eval_data_file is None and training_args.do_eval:
58 | raise ValueError(
59 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
60 | "or remove the --do_eval argument."
61 | )
62 |
63 | if (
64 | os.path.exists(training_args.output_dir)
65 | and os.listdir(training_args.output_dir)
66 | and training_args.do_train
67 | and not training_args.overwrite_output_dir
68 | ):
69 | raise ValueError(
70 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use "
71 | f"--overwrite_output_dir to overcome."
72 | )
73 |
74 | # Setup logging
75 | logging.basicConfig(
76 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
77 | datefmt="%m/%d/%Y %H:%M:%S",
78 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
79 | )
80 | logger.warning(
81 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
82 | training_args.local_rank,
83 | training_args.device,
84 | training_args.n_gpu,
85 | bool(training_args.local_rank != -1),
86 | training_args.fp16,
87 | )
88 | logger.info("Training/evaluation parameters %s", training_args)
89 |
90 | # Set seed
91 | set_seed(training_args.seed)
92 |
93 | # Debug mode
94 | if training_args.debug:
95 | import warnings
96 | warnings.filterwarnings("error")
97 |
98 | # Low rank models need special models!
99 | from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
100 |
101 | # Config.
102 | config = GPT2Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
103 | config.return_dict = True
104 | config.tie_word_embeddings = False
105 |
106 | # Tokenizer; `bos_token` and `eos_token` is the same for GPT2; both are 50256.
107 | tokenizer = GPT2Tokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
108 |
109 | # Model.
110 | gpt2 = GPT2LMHeadModel.from_pretrained(
111 | model_args.model_name_or_path,
112 | config=config,
113 | cache_dir=model_args.cache_dir,
114 | )
115 | print(f'base gpt2 model: {model_args.model_name_or_path}')
116 | print(gpt2)
117 |
118 | # Clone the embedding into the lm_head for better initialization.
119 | lm_head = gpt2.get_output_embeddings()
120 | embedding = gpt2.get_input_embeddings()
121 | lm_head.weight.data.copy_(embedding.weight.data)
122 | print(f'Cloning initial embedding into lm_head, '
123 | f'checking norms... \n'
124 | f'\tlm_head: {lm_head.weight.norm()}, embedding: {embedding.weight.norm()}')
125 | torch.testing.assert_allclose(lm_head.weight, embedding.weight)
126 | del lm_head, embedding
127 |
128 | if data_args.block_size <= 0:
129 | data_args.block_size = tokenizer.model_max_length
130 | else:
131 | data_args.block_size = min(data_args.block_size, tokenizer.model_max_length)
132 |
133 | # Adjust tokenizer and model embeddings.
134 | print('adapt tokenizer to include [PAD]')
135 | print(f'before len(tokenizer) = {len(tokenizer)}')
136 | tokenizer.add_special_tokens({'pad_token': '[PAD]'})
137 | print(f'after len(tokenizer) = {len(tokenizer)}')
138 | print('tokenizer.eos_token:', tokenizer.eos_token, tokenizer.eos_token_id)
139 | print('tokenizer.bos_token:', tokenizer.bos_token, tokenizer.bos_token_id)
140 |
141 | print('adapt the size of lm_head and input_embeddings to include [PAD]')
142 | print('use avg-based initialization')
143 |
144 | input_embeddings_before = gpt2.get_input_embeddings().weight
145 | lm_head_before = gpt2.get_output_embeddings().weight
146 | gpt2.resize_token_embeddings(len(tokenizer))
147 |
148 | input_embeddings_after = gpt2.get_input_embeddings().weight
149 | lm_head_after = gpt2.get_output_embeddings().weight
150 | print(
151 | f'before lm_head.weight.size() = {lm_head_before.size()}, '
152 | f'input_embeddings_before.size() = {input_embeddings_before.size()}'
153 | )
154 | print(
155 | f'after lm_head.weight.size() = {lm_head_after.size()}, '
156 | f'after input_embeddings_after.size() = {input_embeddings_after.size()}'
157 | )
158 | torch.testing.assert_allclose(lm_head_before, lm_head_after[:-1])
159 | print('pre-chunk equal for lm_head')
160 | torch.testing.assert_allclose(input_embeddings_before, input_embeddings_after[:-1])
161 | print('pre-chunk equal for input_embeddings')
162 | lm_head_after.data[-1] = lm_head_before.mean(dim=0)
163 | input_embeddings_after.data[-1] = input_embeddings_before.mean(dim=0)
164 |
165 | print('double check: ')
166 | print('embedding size', gpt2.get_input_embeddings().weight.size())
167 | print('lm_head size', gpt2.get_output_embeddings().weight.size())
168 | model = gpt2
169 |
170 | train_dataset, val_dataset, eval_dataset, data_collator = get_all_datasets(
171 | config=config,
172 | tokenizer=tokenizer,
173 | data_args=data_args,
174 | training_args=training_args,
175 | model_args=model_args,
176 | )
177 |
178 | # Materialize the prompts.
179 | generation_stuff = dict(
180 | train_prompts=get_prompt_dataset(file_path=data_args.train_prompt_file, tokenizer=tokenizer),
181 | val_prompts=get_prompt_dataset(file_path=data_args.val_prompt_file, tokenizer=tokenizer),
182 | eval_prompts=get_prompt_dataset(file_path=data_args.eval_prompt_file, tokenizer=tokenizer),
183 | )
184 |
185 | trainer = Trainer(
186 | model=model,
187 | tokenizer=tokenizer,
188 | args=training_args,
189 | model_args=model_args,
190 | data_args=data_args,
191 | privacy_args=privacy_args,
192 | auxiliary_args=auxiliary_args,
193 | train_dataset=train_dataset,
194 | val_dataset=val_dataset,
195 | eval_dataset=eval_dataset,
196 | data_collator=data_collator,
197 | generation_stuff=generation_stuff,
198 | )
199 |
200 | # Massage the parameters.
201 | if model_args.attention_only:
202 | model.requires_grad_(False)
203 | for name, param in model.named_parameters():
204 | if 'c_attn.weight' in name:
205 | param.requires_grad_(True)
206 | else:
207 | model.requires_grad_(True)
208 | if model_args.static_lm_head:
209 | model.get_output_embeddings().requires_grad_(False)
210 | if model_args.static_embedding:
211 | model.get_input_embeddings().requires_grad_(False)
212 | model.transformer.wpe.requires_grad_(False)
213 | params = tuple(param for param in model.parameters() if param.requires_grad)
214 | names = tuple(name for name, param in model.named_parameters() if param.requires_grad)
215 | num_trainable_params = sum(param.numel() for param in params)
216 | print(f"Number of trainable params: {num_trainable_params / 1e6:.4f} million")
217 | print(json.dumps(names, indent=4))
218 |
219 | # TODO: Using a single gigantic parameter group is okay only when `weight_decay` is 0.
220 | # Biases and LM parameters should not be decayed perhaps even with privacy.
221 | optimizer = torch.optim.AdamW(
222 | params=params,
223 | lr=training_args.learning_rate,
224 | betas=(training_args.adam_beta1, training_args.adam_beta2),
225 | eps=training_args.adam_epsilon,
226 | )
227 | trainer.optimizer = optimizer
228 |
229 | # Create the lr_scheduler.
230 | num_update_steps_per_epoch = len(trainer.get_train_dataloader()) // trainer.args.gradient_accumulation_steps
231 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
232 | t_total = int(num_update_steps_per_epoch * trainer.args.num_train_epochs)
233 | if training_args.lr_decay:
234 | trainer.lr_scheduler = get_linear_schedule_with_warmup(
235 | trainer.optimizer,
236 | num_warmup_steps=training_args.warmup_steps,
237 | num_training_steps=t_total,
238 | )
239 | else:
240 | trainer.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(trainer.optimizer, lambda _: 1.)
241 |
242 | # Hacky way to set noise_multiplier.
243 | if privacy_args.non_private:
244 | privacy_args.noise_multiplier = 0.
245 | privacy_args.per_example_max_grad_norm = None
246 | else:
247 | actual_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
248 | privacy_engine = PrivacyEngine(
249 | module=model,
250 | batch_size=actual_batch_size,
251 | sample_size=len(train_dataset),
252 | epochs=training_args.num_train_epochs,
253 | max_grad_norm=privacy_args.per_example_max_grad_norm,
254 | noise_multiplier=privacy_args.noise_multiplier,
255 | target_epsilon=privacy_args.target_epsilon,
256 | target_delta=privacy_args.target_delta,
257 | accounting_mode=privacy_args.accounting_mode,
258 | clipping_mode=privacy_args.clipping_mode,
259 | )
260 | # Originally, these could have been null.
261 | privacy_args.noise_multiplier = privacy_engine.noise_multiplier
262 | privacy_args.target_delta = privacy_engine.target_delta
263 |
264 | print('privacy_args: ')
265 | print(json.dumps(privacy_args.__dict__, indent=4))
266 | privacy_engine.attach(optimizer)
267 |
268 | # Training.
269 | if training_args.do_train:
270 | all_args = {
271 | **training_args.__dict__,
272 | **data_args.__dict__,
273 | **model_args.__dict__,
274 | **privacy_args.__dict__,
275 | }
276 | utils.jdump(
277 | all_args,
278 | os.path.join(training_args.output_dir, 'argparse.json'),
279 | default=lambda x: str(x),
280 | )
281 |
282 | # For convenience, we also re-save the tokenizer to the same directory,
283 | # so that you can share your model easily on huggingface.co/models =)
284 | if trainer.is_world_master():
285 | tokenizer.save_pretrained(training_args.output_dir)
286 |
287 | logger.info("*** Train ***")
288 | logger.info(
289 | f"Training set size: {len(train_dataset)}, "
290 | f"per_device_train_batch_size: {training_args.per_device_train_batch_size}, "
291 | f"gradient_accumulation_steps: {training_args.gradient_accumulation_steps}"
292 | )
293 | # lxuechen: Especially so for the restored checkpoints. Don't resume...
294 | trainer.train(model_path=None)
295 | if training_args.save_at_last:
296 | trainer.save_model()
297 |
298 | # Evaluation
299 | if training_args.do_eval:
300 | logger.info("*** Evaluate ***")
301 |
302 | output = trainer.evaluate(log_results=False)
303 | utils.jdump(
304 | output,
305 | os.path.join(training_args.output_dir, "final_results.json"),
306 | )
307 |
308 | logger.info("***** Eval results *****")
309 | logger.info(output)
310 |
311 |
312 | if __name__ == "__main__":
313 | main()
314 |
--------------------------------------------------------------------------------
/private_transformers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import lora_utils
16 | from .privacy_engine import PrivacyEngine
17 | from .transformers_support import freeze_isolated_params_for_vit
18 |
19 | __version__ = '0.2.3'
20 |
--------------------------------------------------------------------------------
/private_transformers/accounting/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/private_transformers/accounting/accounting_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import abc
16 | import math
17 | from typing import Dict, Optional, Union
18 |
19 | from . import rdp_accounting
20 |
21 | DEFAULT_ALPHAS = tuple(1 + x / 10.0 for x in range(1, 100)) + tuple(range(12, 64)) # RDP.
22 |
23 |
24 | class AccountingManager(abc.ABC):
25 | def _get_sigma_with_target_epsilon(
26 | self,
27 | target_epsilon,
28 | target_delta,
29 | sample_rate,
30 | steps,
31 | threshold,
32 | sigma_hi_init,
33 | sigma_lo_init,
34 | ):
35 | """Binary search σ given ε and δ."""
36 | if sigma_lo_init > sigma_hi_init:
37 | raise ValueError("`sigma_lo` should be smaller than `sigma_hi`.")
38 |
39 | # Find an appropriate region for binary search.
40 | sigma_hi = sigma_hi_init
41 | sigma_lo = sigma_lo_init
42 |
43 | # Ensure sigma_hi isn't too small.
44 | while True:
45 | eps = self._compute_epsilon_from_sigma(sigma_hi, sample_rate, target_delta, steps)
46 | if eps < target_epsilon:
47 | break
48 | sigma_hi *= 2
49 |
50 | # Ensure sigma_lo isn't too large.
51 | while True:
52 | eps = self._compute_epsilon_from_sigma(sigma_lo, sample_rate, target_delta, steps)
53 | if eps > target_epsilon:
54 | break
55 | sigma_lo /= 2
56 |
57 | # Binary search.
58 | while sigma_hi - sigma_lo > threshold:
59 | sigma = (sigma_hi + sigma_lo) / 2
60 | eps = self._compute_epsilon_from_sigma(sigma, sample_rate, target_delta, steps)
61 | if eps < target_epsilon:
62 | sigma_hi = sigma
63 | else:
64 | sigma_lo = sigma
65 |
66 | # Conservative estimate.
67 | return sigma_hi
68 |
69 | @abc.abstractmethod
70 | def compute_epsilon(self, sigma, sample_rate, target_delta, steps) -> Dict:
71 | """Override for reporting results."""
72 | raise NotImplementedError
73 |
74 | @abc.abstractmethod
75 | def _compute_epsilon_from_sigma(self, sigma, sample_rate, target_delta, steps) -> float:
76 | """Override for binary sigma search."""
77 | raise NotImplementedError
78 |
79 | def compute_sigma(
80 | self,
81 | target_epsilon: float,
82 | target_delta: float,
83 | sample_rate: float,
84 | epochs: Optional[Union[float, int]] = None,
85 | steps=None,
86 | threshold=1e-3,
87 | sigma_hi_init=4,
88 | sigma_lo_init=0.1,
89 | ) -> float:
90 | if steps is None:
91 | if epochs is None:
92 | raise ValueError("Epochs and steps cannot both be None.")
93 | steps = math.ceil(epochs / sample_rate)
94 | return self._get_sigma_with_target_epsilon(
95 | target_epsilon=target_epsilon,
96 | target_delta=target_delta,
97 | sample_rate=sample_rate,
98 | steps=steps,
99 | threshold=threshold,
100 | sigma_hi_init=sigma_hi_init,
101 | sigma_lo_init=sigma_lo_init,
102 | )
103 |
104 |
105 | class RDPManager(AccountingManager):
106 | def __init__(self, alphas):
107 | super(RDPManager, self).__init__()
108 | self._alphas = alphas
109 |
110 | def _compute_epsilon_from_sigma(self, sigma, sample_rate, target_delta, steps):
111 | return self.compute_epsilon(sigma, sample_rate, target_delta, steps)["eps_rdp"]
112 |
113 | def compute_epsilon(self, sigma, sample_rate, target_delta, steps) -> Dict:
114 | """Compute RDP as usual, but convert to (ε, δ)-DP based on the result by Canonne, Kamath, Steinke."""
115 | rdp = rdp_accounting.compute_rdp(q=sample_rate, noise_multiplier=sigma, steps=steps, orders=self._alphas)
116 | eps, alpha = rdp_accounting.get_privacy_spent(orders=self._alphas, rdp=rdp, delta=target_delta)
117 | return dict(eps_rdp=eps, alpha_rdp=alpha)
118 |
119 |
120 | class GLWManager(AccountingManager):
121 | def __init__(self, eps_error=0.05):
122 | super(GLWManager, self).__init__()
123 | self._eps_error = eps_error
124 |
125 | def _compute_epsilon_from_sigma(self, sigma, sample_rate, target_delta, steps):
126 | return self.compute_epsilon(sigma, sample_rate, target_delta, steps)["eps_upper"] # Be conservative.
127 |
128 | def compute_epsilon(self, sigma, sample_rate, target_delta, steps) -> Dict:
129 | if steps == 0:
130 | return dict(eps_low=None, eps_estimate=None, eps_upper=None)
131 |
132 | from prv_accountant import Accountant
133 | accountant = Accountant(
134 | noise_multiplier=sigma,
135 | sampling_probability=sample_rate,
136 | delta=target_delta,
137 | eps_error=self._eps_error,
138 | max_compositions=steps
139 | )
140 | eps_low, eps_estimate, eps_upper = accountant.compute_epsilon(num_compositions=steps)
141 | return dict(eps_low=eps_low, eps_estimate=eps_estimate, eps_upper=eps_upper)
142 |
--------------------------------------------------------------------------------
/private_transformers/accounting/rdp_accounting.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""
17 | This file is adapted from the privacy accounting procedure in Opacus', which in turn is adapted from tf-privacy.
18 | Below is the original documentation in Opacus.
19 |
20 | *Based on Google's TF Privacy:* https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/analysis
21 | /rdp_accountant.py.
22 | *Here, we update this code to Python 3, and optimize dependencies.*
23 |
24 | Functionality for computing Renyi Differential Privacy (RDP) of an additive
25 | Sampled Gaussian Mechanism (SGM).
26 |
27 | Example:
28 | Suppose that we have run an SGM applied to a function with L2-sensitivity of 1.
29 |
30 | Its parameters are given as a list of tuples
31 | ``[(q_1, sigma_1, steps_1), ..., (q_k, sigma_k, steps_k)],``
32 | and we wish to compute epsilon for a given target delta.
33 |
34 | The example code would be:
35 |
36 | >>> max_order = 32
37 | >>> orders = range(2, max_order + 1)
38 | >>> rdp = np.zeros_like(orders, dtype=float)
39 | >>> for q, sigma, steps in parameters:
40 | >>> rdp += privacy_analysis.compute_rdp(q, sigma, steps, orders)
41 | >>> epsilon, opt_order = privacy_analysis.get_privacy_spent(orders, rdp, delta)
42 | """
43 |
44 | import math
45 | import numpy as np
46 | from scipy import special
47 | from typing import List, Sequence, Union
48 |
49 |
50 | ########################
51 | # LOG-SPACE ARITHMETIC #
52 | ########################
53 |
54 |
55 | def _log_add(logx: float, logy: float) -> float:
56 | r"""Adds two numbers in the log space.
57 |
58 | Args:
59 | logx: First term in log space.
60 | logy: Second term in log space.
61 |
62 | Returns:
63 | Sum of numbers in log space.
64 | """
65 | a, b = min(logx, logy), max(logx, logy)
66 | if a == -np.inf: # adding 0
67 | return b
68 | # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
69 | return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
70 |
71 |
72 | def _log_sub(logx: float, logy: float) -> float:
73 | r"""Subtracts two numbers in the log space.
74 |
75 | Args:
76 | logx: First term in log space. Expected to be greater than the second term.
77 | logy: First term in log space. Expected to be less than the first term.
78 |
79 | Returns:
80 | Difference of numbers in log space.
81 |
82 | Raises:
83 | ValueError
84 | If the result is negative.
85 | """
86 | if logx < logy:
87 | raise ValueError("The result of subtraction must be non-negative.")
88 | if logy == -np.inf: # subtracting 0
89 | return logx
90 | if logx == logy:
91 | return -np.inf # 0 is represented as -np.inf in the log space.
92 |
93 | try:
94 | # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
95 | return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
96 | except OverflowError:
97 | return logx
98 |
99 |
100 | def _compute_log_a_for_int_alpha(q: float, sigma: float, alpha: int) -> float:
101 | r"""Computes :math:`log(A_\alpha)` for integer ``alpha``.
102 |
103 | Notes:
104 | Note that
105 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``,
106 | and that 0 < ``q`` < 1.
107 |
108 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details.
109 |
110 | Args:
111 | q: Sampling rate of SGM.
112 | sigma: The standard deviation of the additive Gaussian noise.
113 | alpha: The order at which RDP is computed.
114 |
115 | Returns:
116 | :math:`log(A_\alpha)` as defined in Section 3.3 of
117 | https://arxiv.org/pdf/1908.10530.pdf.
118 | """
119 |
120 | # Initialize with 0 in the log space.
121 | log_a = -np.inf
122 |
123 | for i in range(alpha + 1):
124 | log_coef_i = (
125 | math.log(special.binom(alpha, i))
126 | + i * math.log(q)
127 | + (alpha - i) * math.log(1 - q)
128 | )
129 |
130 | s = log_coef_i + (i * i - i) / (2 * (sigma ** 2))
131 | log_a = _log_add(log_a, s)
132 |
133 | return float(log_a)
134 |
135 |
136 | def _compute_log_a_for_frac_alpha(q: float, sigma: float, alpha: float) -> float:
137 | r"""Computes :math:`log(A_\alpha)` for fractional ``alpha``.
138 |
139 | Notes:
140 | Note that
141 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``,
142 | and that 0 < ``q`` < 1.
143 |
144 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details.
145 |
146 | Args:
147 | q: Sampling rate of SGM.
148 | sigma: The standard deviation of the additive Gaussian noise.
149 | alpha: The order at which RDP is computed.
150 |
151 | Returns:
152 | :math:`log(A_\alpha)` as defined in Section 3.3 of
153 | https://arxiv.org/pdf/1908.10530.pdf.
154 | """
155 | # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
156 | # initialized to 0 in the log space:
157 | log_a0, log_a1 = -np.inf, -np.inf
158 | i = 0
159 |
160 | z0 = sigma ** 2 * math.log(1 / q - 1) + 0.5
161 |
162 | while True: # do ... until loop
163 | coef = special.binom(alpha, i)
164 | log_coef = math.log(abs(coef))
165 | j = alpha - i
166 |
167 | log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
168 | log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
169 |
170 | log_e0 = math.log(0.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
171 | log_e1 = math.log(0.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
172 |
173 | log_s0 = log_t0 + (i * i - i) / (2 * (sigma ** 2)) + log_e0
174 | log_s1 = log_t1 + (j * j - j) / (2 * (sigma ** 2)) + log_e1
175 |
176 | if coef > 0:
177 | log_a0 = _log_add(log_a0, log_s0)
178 | log_a1 = _log_add(log_a1, log_s1)
179 | else:
180 | log_a0 = _log_sub(log_a0, log_s0)
181 | log_a1 = _log_sub(log_a1, log_s1)
182 |
183 | i += 1
184 | if max(log_s0, log_s1) < -30:
185 | break
186 |
187 | return _log_add(log_a0, log_a1)
188 |
189 |
190 | def _compute_log_a(q: float, sigma: float, alpha: float) -> float:
191 | r"""Computes :math:`log(A_\alpha)` for any positive finite ``alpha``.
192 |
193 | Notes:
194 | Note that
195 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``,
196 | and that 0 < ``q`` < 1.
197 |
198 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf
199 | for details.
200 |
201 | Args:
202 | q: Sampling rate of SGM.
203 | sigma: The standard deviation of the additive Gaussian noise.
204 | alpha: The order at which RDP is computed.
205 |
206 | Returns:
207 | :math:`log(A_\alpha)` as defined in the paper mentioned above.
208 | """
209 | if float(alpha).is_integer():
210 | return _compute_log_a_for_int_alpha(q, sigma, int(alpha))
211 | else:
212 | return _compute_log_a_for_frac_alpha(q, sigma, alpha)
213 |
214 |
215 | def _log_erfc(x: float) -> float:
216 | r"""Computes :math:`log(erfc(x))` with high accuracy for large ``x``.
217 |
218 | Helper function used in computation of :math:`log(A_\alpha)`
219 | for a fractional alpha.
220 |
221 | Args:
222 | x: The input to the function
223 |
224 | Returns:
225 | :math:`log(erfc(x))`
226 | """
227 | return math.log(2) + special.log_ndtr(-x * 2 ** 0.5)
228 |
229 |
230 | def _compute_rdp(q: float, sigma: float, alpha: float) -> float:
231 | r"""Computes RDP of the Sampled Gaussian Mechanism at order ``alpha``.
232 |
233 | Args:
234 | q: Sampling rate of SGM.
235 | sigma: The standard deviation of the additive Gaussian noise.
236 | alpha: The order at which RDP is computed.
237 |
238 | Returns:
239 | RDP at order ``alpha``; can be np.inf.
240 | """
241 | if q == 0:
242 | return 0
243 |
244 | # no privacy
245 | if sigma == 0:
246 | return np.inf
247 |
248 | if q == 1.0:
249 | return alpha / (2 * sigma ** 2)
250 |
251 | if np.isinf(alpha):
252 | return np.inf
253 |
254 | return _compute_log_a(q, sigma, alpha) / (alpha - 1)
255 |
256 |
257 | def compute_rdp(
258 | q: float, noise_multiplier: float, steps: int, orders: Union[Sequence[float], float]
259 | ) -> Union[List[float], float]:
260 | r"""Computes Renyi Differential Privacy (RDP) guarantees of the
261 | Sampled Gaussian Mechanism (SGM) iterated ``steps`` times.
262 |
263 | Args:
264 | q: Sampling rate of SGM.
265 | noise_multiplier: The ratio of the standard deviation of the
266 | additive Gaussian noise to the L2-sensitivity of the function
267 | to which it is added. Note that this is same as the standard
268 | deviation of the additive Gaussian noise when the L2-sensitivity
269 | of the function is 1.
270 | steps: The number of iterations of the mechanism.
271 | orders: An array (or a scalar) of RDP orders.
272 |
273 | Returns:
274 | The RDP guarantees at all orders; can be ``np.inf``.
275 | """
276 | if isinstance(orders, float):
277 | rdp = _compute_rdp(q, noise_multiplier, orders)
278 | else:
279 | rdp = np.array([_compute_rdp(q, noise_multiplier, order) for order in orders])
280 |
281 | return rdp * steps
282 |
283 |
284 | # Based on
285 | # https://github.com/tensorflow/privacy/blob/5f07198b66b3617b22609db983926e3ba97cd905/tensorflow_privacy/privacy/analysis/rdp_accountant.py#L237
286 | def get_privacy_spent(orders, rdp, delta):
287 | """Compute epsilon given a list of RDP values and target delta.
288 | Args:
289 | orders: An array (or a scalar) of orders.
290 | rdp: A list (or a scalar) of RDP guarantees.
291 | delta: The target delta.
292 | Returns:
293 | Pair of (eps, optimal_order).
294 | Raises:
295 | ValueError: If input is malformed.
296 | """
297 | orders_vec = np.atleast_1d(orders)
298 | rdp_vec = np.atleast_1d(rdp)
299 |
300 | if delta <= 0:
301 | raise ValueError("Privacy failure probability bound delta must be >0.")
302 | if len(orders_vec) != len(rdp_vec):
303 | raise ValueError("Input lists must have the same length.")
304 |
305 | # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3):
306 | # eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) )
307 |
308 | # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4).
309 | # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1).
310 | eps_vec = []
311 | for (a, r) in zip(orders_vec, rdp_vec):
312 | if a < 1:
313 | raise ValueError("Renyi divergence order must be >=1.")
314 | if r < 0:
315 | raise ValueError("Renyi divergence must be >=0.")
316 |
317 | if delta ** 2 + math.expm1(-r) >= 0:
318 | # In this case, we can simply bound via KL divergence:
319 | # delta <= sqrt(1-exp(-KL)).
320 | eps = 0 # No need to try further computation if we have eps = 0.
321 | elif a > 1.01:
322 | # This bound is not numerically stable as alpha->1.
323 | # Thus we have a min value of alpha.
324 | # The bound is also not useful for small alpha, so doesn't matter.
325 | eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1)
326 | else:
327 | # In this case we can't do anything. E.g., asking for delta = 0.
328 | eps = np.inf
329 | eps_vec.append(eps)
330 |
331 | idx_opt = np.argmin(eps_vec)
332 | return max(0, eps_vec[idx_opt]), orders_vec[idx_opt]
333 |
--------------------------------------------------------------------------------
/private_transformers/autograd_grad_sample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """
17 | A large portion of this code is adapted from Opacus (https://github.com/pytorch/opacus),
18 | which is licensed under Apache License 2.0.
19 |
20 | We have modified it considerably to support ghost clipping.
21 | """
22 |
23 | from typing import Tuple
24 |
25 | import torch
26 | import torch.nn as nn
27 |
28 | from .settings import BackwardHookMode
29 | from .supported_layers_grad_samplers import _supported_layers_grad_samplers
30 |
31 | # TODO: hooks mode should be settable based on the module.
32 | _hooks_disabled: bool = False
33 | _hooks_mode = BackwardHookMode.default
34 |
35 |
36 | def set_hooks_mode(mode):
37 | if mode not in BackwardHookMode.all():
38 | raise ValueError(f"Unknown mode for hooks: {mode}; expected one of {BackwardHookMode.all()}.")
39 |
40 | global _hooks_mode
41 | _hooks_mode = mode # Set mode.
42 |
43 | if _hooks_mode == BackwardHookMode.ghost_grad: # Second backward pass of ghost clipping doesn't need hooks.
44 | disable_hooks()
45 | elif _hooks_mode == BackwardHookMode.ghost_norm: # First backward pass of ghost clipping needs to accumulate norms.
46 | enable_hooks()
47 |
48 |
49 | def get_hooks_mode():
50 | global _hooks_mode
51 | return _hooks_mode
52 |
53 |
54 | def requires_grad(module: nn.Module, recurse: bool = False) -> bool:
55 | """
56 | Checks if any parameters in a specified module require gradients.
57 |
58 | Args:
59 | module: PyTorch module whose parameters are examined
60 | recurse: Flag specifying if the gradient requirement check should
61 | be applied recursively to sub-modules of the specified module
62 |
63 | Returns:
64 | Flag indicate if any parameters require gradients
65 | """
66 | return any(p.requires_grad for p in module.parameters(recurse))
67 |
68 |
69 | def add_hooks(model: nn.Module, loss_reduction: str = "mean"):
70 | r"""
71 | Adds hooks to model to save activations and backprop values.
72 | The hooks will
73 |
74 | 1. save activations into ``param.activations`` during forward pass.
75 | 2. compute per-sample gradients and save them in ``param.grad_sample`` during backward pass.
76 |
77 | Args:
78 | model: Model to which hooks are added.
79 | loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation.
80 | Can take values ``sum`` or ``mean``.
81 | """
82 | if hasattr(model, "autograd_grad_sample_hooks"):
83 | raise ValueError("Trying to add hooks twice to the same model")
84 |
85 | enable_hooks()
86 |
87 | handles = []
88 | for name, layer in model.named_modules():
89 | if type(layer) in _supported_layers_grad_samplers:
90 | if requires_grad(layer, recurse=False):
91 | handles.append(layer.register_forward_hook(_capture_activations))
92 |
93 | def this_backward(this_layer, grad_input, grad_output):
94 | return _capture_backprops(this_layer, grad_input, grad_output, loss_reduction)
95 |
96 | # Starting with 1.8.0, use `register_full_backward_hook`.
97 | handles.append(layer.register_backward_hook(this_backward))
98 |
99 | model.__dict__.setdefault("autograd_grad_sample_hooks", []).extend(handles)
100 |
101 |
102 | def remove_hooks(model: nn.Module):
103 | """Removes hooks added by `add_hooks()`."""
104 | if not hasattr(model, "autograd_grad_sample_hooks"):
105 | raise ValueError("Asked to remove hooks, but no hooks found")
106 | else:
107 | for handle in model.autograd_grad_sample_hooks:
108 | handle.remove()
109 | del model.autograd_grad_sample_hooks
110 |
111 |
112 | def disable_hooks():
113 | """Globally disables all hooks installed by this library."""
114 | global _hooks_disabled
115 | _hooks_disabled = True
116 |
117 |
118 | def enable_hooks():
119 | """Globally enables all hooks installed by this library."""
120 | global _hooks_disabled
121 | _hooks_disabled = False
122 |
123 |
124 | def _capture_activations(layer: nn.Module, inputs: Tuple, outputs: Tuple):
125 | """Forward hook handler captures and saves activations."""
126 | if not requires_grad(layer) or not layer.training or _hooks_disabled:
127 | return
128 |
129 | if not hasattr(layer, "activations"):
130 | layer.activations = []
131 |
132 | # This improves on original Opacus and supports additional arguments on top of the (first) activation tensor.
133 | stored_inputs = tuple(input_i.detach() if torch.is_tensor(input_i) else input_i for input_i in inputs)
134 | layer.activations.append(stored_inputs)
135 |
136 |
137 | def _capture_backprops(
138 | layer: nn.Module,
139 | inputs: Tuple[torch.Tensor],
140 | outputs: Tuple[torch.Tensor],
141 | loss_reduction: str
142 | ):
143 | """Backward hook handler captures grad_outputs."""
144 | # This improves on the original Opacus codebase and supports multiple outputs.
145 | backprops = tuple(output_i.detach() if torch.is_tensor(output_i) else output_i for output_i in outputs)
146 | _compute_grad_sample(layer, backprops, loss_reduction)
147 |
148 |
149 | def _compute_grad_sample(layer: nn.Module, backprops: Tuple, loss_reduction: str):
150 | """Computes per-sample gradients with respect to the parameters."""
151 | if not requires_grad(layer) or not layer.training or _hooks_disabled:
152 | return
153 |
154 | if not hasattr(layer, "activations"):
155 | raise ValueError(f"No activations detected for {type(layer)}, run forward after add_hooks(model)")
156 |
157 | # Outside of the LSTM there is "batch_first" but not for the Linear inside the LSTM
158 | if isinstance(layer.activations, list):
159 | A = layer.activations.pop()
160 | else:
161 | A = layer.activations
162 |
163 | if not hasattr(layer, "max_batch_len"):
164 | assert torch.is_tensor(A[0]), f"Internal error: first input of the following layer isn't a Tensor. \n{layer}"
165 | layer.max_batch_len = _get_batch_size(layer, A[0])
166 |
167 | n = layer.max_batch_len
168 | if loss_reduction == "mean":
169 | B = tuple(B_i * n if torch.is_tensor(B_i) else B_i for B_i in backprops)
170 | elif loss_reduction == "sum":
171 | B = backprops
172 | else:
173 | raise ValueError(f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported")
174 |
175 | # compute grad sample for individual layers
176 | compute_layer_grad_sample = _supported_layers_grad_samplers.get(type(layer))
177 | compute_layer_grad_sample(layer, A, B)
178 |
179 | if (not isinstance(layer.activations, list) or len(layer.activations) == 0) and hasattr(layer, "max_batch_len"):
180 | del layer.max_batch_len
181 |
182 |
183 | def _get_batch_size(layer: nn.Module, grad_sample: torch.Tensor) -> int:
184 | r"""
185 | Computes and returns the maximum batch size which is the maximum of the dimension values
186 | along 'batch_dim' axis over layer.activations + [grad_sample], where layer.activations is
187 | a list. If layer.activations is a not a list, then return grad_sample.shape[batch_dim].
188 | """
189 |
190 | batch_dim = 0
191 | max_batch_len = 0
192 | if isinstance(layer.activations, list):
193 | for out in layer.activations:
194 | assert torch.is_tensor(out[0]), (
195 | f"Internal error: first input of the following layer isn't a Tensor. \n{layer}"
196 | )
197 | if out[0].shape[batch_dim] > max_batch_len:
198 | max_batch_len = out[0].shape[batch_dim]
199 |
200 | max_batch_len = max(max_batch_len, grad_sample.shape[batch_dim])
201 | return max_batch_len
202 |
--------------------------------------------------------------------------------
/private_transformers/lora_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | LoRA layers.
17 |
18 | This version does not have merged weights for zero latency inference. It makes the code easier to read and maintain.
19 | Adapted from
20 | https://github.com/microsoft/LoRA
21 | https://www.microsoft.com/en-us/research/project/dp-transformers/
22 | """
23 |
24 | import torch
25 | import transformers
26 | from torch import nn
27 |
28 |
29 | class DPMergedLinear(nn.Module):
30 | def __init__(
31 | self,
32 | in_features: int,
33 | out_features: int,
34 | lora_r=0,
35 | lora_alpha=1.,
36 | lora_dropout=0.,
37 | ):
38 | super(DPMergedLinear, self).__init__()
39 | self.linear = nn.Linear(in_features=in_features, out_features=out_features)
40 | self.lora_r = lora_r
41 | self.lora_alpha = lora_alpha
42 | self.lora_dropout = nn.Dropout(p=lora_dropout)
43 | if self.lora_r > 0:
44 | self.lora_A = nn.Linear(in_features=in_features, out_features=lora_r, bias=False)
45 | self.lora_B = nn.Linear(in_features=lora_r, out_features=out_features, bias=False)
46 | self.scaling = self.lora_alpha / lora_r
47 | self.reset_parameters()
48 |
49 | def forward(self, x: torch.Tensor):
50 | result = self.linear(x)
51 | if self.lora_r > 0:
52 | after_dropout = self.lora_dropout(x)
53 | after_A = self.lora_A(after_dropout)
54 | after_B = self.lora_B(after_A)
55 | result += after_B * self.scaling
56 | return result
57 |
58 | def reset_parameters(self):
59 | self.linear.reset_parameters()
60 | if self.lora_r > 0:
61 | self.lora_A.reset_parameters()
62 | self.lora_B.weight.data.zero_()
63 |
64 | @staticmethod
65 | def from_transformers_conv1d(
66 | original_layer,
67 | lora_r=0,
68 | lora_alpha=1.,
69 | lora_dropout=0.,
70 | ) -> "DPMergedLinear":
71 | lora_layer = DPMergedLinear(
72 | in_features=original_layer.weight.shape[0],
73 | out_features=original_layer.weight.shape[1],
74 | lora_r=lora_r,
75 | lora_alpha=lora_alpha,
76 | lora_dropout=lora_dropout,
77 | ).to(original_layer.weight.device)
78 | lora_layer.linear.weight.data.copy_(original_layer.weight.T.data)
79 | lora_layer.linear.bias.data.copy_(original_layer.bias.data)
80 | return lora_layer
81 |
82 |
83 | def convert_gpt2_attention_to_lora(
84 | model: transformers.GPT2PreTrainedModel,
85 | lora_r=0,
86 | lora_alpha=1.,
87 | lora_dropout=0.,
88 | ) -> transformers.GPT2PreTrainedModel:
89 | if not isinstance(model, transformers.GPT2PreTrainedModel):
90 | raise TypeError("Requires a GPT2 model")
91 |
92 | if not hasattr(model, "h") and hasattr(model, "transformer"):
93 | transformer = model.transformer
94 | else:
95 | transformer = model
96 |
97 | for h_i in transformer.h:
98 | new_layer = DPMergedLinear.from_transformers_conv1d(
99 | original_layer=h_i.attn.c_attn,
100 | lora_r=lora_r,
101 | lora_alpha=lora_alpha,
102 | lora_dropout=lora_dropout,
103 | )
104 | h_i.attn.c_attn = new_layer
105 |
106 | return model
107 |
108 |
109 | def mark_only_lora_as_trainable(model: torch.nn.Module) -> None:
110 | model.requires_grad_(True)
111 | for n, p in model.named_parameters():
112 | if 'lora_' not in n:
113 | p.requires_grad = False
114 |
--------------------------------------------------------------------------------
/private_transformers/settings.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import transformers
16 | from ml_swissknife import utils
17 |
18 |
19 | class BackwardHookMode(metaclass=utils.ContainerMeta):
20 | ghost_norm = "ghost_norm"
21 | ghost_grad = "ghost_grad"
22 | default = "default"
23 |
24 |
25 | class ClippingMode(metaclass=utils.ContainerMeta):
26 | default = "default" # Global fixed.
27 | ghost = "ghost" # Global fixed clipping with ghost clipping.
28 | per_layer = "per_layer" # Per layer fixed clipping.
29 | per_layer_percentile = "per_layer_percentile" # Clip gradient per-layer based on gradient norm percentile.
30 |
31 |
32 | class AccountingMode(metaclass=utils.ContainerMeta):
33 | rdp = "rdp"
34 | glw = "glw"
35 | all_ = "all"
36 |
37 |
38 | SUPPORTED_TRANSFORMERS = (
39 | transformers.models.openai.modeling_openai.OpenAIGPTLMHeadModel,
40 | transformers.models.openai.modeling_openai.OpenAIGPTDoubleHeadsModel,
41 | transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel,
42 | transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel,
43 | transformers.models.bert.modeling_bert.BertForSequenceClassification,
44 | transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification,
45 | transformers.models.albert.modeling_albert.AlbertForSequenceClassification,
46 | transformers.models.bart.modeling_bart.BartForConditionalGeneration,
47 | transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
48 | transformers.models.opt.modeling_opt.OPTForCausalLM,
49 | transformers.models.vit.modeling_vit.ViTForImageClassification,
50 | transformers.models.deit.modeling_deit.DeiTForImageClassification,
51 | transformers.models.beit.modeling_beit.BeitForImageClassification,
52 | )
53 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 |
18 | import setuptools
19 |
20 | # for simplicity we actually store the version in the __version__ attribute in the source
21 | here = os.path.realpath(os.path.dirname(__file__))
22 | with open(os.path.join(here, 'private_transformers', '__init__.py')) as f:
23 | meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
24 | if meta_match:
25 | version = meta_match.group(1)
26 | else:
27 | raise RuntimeError("Unable to find __version__ string.")
28 |
29 | with open(os.path.join(here, 'README.md')) as f:
30 | readme = f.read()
31 |
32 | setuptools.setup(
33 | name="private_transformers",
34 | version=version,
35 | author="Xuechen Li",
36 | author_email="lxuechen@cs.toronto.edu",
37 | description="Train Hugging Face transformers with differential privacy.",
38 | long_description=readme,
39 | url="https://github.com/lxuechen/private-transformers",
40 | packages=setuptools.find_packages(exclude=['examples', 'tests']),
41 | install_requires=[
42 | "torch>=1.8.0",
43 | "prv-accountant",
44 | "transformers>=4.20.1", # v0.1.0 uses 4.16.2.
45 | "numpy",
46 | "scipy",
47 | "jupyterlab",
48 | "jupyter",
49 | "ml-swissknife",
50 | "opt_einsum",
51 | "pytest"
52 | ],
53 | python_requires='~=3.8',
54 | classifiers=[
55 | "Programming Language :: Python :: 3",
56 | "License :: OSI Approved :: Apache Software License",
57 | ],
58 | )
59 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Xuechen Li. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------