├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── config ├── __init__.py ├── get_config.py ├── images.py └── tabular.py ├── datasets ├── __init__.py ├── dataset.py ├── image.py ├── loaders.py ├── sample_weights.py └── tabular.py ├── evaluators ├── __init__.py ├── evaluator.py └── metrics.py ├── experiment_scripts ├── adult_script.sh ├── celeba_script.sh ├── dutch_script.sh └── mnist_script.sh ├── main.py ├── models ├── __init__.py ├── model_factory.py └── neural_networks.py ├── optimizers ├── dpsgd_f_optimizer.py ├── dpsgd_global_adaptive_optimizer.py └── dpsgd_global_optimizer.py ├── privacy_engines ├── dpsgd_f_engine.py ├── dpsgd_global_adaptive_engine.py └── dpsgd_global_engine.py ├── trainers ├── __init__.py ├── trainer.py └── trainer_factory.py ├── utils.py └── writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | .idea/ 12 | build/ 13 | plots/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | test.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Data/Results 135 | results/ 136 | 137 | # VSCode 138 | .vscode/ 139 | 140 | runs/ 141 | data/ 142 | experiments/ 143 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fair-DP 2 | This is the codebase accompanying the paper [Disparate Impact in Differential Privacy from Gradient Misalignment](https://arxiv.org/abs/2206.07737). It was accepted for a spotlight presentation in ICLR 2023 and you can check the [open review](https://openreview.net/forum?id=qLOaeRvteqbx). 3 | ## Prerequisites 4 | 5 | - Install conda, pip 6 | - Python 3.10 7 | 8 | ```bash 9 | conda create -n FairDP python=3.10 10 | conda activate FairDP 11 | ``` 12 | 13 | - PyTorch 1.11.0 14 | 15 | ```bash 16 | conda install pytorch=1.11.0 torchvision=0.12.0 numpy=1.22 -c pytorch 17 | ``` 18 | 19 | - functorch 0.1.1 20 | 21 | ```bash 22 | pip install functorch==0.1.1 23 | ``` 24 | 25 | - opacus 1.1 26 | 27 | ```bash 28 | conda install -c conda-forge opacus=1.1 29 | ``` 30 | 31 | - matplotlib 3.4.3 32 | 33 | ```bash 34 | conda install -c conda-forge matplotlib=3.4.3 35 | ``` 36 | 37 | - Other requirements 38 | 39 | ```bash 40 | conda install pandas tbb regex tqdm tensorboardX=2.2 41 | pip install tensorboard==2.9 42 | 43 | ``` 44 | 45 | Scripts to reproduce experiments located at fair-dp/experiment_scripts, results saved to fair-dp/runs. 46 | 47 | ``` 48 | bash ./experiment_scripts/mnist_script.sh 49 | tensorboard --logdir=runs 50 | ``` 51 | 52 | - Download CelebA dataset from https://www.kaggle.com/datasets/jessicali9530/celeba-dataset and save files to 53 | fair-dp/data/celeba/ 54 | 55 | - Download Adult dataset from https://archive.ics.uci.edu/ml/datasets/Adult and save files adult.data, adult.test to 56 | fair-dp/data/adult/ 57 | 58 | ``` 59 | bash ./experiment_scripts/adult_script.sh 60 | ``` 61 | 62 | - Download Dutch dataset from https://easy.dans.knaw.nl/ui/datasets/id/easy-dataset:32357. Free registration is required 63 | on the website. Under the "Data Files" tab download all files. Unzip and save to fair-dp/data/dutch/. Full file path 64 | required is ./fair-dp/data/dutch/original/org/IPUMS2001.asc 65 | 66 | ``` 67 | bash ./experiment_scripts/dutch_script.sh 68 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .writer import Writer -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from .get_config import get_config 4 | 5 | def parse_config_arg(key_value): 6 | assert "=" in key_value, "Must specify config items with format `key=value`" 7 | 8 | k, v = key_value.split("=", maxsplit=1) 9 | 10 | assert k, "Config item can't have empty key" 11 | assert v, "Config item can't have empty value" 12 | 13 | try: 14 | v = ast.literal_eval(v) 15 | except ValueError: 16 | v = str(v) 17 | 18 | return k, v 19 | -------------------------------------------------------------------------------- /config/get_config.py: -------------------------------------------------------------------------------- 1 | from .images import CFG_MAP_IMG 2 | from .tabular import CFG_MAP_TAB 3 | 4 | _IMAGE_DATASETS = ["mnist", "fashion-mnist", "svhn", "cifar10", "celeba"] 5 | _TABULAR_DATASETS = ["adult", "dutch"] 6 | 7 | 8 | def get_config(dataset, method): 9 | if dataset in _IMAGE_DATASETS: 10 | cfg_map = CFG_MAP_IMG 11 | print("Note: protected group set to labels") 12 | elif dataset in _TABULAR_DATASETS: 13 | cfg_map = CFG_MAP_TAB 14 | else: 15 | raise ValueError( 16 | f"Invalid dataset {dataset}. " 17 | + f"Valid choices are {_IMAGE_DATASETS + _TABULAR_DATASETS}." 18 | ) 19 | 20 | base_config = cfg_map["base"](dataset) 21 | 22 | try: 23 | method_config_function = cfg_map[method] 24 | except KeyError: 25 | cfg_map.pop("base") 26 | raise ValueError( 27 | f"Invalid method {method}. " 28 | + f"Valid choices are {cfg_map.keys()}." 29 | ) 30 | 31 | return { 32 | **base_config, 33 | 34 | "dataset": dataset, 35 | "method": method, 36 | 37 | **method_config_function(dataset) 38 | } 39 | -------------------------------------------------------------------------------- /config/images.py: -------------------------------------------------------------------------------- 1 | def get_base_config(dataset): 2 | if dataset in ["mnist", "fashion-mnist", "svhn", "cifar10"]: 3 | delta = 1e-6 4 | output_dim = 10 5 | num_groups = 10 6 | protected_group = "labels" 7 | selected_groups = [2, 8] 8 | elif dataset in ["celeba"]: 9 | delta = 1e-6 10 | output_dim = 2 11 | num_groups = 2 12 | protected_group = "eyeglasses" 13 | selected_groups = [0, 1] 14 | else: 15 | raise ValueError(f"Unknown dataset {dataset}") 16 | 17 | net_configs = { 18 | "net": "cnn", 19 | "activation": "tanh", 20 | "hidden_channels": [32, 16], 21 | "kernel_size": [3, 3, 3, 3], 22 | "stride": [1, 1, 1, 1], 23 | "output_dim": output_dim, 24 | } 25 | 26 | return { 27 | "protected_group": protected_group, 28 | "num_groups": num_groups, 29 | "selected_groups": selected_groups, 30 | 31 | "seed": 0, 32 | 33 | "optimizer": "sgd", 34 | "lr": 0.01, 35 | "use_lr_scheduler": False, 36 | "max_epochs": 60, 37 | "accountant": "rdp", 38 | "delta": delta, 39 | "noise_multiplier": 0.8, 40 | "l2_norm_clip": 1.0, 41 | 42 | "make_valid_loader": False, 43 | "train_batch_size": 256, 44 | "valid_batch_size": 256, 45 | "test_batch_size": 256, 46 | "group_ratios": [-1] * num_groups, 47 | 48 | "valid_metrics": ["accuracy", "accuracy_per_group"], 49 | "test_metrics": ["accuracy", "accuracy_per_group", "macro_accuracy"], 50 | "evaluate_angles": False, 51 | "evaluate_hessian": False, 52 | "angle_comp_step": 200, 53 | "num_hutchinson_estimates": 100, 54 | "sampled_expected_loss": False, 55 | 56 | **net_configs 57 | } 58 | 59 | 60 | def get_non_private_config(dataset): 61 | return {} 62 | 63 | 64 | def get_dpsgd_config(dataset): 65 | return { 66 | "activation": "tanh", 67 | } 68 | 69 | 70 | def get_dpsgd_f_config(dataset): 71 | return { 72 | "activation": "tanh", 73 | "base_max_grad_norm": 1.0, # C0 74 | "counts_noise_multiplier": 10.0 # noise scale applied on mk and ok 75 | } 76 | 77 | 78 | def get_fairness_lens_config(dataset): 79 | return { 80 | "activation": "tanh", 81 | "gradient_regularizer": 1.0, 82 | "boundary_regularizer": 1.0 83 | } 84 | 85 | 86 | def get_dpsgd_global_config(dataset): 87 | return { 88 | "activation": "tanh", 89 | "strict_max_grad_norm": 100, # Z 90 | } 91 | 92 | 93 | # TODO: change defaults 94 | def get_dpsgd_global_adapt_config(dataset): 95 | return { 96 | "activation": "tanh", 97 | "strict_max_grad_norm": 100, # Z 98 | "bits_noise_multiplier": 10.0, # noise scale applied on average of bits 99 | "lr_Z": 0.1, # learning rate with which Z^t is tuned 100 | "threshold": 1 # threshold in how we compare gradient norms to Z 101 | } 102 | 103 | 104 | CFG_MAP_IMG = { 105 | "base": get_base_config, 106 | "regular": get_non_private_config, 107 | "dpsgd": get_dpsgd_config, 108 | "dpsgd-f": get_dpsgd_f_config, 109 | "fairness-lens": get_fairness_lens_config, 110 | "dpsgd-global": get_dpsgd_global_config, 111 | "dpsgd-global-adapt": get_dpsgd_global_adapt_config 112 | } 113 | -------------------------------------------------------------------------------- /config/tabular.py: -------------------------------------------------------------------------------- 1 | def get_base_config(dataset): 2 | if dataset in ["adult", "dutch"]: 3 | delta = 1e-6 4 | output_dim = 2 5 | protected_group = "sex" 6 | num_groups = 2 7 | else: 8 | raise ValueError(f"Unknown dataset {dataset}") 9 | 10 | net_configs = { 11 | "net": 'mlp', 12 | "activation": "tanh", 13 | "hidden_dims": [256, 256], 14 | "output_dim": output_dim, 15 | } 16 | 17 | return { 18 | "protected_group": protected_group, 19 | "num_groups": num_groups, 20 | "selected_groups": [0, 1], 21 | 22 | "seed": 0, 23 | 24 | "optimizer": "sgd", 25 | "lr": 0.01, 26 | "use_lr_scheduler": False, 27 | "max_epochs": 20, 28 | "accountant": "rdp", 29 | "delta": delta, 30 | "noise_multiplier": 1.0, 31 | "l2_norm_clip": 0.5, 32 | 33 | "make_valid_loader": False, 34 | "train_batch_size": 256, 35 | "valid_batch_size": 256, 36 | "test_batch_size": 256, 37 | "group_ratios": [-1] * num_groups, 38 | 39 | "valid_metrics": ["accuracy", "accuracy_per_group"], 40 | "test_metrics": ["accuracy", "accuracy_per_group", "macro_accuracy"], 41 | "evaluate_angles": False, 42 | "evaluate_hessian": False, 43 | "angle_comp_step": 100, 44 | "num_hutchinson_estimates": 100, 45 | "sampled_expected_loss": False, 46 | 47 | **net_configs 48 | } 49 | 50 | 51 | def get_non_private_config(dataset): 52 | return {} 53 | 54 | 55 | def get_dpsgd_config(dataset): 56 | return { 57 | "activation": "tanh", 58 | } 59 | 60 | 61 | def get_dpsgd_f_config(dataset): 62 | return { 63 | "activation": "tanh", 64 | "base_max_grad_norm": 1.0, # C0 65 | "counts_noise_multiplier": 10.0 # noise scale applied on mk and ok 66 | } 67 | 68 | 69 | def get_fairness_lens_config(dataset): 70 | return { 71 | "activation": "tanh", 72 | "gradient_regularizer": 1.0, 73 | "boundary_regularizer": 1.0 74 | } 75 | 76 | 77 | def get_dpsgd_global_config(dataset): 78 | return { 79 | "activation": "tanh", 80 | "strict_max_grad_norm": 100, # Z 81 | } 82 | 83 | 84 | # TODO: change defaults 85 | def get_dpsgd_global_adapt_config(dataset): 86 | return { 87 | "activation": "tanh", 88 | "strict_max_grad_norm": 100, # Z 89 | "bits_noise_multiplier": 10.0, # noise scale applied on average of bits 90 | "lr_Z": 0.1, # learning rate with which Z^t is tuned 91 | "threshold": 1 # threshold in how we compare gradient norms to Z 92 | } 93 | 94 | 95 | CFG_MAP_TAB = { 96 | "base": get_base_config, 97 | "regular": get_non_private_config, 98 | "dpsgd": get_dpsgd_config, 99 | "dpsgd-f": get_dpsgd_f_config, 100 | "fairness-lens": get_fairness_lens_config, 101 | "dpsgd-global": get_dpsgd_global_config, 102 | "dpsgd-global-adapt": get_dpsgd_global_adapt_config 103 | } 104 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .loaders import get_loaders_from_config, get_loaders -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import torch 4 | 5 | 6 | class GroupLabelDataset(torch.utils.data.Dataset): 7 | ''' 8 | Implementation of torch Dataset that returns features 'x', classification labels 'y', and protected group labels 'z' 9 | ''' 10 | 11 | def __init__(self, role, x, y=None, z=None): 12 | if y is None: 13 | y = torch.zeros(x.shape[0]).long() 14 | 15 | if z is None: 16 | z = torch.zeros(x.shape[0]).long() 17 | 18 | assert x.shape[0] == y.shape[0] and x.shape[0] == z.shape[0] 19 | assert role in ["train", "valid", "test"] 20 | 21 | self.role = role 22 | 23 | self.x = x 24 | self.y = y 25 | self.z = z 26 | 27 | def __len__(self) -> int: 28 | return self.x.shape[0] 29 | 30 | def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: 31 | return self.x[index], self.y[index], self.z[index] 32 | 33 | def to(self, device): 34 | return GroupLabelDataset( 35 | self.role, 36 | self.x.to(device), 37 | self.y.to(device), 38 | self.z.to(device), 39 | ) 40 | -------------------------------------------------------------------------------- /datasets/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import PIL 4 | from pathlib import Path 5 | from typing import Tuple, Sequence, Any 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.datasets 10 | import torchvision.transforms as transforms 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from .dataset import GroupLabelDataset 15 | from .sample_weights import find_sample_weights 16 | 17 | 18 | class CelebA(Dataset): 19 | """ 20 | CelebA PyTorch dataset 21 | The built-in PyTorch dataset for CelebA is outdated. 22 | """ 23 | 24 | def __init__(self, root: str, group_ratios: Sequence[int], role: str = "train", seed: int = 0): 25 | self.root = Path(root) 26 | self.role = role 27 | 28 | self.transform = transforms.Compose([ 29 | transforms.Resize((64, 64)), 30 | transforms.ToTensor(), 31 | ]) 32 | 33 | celeb_path = lambda x: self.root / x 34 | 35 | role_map = { 36 | "train": 0, 37 | "valid": 1, 38 | "test": 2, 39 | "all": None, 40 | } 41 | splits_df = pd.read_csv(celeb_path("list_eval_partition.csv")) 42 | fields = ['image_id', 'Male', 'Eyeglasses'] 43 | attrs_df = pd.read_csv(celeb_path("list_attr_celeba.csv"), usecols=fields) 44 | df = pd.merge(splits_df, attrs_df, on='image_id') 45 | df = df[df['partition'] == role_map[self.role]].drop(labels='partition', axis=1) 46 | df = df.replace(to_replace=-1, value=0) 47 | 48 | if seed: 49 | # Shuffle order according to seed but keep standard partition because the same person appears multiple times 50 | state = np.random.default_rng(seed=seed) 51 | df = df.sample(frac=1, random_state=state) 52 | 53 | labels = df["Male"] 54 | if group_ratios and (role_map[self.role] != 2): 55 | # don't alter the test set, refer to sample_weights.py 56 | label_counts = labels.value_counts(dropna=False).tolist() 57 | sample_weights = find_sample_weights(group_ratios, label_counts) 58 | print(f"Number of samples by label (before sampling) in {self.role}:") 59 | print(f"Female: {label_counts[0]}, Male: {label_counts[1]}") 60 | 61 | random.seed(seed) 62 | idx = [random.random() <= sample_weights[label] for label in labels] 63 | labels = labels[idx] 64 | label_counts_after = labels.value_counts(dropna=False).tolist() 65 | 66 | print("Number of samples by label (after sampling):") 67 | print(f"Female: {label_counts_after[0]}, Male: {label_counts_after[1]}") 68 | df = df[idx] 69 | 70 | self.filename = df["image_id"].tolist() 71 | # Male is 1, Female is 0 72 | self.y = torch.Tensor(df["Male"].values).long() 73 | # Wearing glasses is 1, otherwise zero 74 | self.z = torch.Tensor(df["Eyeglasses"].values).long() 75 | 76 | self.shape = (len(self.filename), 3, 64, 64) 77 | 78 | def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: 79 | img_path = (self.root / "img_align_celeba" / 80 | "img_align_celeba" / self.filename[index]) 81 | x = PIL.Image.open(img_path) 82 | x = self.transform(x).to(self.device) 83 | 84 | y = self.y[index].to(self.device) 85 | z = self.z[index].to(self.device) 86 | 87 | return x, y, z 88 | 89 | def __len__(self) -> int: 90 | return len(self.filename) 91 | 92 | def to(self, device): 93 | self.device = device 94 | return self 95 | 96 | 97 | def image_tensors_to_dataset(dataset_role, images, labels): 98 | images = images.to(dtype=torch.get_default_dtype()) 99 | labels = labels.long() 100 | # NOTE: assumed protected group is defined by labels for image dsets 101 | return GroupLabelDataset(dataset_role, images, labels, labels) 102 | 103 | 104 | # Returns tuple of form `(images, labels)`. 105 | # `images` has shape `(nimages, nchannels, nrows, ncols)`, and has 106 | # entries in {0, ..., 1} 107 | def get_raw_image_tensors(dataset_name, train, data_root, group_ratios=None, seed=0): 108 | data_dir = os.path.join(data_root, dataset_name) 109 | 110 | if dataset_name == "cifar10": 111 | dataset = torchvision.datasets.CIFAR10(root=data_dir, train=train, download=True) 112 | images = torch.tensor(dataset.data).permute((0, 3, 1, 2)) 113 | labels = torch.tensor(dataset.targets) 114 | 115 | elif dataset_name == "svhn": 116 | dataset = torchvision.datasets.SVHN(root=data_dir, split="train" if train else "test", download=True) 117 | images = torch.tensor(dataset.data) 118 | labels = torch.tensor(dataset.labels) 119 | 120 | elif dataset_name in ["mnist", "fashion-mnist"]: 121 | dataset_class = { 122 | "mnist": torchvision.datasets.MNIST, 123 | "fashion-mnist": torchvision.datasets.FashionMNIST 124 | }[dataset_name] 125 | dataset = dataset_class(root=data_dir, train=train, download=True) 126 | images = dataset.data.unsqueeze(1) 127 | labels = dataset.targets 128 | 129 | else: 130 | raise ValueError(f"Unknown dataset {dataset_name}") 131 | 132 | images = images / 255.0 133 | 134 | if group_ratios: 135 | # refer to sample_weights.py 136 | _, label_counts = torch.unique(labels, sorted=True, return_counts=True) 137 | sample_weights = find_sample_weights(group_ratios, label_counts.tolist()) 138 | print("Number of samples by label (before sampling):") 139 | print(label_counts) 140 | 141 | random.seed(seed) 142 | idx = [random.random() <= sample_weights[label.item()] for label in labels] 143 | labels = labels[idx] 144 | images = images[idx] 145 | _, label_counts_after = torch.unique(labels, sorted=True, return_counts=True) 146 | 147 | print("Number of samples by label (after sampling):") 148 | print(label_counts_after) 149 | 150 | return images, labels 151 | 152 | 153 | def get_torchvision_datasets(dataset_name, data_root, seed, group_ratios, valid_fraction, flatten): 154 | images, labels = get_raw_image_tensors(dataset_name, train=True, data_root=data_root, group_ratios=group_ratios, 155 | seed=seed) 156 | if flatten: 157 | images = images.flatten(start_dim=1) 158 | 159 | perm = torch.randperm(images.shape[0]) 160 | shuffled_images = images[perm] 161 | shuffled_labels = labels[perm] 162 | 163 | valid_size = int(valid_fraction * images.shape[0]) 164 | valid_images = shuffled_images[:valid_size] 165 | valid_labels = shuffled_labels[:valid_size] 166 | train_images = shuffled_images[valid_size:] 167 | train_labels = shuffled_labels[valid_size:] 168 | 169 | train_dset = image_tensors_to_dataset("train", train_images, train_labels) 170 | valid_dset = image_tensors_to_dataset("valid", valid_images, valid_labels) 171 | 172 | test_images, test_labels = get_raw_image_tensors(dataset_name, train=False, data_root=data_root) 173 | if flatten: 174 | test_images = test_images.flatten(start_dim=1) 175 | test_dset = image_tensors_to_dataset("test", test_images, test_labels) 176 | 177 | return train_dset, valid_dset, test_dset 178 | 179 | def get_image_datasets_by_class(dataset_name, data_root, seed, group_ratios, valid_fraction, flatten=False): 180 | data_dir = os.path.join(data_root, dataset_name) 181 | 182 | if dataset_name == "celeba": 183 | # valid_fraction and flatten ignored 184 | data_class = CelebA 185 | 186 | else: 187 | raise ValueError(f"Unknown dataset {dataset_name}") 188 | 189 | train_dset = data_class(root=data_dir, group_ratios=group_ratios, role="train", seed=seed) 190 | valid_dset = data_class(root=data_dir, group_ratios=group_ratios, role="valid", seed=seed) 191 | test_dset = data_class(root=data_dir, group_ratios=group_ratios, role="test", seed=seed) 192 | 193 | return train_dset, valid_dset, test_dset 194 | 195 | def get_image_datasets(dataset_name, data_root, seed, group_ratios, make_valid_loader=False, flatten=False): 196 | valid_fraction = 0.1 if make_valid_loader else 0 197 | 198 | torchvision_datasets = ["mnist", "fashion-mnist", "svhn", "cifar10"] 199 | 200 | get_datasets_fn = get_torchvision_datasets if dataset_name in torchvision_datasets else get_image_datasets_by_class 201 | 202 | return get_datasets_fn(dataset_name, data_root, seed, group_ratios, valid_fraction, flatten) 203 | -------------------------------------------------------------------------------- /datasets/loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | 4 | from .image import get_image_datasets 5 | from .tabular import get_tabular_datasets 6 | 7 | 8 | def get_loaders_from_config(cfg, device, **kwargs): 9 | if cfg["net"] == "cnn": 10 | flatten = False 11 | elif cfg["net"] == "mlp": 12 | flatten = True 13 | elif cfg["net"] == "logistic": 14 | flatten = True 15 | else: 16 | raise ValueError(f"Unknown net type {cfg['net']} for flattening") 17 | 18 | train_loader, valid_loader, test_loader = get_loaders( 19 | dataset=cfg["dataset"], 20 | device=device, 21 | data_root=cfg.get("data_root", "data/"), 22 | train_batch_size=cfg["train_batch_size"], 23 | valid_batch_size=cfg["valid_batch_size"], 24 | test_batch_size=cfg["test_batch_size"], 25 | group_ratios=cfg["group_ratios"], 26 | seed=cfg["seed"], 27 | protected_group=cfg["protected_group"], 28 | make_valid_loader=cfg["make_valid_loader"], 29 | flatten=flatten, 30 | ) 31 | 32 | if cfg["dataset"] in ["celeba"]: 33 | train_dataset_shape = train_loader.dataset.shape 34 | else: 35 | train_dataset_shape = train_loader.dataset.x.shape 36 | cfg["train_dataset_size"] = train_dataset_shape[0] 37 | cfg["data_shape"] = tuple(train_dataset_shape[1:]) 38 | cfg["data_dim"] = int(np.prod(cfg["data_shape"])) 39 | 40 | if not cfg["make_valid_loader"]: 41 | valid_loader = test_loader 42 | print("WARNING: Using test loader for validation") 43 | 44 | return train_loader, valid_loader, test_loader 45 | 46 | 47 | def get_loaders( 48 | dataset, 49 | device, 50 | data_root, 51 | train_batch_size, 52 | valid_batch_size, 53 | test_batch_size, 54 | group_ratios, 55 | seed, 56 | protected_group, 57 | make_valid_loader, 58 | flatten, 59 | ): 60 | # NOTE: only training and validation sets sampled according to group_ratios 61 | if dataset in ["mnist", "fashion-mnist", "cifar10", "svhn", "celeba"]: 62 | train_dset, valid_dset, test_dset = get_image_datasets(dataset, data_root, seed, group_ratios, 63 | make_valid_loader, flatten) 64 | 65 | # NOTE: entire dataset sampled according to group_ratios 66 | elif dataset in ["adult", "dutch"]: 67 | train_dset, valid_dset, test_dset = get_tabular_datasets(dataset, data_root, seed, protected_group, 68 | group_ratios, make_valid_loader) 69 | else: 70 | raise ValueError(f"Unknown dataset {dataset}") 71 | 72 | train_loader = get_loader(train_dset, device, train_batch_size, drop_last=False) 73 | 74 | if make_valid_loader: 75 | valid_loader = get_loader(valid_dset, device, valid_batch_size, drop_last=False) 76 | else: 77 | valid_loader = None 78 | 79 | test_loader = get_loader(test_dset, device, test_batch_size, drop_last=False) 80 | 81 | return train_loader, valid_loader, test_loader 82 | 83 | 84 | def get_loader(dset, device, batch_size, drop_last): 85 | return DataLoader( 86 | dset.to(device), 87 | batch_size=batch_size, 88 | shuffle=True, 89 | drop_last=drop_last, 90 | num_workers=0, 91 | pin_memory=False 92 | ) 93 | -------------------------------------------------------------------------------- /datasets/sample_weights.py: -------------------------------------------------------------------------------- 1 | from numpy import argmax 2 | 3 | ''' 4 | Suppose we have data partitioned by groups, want to sample data by group_ratios such that 5 | after sampling, 6 | number of samples in group i / total number of samples ~= group_ratios[i] / sum(group_ratios) 7 | 8 | Transform group_ratios -> sample_weights and then sample each group i with probability 9 | sample_weights[i] to achieve this 10 | 11 | all group ratios should be <= 1, -1 indicates group should be sampled with probability 1 12 | ''' 13 | 14 | 15 | def find_restricted(group_ratios, num_samples, sample_idx): 16 | """ 17 | group_ratios: -1, -1, 0.09, -1, ... 18 | num_samples: a list of sample counts that falls into each group 19 | sample_idx: a list of index that is not -1 in group_ratios 20 | """ 21 | candidates = [] 22 | for i in sample_idx: 23 | if all(group_ratios[j] * num_samples[i] <= num_samples[j] for j in sample_idx): 24 | candidates.append(i) 25 | restricted_index = argmax([group_ratios[i] * num_samples[i] for i in candidates]) 26 | restricted = candidates[restricted_index] 27 | return restricted 28 | 29 | 30 | def find_sample_weights(group_ratios, num_samples): 31 | to_sample_idx = [i for i, item in enumerate(group_ratios) if item != -1] 32 | if to_sample_idx == []: 33 | return {j: 1 for j in range(len(group_ratios))} 34 | restricted = find_restricted(group_ratios, num_samples, to_sample_idx) 35 | sample_weights = {j: group_ratios[j] * num_samples[restricted] / num_samples[j] if j in to_sample_idx else 1 for j 36 | in range(len(group_ratios))} 37 | return sample_weights 38 | -------------------------------------------------------------------------------- /datasets/tabular.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random as rand 3 | 4 | import pandas as pd 5 | import regex as re 6 | import torch 7 | 8 | from .dataset import GroupLabelDataset 9 | from .sample_weights import find_sample_weights 10 | 11 | 12 | # normalize df columns 13 | def normalize(df, columns): 14 | result = df.copy() 15 | for column in columns: 16 | mu = df[column].mean(axis=0) 17 | sigma = df[column].std(axis=0) 18 | assert sigma != 0 19 | result[column] = (df[column] - mu) / sigma 20 | return result 21 | 22 | 23 | def make_tabular_train_valid_split(data, valid_frac): 24 | n_valid = int(valid_frac * data.shape[0]) 25 | valid_data = data[:n_valid] 26 | train_data = data[n_valid:] 27 | return train_data, valid_data 28 | 29 | 30 | def make_tabular_train_valid_test_split(data, valid_frac, test_frac, seed): 31 | # shuffle samples 32 | data = data.sample(frac=1, random_state=seed) 33 | 34 | n_test = int(test_frac * data.shape[0]) 35 | test_data = data[:n_test] 36 | data = data[n_test:] 37 | 38 | train_data, valid_data = make_tabular_train_valid_split(data, valid_frac) 39 | 40 | return train_data, valid_data, test_data 41 | 42 | 43 | # refer to sample_weights.py 44 | def sample_by_group_ratios(group_ratios, df, seed): 45 | print("Number of samples by group (before sampling):") 46 | print(df.protected_group.value_counts()) 47 | sample_weights = find_sample_weights(group_ratios, df.protected_group.value_counts().tolist()) 48 | rand.seed(seed) 49 | idx = [rand.random() <= sample_weights[row.protected_group] for _, row in df.iterrows()] 50 | df = df.loc[idx] 51 | print("Number of samples by group (after sampling):") 52 | print(df.protected_group.value_counts()) 53 | return df 54 | 55 | 56 | def preprocess_adult(df, protected_group, target, group_ratios, seed): 57 | numerical_columns = ["age", "education_num", "capital_gain", "capital_loss", 58 | "hours_per_week"] 59 | if protected_group in numerical_columns: 60 | numerical_columns.remove(protected_group) 61 | df = normalize(df, numerical_columns) 62 | 63 | mapped_income_values = df.income.map({"<=50K": 0, ">50K": 1, "<=50K.": 0, ">50K.": 1}) 64 | df.loc[:, "income"] = mapped_income_values 65 | 66 | mapped_sex_values = df.sex.map({"Male": 0, "Female": 1}) 67 | df.loc[:, "sex"] = mapped_sex_values 68 | 69 | # make race binary 70 | def race_map(value): 71 | if value != "White": 72 | return (1) 73 | return (0) 74 | 75 | mapped_race_values = df.race.map(race_map) 76 | df.loc[:, "race"] = mapped_race_values 77 | 78 | categorical = df.columns.tolist() 79 | for column in numerical_columns: 80 | categorical.remove(column) 81 | print("Possible protected groups are: {}".format(categorical)) 82 | 83 | if protected_group == "labels": 84 | df.loc[:, "protected_group"] = df[target] 85 | elif protected_group not in categorical: 86 | raise ValueError( 87 | f"Invalid protected group {protected_group}. " 88 | + f"Valid choices are {categorical}." 89 | ) 90 | else: 91 | df.loc[:, "protected_group"] = df[protected_group] 92 | 93 | df = sample_by_group_ratios(group_ratios, df, seed) 94 | 95 | # convert to one-hot vectors 96 | categorical_non_binary = ["workclass", "education", "marital_status", "occupation", 97 | "relationship", "native_country"] 98 | df = pd.get_dummies(df, columns=categorical_non_binary) 99 | 100 | return df 101 | 102 | 103 | def preprocess_dutch(df, protected_group, target, group_ratios, seed): 104 | # remove weight feature 105 | df = df.drop("weight", axis=1) 106 | 107 | # drop underage samples (under 14 yrs old) 108 | df = df.drop(df[df.age <= 3].index) 109 | 110 | # drop samples with occupation = not working, unknown = 999, 998 111 | df = df.drop(df[df.occupation == 999].index) 112 | df = df.drop(df[df.occupation == 998].index) 113 | 114 | # map occupations to high=1, mid, low=0 115 | occupation_map = { 116 | 1: 1, 117 | 2: 1, 118 | 3: "mid", 119 | 4: 0, 120 | 5: 0, 121 | 6: "mid", 122 | 7: "mid", 123 | 8: "mid", 124 | 9: 0 125 | } 126 | mapped_occupation_values = df.occupation.map(occupation_map) 127 | df.loc[:, "occupation"] = mapped_occupation_values 128 | 129 | # drop samples with occupation = mid 130 | df = df.drop(df[df.occupation == "mid"].index) 131 | 132 | mapped_sex_values = df.sex.map({1: 0, 2: 1}) 133 | df.loc[:, "sex"] = mapped_sex_values 134 | 135 | # note original dataset has values {0,1,9} for prev_res_place, but all samples with 9 are underage, hence get dropped 136 | mapped_prev_res_place_values = df.prev_res_place.map({1: 0, 2: 1}) 137 | df.loc[:, "prev_res_place"] = mapped_prev_res_place_values 138 | 139 | categorical = df.columns.to_list() 140 | print("Possible protected groups are: {}".format(categorical)) 141 | 142 | if protected_group == "labels": 143 | df.loc[:, "protected_group"] = df[target] 144 | elif protected_group not in categorical: 145 | raise ValueError( 146 | f"Invalid protected group {protected_group}. " 147 | + f"Valid choices are {categorical}." 148 | ) 149 | else: 150 | df.loc[:, "protected_group"] = df[protected_group] 151 | 152 | # convert categorical unprotected features to one-hot vectors 153 | if target in categorical: 154 | categorical.remove(target) 155 | if "sex" in categorical: 156 | categorical.remove("sex") # binary 157 | if "prev_res_place" in categorical: 158 | categorical.remove("prev_res_place") # binary 159 | 160 | df = sample_by_group_ratios(group_ratios, df, seed) 161 | 162 | df = pd.get_dummies(df, columns=categorical) 163 | 164 | return df 165 | 166 | 167 | def get_dutch_raw(data_root, valid_frac, test_frac, seed, protected_group, target, group_ratios): 168 | ''' 169 | Dutch dataset: 170 | Download from https://easy.dans.knaw.nl/ui/datasets/id/easy-dataset:32357 (free registration required) 171 | unzip and save directory to fair-dp/data/dutch/ 172 | ''' 173 | 174 | columns = ["sex", "age", "household_posn", "household_size", "prev_res_place", "citizenship", 175 | "country_birth", "edu_level", "econ_status", "occupation", "cur_eco_activity", 176 | "marital_status", "weight"] 177 | col_str = ",".join(columns) 178 | col_str = col_str + "\n" 179 | 180 | # data location data/dutch/original/org/IPUMS2001.asc 181 | write_from = open(os.path.join(data_root, "dutch", "original", "org", "IPUMS2001.asc"), "r") 182 | write_to = open(os.path.join(data_root, "dutch", "dutch_data_formatted.csv"), "w") 183 | 184 | write_to.write(col_str) 185 | 186 | def to_csv(write_from, write_to): 187 | while True: 188 | line = write_from.readline() 189 | 190 | if not line: 191 | break 192 | 193 | # refer to IMPUS2001_meta.pdf page 8 for group values 194 | result = re.search(r"(.{1})(.{2})(.{4})(.{3})(.{3})(.{2})(.{1})(.{2})(.{3})(.{3})(.{3})(.{1})(.{16})", line) 195 | formatted_str = "" 196 | for group in range(1, len(result.groups()) + 1): 197 | if group == len(result.groups()): 198 | formatted_str = formatted_str + result.group(group).strip() + "\n" 199 | else: 200 | formatted_str = formatted_str + result.group(group).strip() + "," 201 | write_to.write(formatted_str) 202 | 203 | to_csv(write_from, write_to) 204 | write_from.close() 205 | write_to.close() 206 | 207 | df = pd.read_csv(os.path.join(data_root, "dutch", "dutch_data_formatted.csv")) 208 | 209 | df_preprocessed = preprocess_dutch(df, protected_group, target, group_ratios, seed) 210 | 211 | train_raw, valid_raw, test_raw = make_tabular_train_valid_test_split(df_preprocessed, valid_frac, test_frac, seed) 212 | 213 | return train_raw, valid_raw, test_raw 214 | 215 | 216 | def get_adult_raw(data_root, valid_frac, test_frac, seed, protected_group, target, group_ratios): 217 | ''' 218 | Adult dataset: 219 | Download from https://archive.ics.uci.edu/ml/datasets/Adult 220 | and save files adult.data, adult.test to fair-dp/data/adult/ 221 | ''' 222 | columns = ["age", "workclass", "fnlwgt", "education", "education_num", "marital_status", 223 | "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss", 224 | "hours_per_week", "native_country", "income"] 225 | 226 | df_1 = pd.read_csv(os.path.join(data_root, "adult", "adult.data"), sep=", ", engine='python', header=None) 227 | df_2 = pd.read_csv(os.path.join(data_root, "adult", "adult.test"), sep=", ", engine='python', header=None, 228 | skiprows=1) 229 | df_1.columns = columns 230 | df_2.columns = columns 231 | df = pd.concat((df_1, df_2), ignore_index=True) 232 | 233 | df = df.drop("fnlwgt", axis=1) 234 | for column in df.columns: 235 | df = df[df[column] != "?"] 236 | df.to_csv(os.path.join(data_root, "adult", "adult_data_formatted.csv"), index=False) 237 | 238 | df = pd.read_csv(os.path.join(data_root, "adult", "adult_data_formatted.csv")) 239 | 240 | df_preprocessed = preprocess_adult(df, protected_group, target, group_ratios, seed) 241 | 242 | train_raw, valid_raw, test_raw = make_tabular_train_valid_test_split(df_preprocessed, valid_frac, test_frac, seed) 243 | 244 | return train_raw, valid_raw, test_raw 245 | 246 | 247 | def get_tabular_datasets(name, data_root, seed, protected_group, group_ratios=None, make_valid_loader=False): 248 | if name == "adult": 249 | data_fn = get_adult_raw 250 | target = "income" 251 | elif name == "dutch": 252 | data_fn = get_dutch_raw 253 | target = "occupation" 254 | else: 255 | raise ValueError(f"Unknown dataset {name}") 256 | 257 | valid_frac = 0 258 | if make_valid_loader: 259 | valid_frac = 0.1 260 | test_frac = 0.2 261 | train_raw, valid_raw, test_raw = data_fn(data_root, valid_frac, test_frac, seed, protected_group, target, 262 | group_ratios) 263 | 264 | feature_columns = train_raw.columns.to_list() 265 | feature_columns.remove(target) 266 | feature_columns.remove("protected_group") 267 | 268 | train_dset = GroupLabelDataset("train", 269 | torch.tensor(train_raw[feature_columns].values, dtype=torch.get_default_dtype()), 270 | torch.tensor(train_raw[target].to_list(), dtype=torch.long), 271 | torch.tensor(train_raw["protected_group"].values.tolist(), dtype=torch.long) 272 | ) 273 | valid_dset = GroupLabelDataset("valid", 274 | torch.tensor(valid_raw[feature_columns].values, dtype=torch.get_default_dtype()), 275 | torch.tensor(valid_raw[target].to_list(), dtype=torch.long), 276 | torch.tensor(valid_raw["protected_group"].values.tolist(), dtype=torch.long) 277 | ) 278 | test_dset = GroupLabelDataset("test", 279 | torch.tensor(test_raw[feature_columns].values, dtype=torch.get_default_dtype()), 280 | torch.tensor(test_raw[target].to_list(), dtype=torch.long), 281 | torch.tensor(test_raw["protected_group"].values.tolist(), dtype=torch.long) 282 | ) 283 | 284 | return train_dset, valid_dset, test_dset 285 | -------------------------------------------------------------------------------- /evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import create_evaluator -------------------------------------------------------------------------------- /evaluators/evaluator.py: -------------------------------------------------------------------------------- 1 | from inspect import getmembers, isfunction 2 | 3 | from . import metrics 4 | 5 | metric_fn_dict = dict(getmembers(metrics, predicate=isfunction)) 6 | 7 | 8 | class Evaluator: 9 | def __init__(self, model, *, 10 | valid_loader, test_loader, 11 | valid_metrics=None, 12 | test_metrics=None, 13 | **kwargs): 14 | self.model = model 15 | self.valid_loader = valid_loader 16 | self.test_loader = test_loader 17 | self.valid_metrics = valid_metrics or {} 18 | self.test_metrics = test_metrics or valid_metrics 19 | self.metric_kwargs = kwargs or {} 20 | 21 | def evaluate(self, dataloader, metric): 22 | assert metric in metric_fn_dict, f"Metric name {metric} not present in `metrics.py`" 23 | 24 | metric_fn = metric_fn_dict[metric] 25 | 26 | self.model.eval() 27 | return metric_fn(self.model, dataloader, **self.metric_kwargs) 28 | 29 | def validate(self): 30 | print(f"Validating {self.valid_metrics}") 31 | return {metric: self.evaluate(self.valid_loader, metric) 32 | for metric in self.valid_metrics} 33 | 34 | def test(self): 35 | print(f"Testing {self.test_metrics}") 36 | return {metric: self.evaluate(self.test_loader, metric) 37 | for metric in self.test_metrics} 38 | 39 | 40 | def create_evaluator(model, valid_loader, test_loader, valid_metrics, test_metrics, **kwargs): 41 | valid_metrics = set(valid_metrics) 42 | test_metrics = set(test_metrics) 43 | 44 | return Evaluator( 45 | model, 46 | valid_loader=valid_loader, 47 | test_loader=test_loader, 48 | valid_metrics=valid_metrics, 49 | test_metrics=test_metrics, 50 | **kwargs 51 | ) 52 | -------------------------------------------------------------------------------- /evaluators/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from opacus.grad_sample.grad_sample_module import GradSampleModule 3 | 4 | from utils import split_by_group 5 | 6 | 7 | def accuracy(model, dataloader, **kwargs): 8 | correct = 0 9 | total = 0 10 | with torch.no_grad(): 11 | device = model._module.device if isinstance(model, GradSampleModule) else model.device 12 | for _batch_idx, (data, labels, group) in enumerate(dataloader): 13 | data, labels = data.to(device), labels.to(device) 14 | outputs = model(data) 15 | _, predicted = torch.max(outputs, 1) 16 | total += labels.size(0) 17 | correct += (predicted == labels).sum() 18 | return (correct / total).item() 19 | 20 | 21 | def accuracy_per_group(model, dataloader, num_groups=None, **kwargs): 22 | correct_per_group = [0] * num_groups 23 | total_per_group = [0] * num_groups 24 | with torch.no_grad(): 25 | device = model._module.device if isinstance(model, GradSampleModule) else model.device 26 | for _batch_idx, (data, labels, group) in enumerate(dataloader): 27 | data, labels = data.to(device), labels.to(device) 28 | 29 | per_group = split_by_group(data, labels, group, num_groups) 30 | for i, group in enumerate(per_group): 31 | data_group, labels_group = group 32 | outputs = model(data_group) 33 | _, predicted = torch.max(outputs, 1) 34 | total_per_group[i] += labels_group.size(0) 35 | correct_per_group[i] += (predicted == labels_group).sum() 36 | return [float(correct_per_group[i] / total_per_group[i]) for i in range(num_groups)] 37 | 38 | 39 | def macro_accuracy(model, dataloader, num_classes=None, **kwargs): 40 | confusion_matrix = torch.zeros(num_classes, num_classes) 41 | with torch.no_grad(): 42 | device = model._module.device if isinstance(model, GradSampleModule) else model.device 43 | for _batch_idx, (data, labels, group) in enumerate(dataloader): 44 | data, labels = data.to(device), labels.to(device) 45 | outputs = model(data) 46 | _, predicted = torch.max(outputs, 1) 47 | for true_p, all_p in zip(labels.view(-1), predicted.view(-1)): 48 | confusion_matrix[true_p.long(), all_p.long()] += 1 49 | 50 | accs = confusion_matrix.diag() / confusion_matrix.sum(1) 51 | return accs.mean().item() 52 | -------------------------------------------------------------------------------- /experiment_scripts/adult_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for seed in {0..4} 4 | do 5 | dir="adult_$seed" 6 | angles='False' 7 | hessian='False' 8 | step=50 9 | 10 | echo "$seed nonpriv" 11 | python3 main.py --dataset=adult --method=regular --config group_ratios=1,1 --config make_valid_loader=0 --config net=mlp --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/adult_nonpriv --config seed=$seed 12 | 13 | echo "$seed dpsgd" 14 | python3 main.py --dataset=adult --method=dpsgd --config group_ratios=1,1 --config make_valid_loader=0 --config net=mlp --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config l2_norm_clip=0.5 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/adult_dpsgd --config seed=$seed 15 | 16 | echo "$seed dpsgd-f" 17 | python3 main.py --dataset=adult --method=dpsgd-f --config group_ratios=1,1 --config make_valid_loader=0 --config net=mlp --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config base_max_grad_norm=0.5 --config counts_noise_multiplier=10 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/adult_dpsgdf --config seed=$seed 18 | 19 | echo "$seed dpsgd-g" 20 | python3 main.py --dataset=adult --method=dpsgd-global --config group_ratios=1,1 --config make_valid_loader=0 --config net=mlp --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config l2_norm_clip=0.5 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config strict_max_grad_norm=50 --config lr=0.2 --config logdir=$dir/adult_dpsgdg --config seed=$seed 21 | 22 | echo "$seed dpsgd-g-adapt" 23 | python3 main.py --dataset=adult --method=dpsgd-global-adapt --config group_ratios=1,1 --config make_valid_loader=0 --config net=mlp --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config l2_norm_clip=0.5 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config strict_max_grad_norm=50 --config lr=0.2 --config bits_noise_multiplier=10 --config lr_Z=0.1 --config threshold=1 --config logdir=$dir/adult_dpsgdg --config seed=$seed 24 | done 25 | -------------------------------------------------------------------------------- /experiment_scripts/celeba_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | seed=0 4 | dir="celeba/$seed" 5 | angles="False" 6 | hessians="False" 7 | 8 | echo "$seed nonpriv" 9 | python3 main.py --dataset=celeba --method=regular --config lr=0.01 --config max_epochs=60 --config logdir=$dir/celeba_nonpriv --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 10 | 11 | echo "$seed dpsgd" 12 | python3 main.py --dataset=celeba --method=dpsgd --config lr=0.01 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config l2_norm_clip=1 --config logdir=$dir/celeba_dpsgd --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 13 | 14 | echo "$seed dpsgd-f" 15 | python3 main.py --dataset=celeba --method=dpsgd-f--config lr=0.01 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config base_max_grad_norm=1 --config counts_noise_multiplier=8 --config logdir=$dir/celeba_dpsgdf --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 16 | 17 | echo "$seed dpsgd-g" 18 | python3 main.py --dataset=celeba --method=dpsgd-global --config lr=0.1 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config l2_norm_clip=1 --config strict_max_grad_norm=100 --config logdir=$dir/celeba_dpsgdg --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 19 | 20 | echo "$seed dpsgd-global-adapt" 21 | python3 main.py --dataset=celeba --method=dpsgd-global-adapt --config lr=0.1 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config l2_norm_clip=1 --config strict_max_grad_norm=50 --config logdir=$dir/celeba_dpsgdg_adapt --config threshold=0.7 --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 22 | 23 | -------------------------------------------------------------------------------- /experiment_scripts/dutch_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Dutch dataset: 4 | # Download from https://raw.githubusercontent.com/tailequy/fairness_dataset/main/Dutch_census/dutch_census_2001.arff 5 | # and save file to fair-dp/data/dutch/ 6 | 7 | for seed in {0..4} 8 | do 9 | dir="dutch_logreg/$seed" 10 | angles='False' 11 | hessian='False' 12 | step=100 13 | 14 | echo "$seed nonpriv" 15 | python3 main.py --dataset=dutch --method=regular --config group_ratios=1,0.333 --config make_valid_loader=0 --config net=logistic --config lr=0.5 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/dutch_nonpriv --config seed=$seed 16 | 17 | echo "$seed dpsgd" 18 | python3 main.py --dataset=dutch --method=dpsgd --config group_ratios=1,0.333 --config make_valid_loader=0 --config net=logistic --config lr=0.5 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config l2_norm_clip=0.5 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/dutch_dpsgd --config seed=$seed 19 | 20 | echo "$seed dpsgd-f" 21 | python3 main.py --dataset=dutch --method=dpsgd-f --config group_ratios=1,0.333 --config make_valid_loader=0 --config net=logistic --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config base_max_grad_norm=0.5 --config counts_noise_multiplier=10 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/dutch_dpsgdf --config seed=$seed 22 | 23 | echo "$seed dpsgd-g" 24 | python3 main.py --dataset=dutch --method=dpsgd-global --config group_ratios=1,0.333 --config make_valid_loader=0 --config net=logistic --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config l2_norm_clip=0.5 --config strict_max_grad_norm=50 --config lr=0.5 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=$step --config logdir=$dir/dutch_dpsgdg --config seed=$seed 25 | 26 | echo "$seed dpsgd-g-adapt" 27 | python3 main.py --dataset=dutch --method=dpsgd-global-adapt --config group_ratios=1,0.333 --config make_valid_loader=0 --config net=logistic --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=20 --config delta=1e-6 --config noise_multiplier=1 --config l2_norm_clip=0.5 --config strict_max_grad_norm=50 --config lr=0.5 --config lr_Z=0.1 --config threshold=1 --config evaluate_angles=$angles --config evaluate_hessian=$hessian --config angle_comp_step=100 --config logdir=$dir/dutch_dpsgdg_adapt --config seed=$seed 28 | done 29 | 30 | -------------------------------------------------------------------------------- /experiment_scripts/mnist_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for seed in {0..4} 4 | do 5 | dir="mnist_$seed" 6 | angles="false" 7 | hessians="false" 8 | 9 | echo "$seed nonpriv" 10 | python3 main.py --dataset=mnist --method=regular --config group_ratios=-1,-1,-1,-1,-1,-1,-1,-1,0.09,-1 --config make_valid_loader=0 --config net=cnn --config hidden_channels=32,16 --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=60 --config logdir=$dir/mnist_nonpriv --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 11 | 12 | echo "$seed dpsgd" 13 | python3 main.py --dataset=mnist --method=dpsgd --config group_ratios=-1,-1,-1,-1,-1,-1,-1,-1,0.09,-1 --config make_valid_loader=0 --config net=cnn --config hidden_channels=32,16 --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config l2_norm_clip=1 --config logdir=$dir/mnist_dpsgd --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 14 | 15 | echo "$seed dpsgd-f" 16 | python3 main.py --dataset=mnist --method=dpsgd-f --config group_ratios=-1,-1,-1,-1,-1,-1,-1,-1,0.09,-1 --config make_valid_loader=0 --config net=cnn --config hidden_channels=32,16 --config lr=0.01 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config base_max_grad_norm=1 --config counts_noise_multiplier=8 --config logdir=$dir/mnist_dpsgdf --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 17 | 18 | echo "$seed dpsgd-g" 19 | python3 main.py --dataset=mnist --method=dpsgd-global --config group_ratios=-1,-1,-1,-1,-1,-1,-1,-1,0.09,-1 --config make_valid_loader=0 --config net=cnn --config hidden_channels=32,16 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config l2_norm_clip=20 --config strict_max_grad_norm=100 --config lr=0.01 --config logdir=$dir/mnist_dpsgdg --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 20 | 21 | echo "$seed dpsgd-global-adapt" 22 | python3 main.py --dataset=mnist --method=dpsgd-global-adapt --config group_ratios=-1,-1,-1,-1,-1,-1,-1,-1,0.09,-1 --config make_valid_loader=0 --config net=cnn --config hidden_channels=32,16 --config train_batch_size=256 --config valid_batch_size=256 --config test_batch_size=256 --config max_epochs=60 --config delta=1e-6 --config noise_multiplier=0.8 --config l2_norm_clip=1 --config strict_max_grad_norm=50 --config lr=0.1 --config logdir=$dir/mnist_dpsgdg_adapt --config threshold=0.7 --config seed=$seed --config evaluate_angles=$angles --config evaluate_hessian=$hessians --config angle_comp_step=200 23 | done 24 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | import random 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | from opacus import PrivacyEngine 9 | 10 | from config import get_config, parse_config_arg 11 | from datasets import get_loaders_from_config 12 | from evaluators import create_evaluator 13 | from models import create_model 14 | from privacy_engines.dpsgd_f_engine import DPSGDF_PrivacyEngine 15 | from privacy_engines.dpsgd_global_adaptive_engine import DPSGDGlobalAdaptivePrivacyEngine 16 | from privacy_engines.dpsgd_global_engine import DPSGDGlobalPrivacyEngine 17 | from trainers import create_trainer 18 | from utils import privacy_checker 19 | from writer import Writer 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description="Fairness for DP-SGD") 24 | 25 | parser.add_argument("--dataset", type=str, default="mnist", 26 | help="Dataset to train on.") 27 | parser.add_argument("--method", type=str, default="regular", 28 | choices=["regular", "dpsgd", "dpsgd-f", "fairness-lens", "dpsgd-global", "dpsgd-global-adapt"], 29 | help="Method for training and clipping.") 30 | 31 | parser.add_argument("--config", default=[], action="append", 32 | help="Override config entries. Specify as `key=value`.") 33 | 34 | args = parser.parse_args() 35 | 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | 38 | cfg = get_config( 39 | dataset=args.dataset, 40 | method=args.method, 41 | ) 42 | cfg = {**cfg, **dict(parse_config_arg(kv) for kv in args.config)} 43 | 44 | # Checks group_ratios is specified correctly 45 | if len(cfg["group_ratios"]) != cfg["num_groups"]: 46 | raise ValueError( 47 | "Number of group ratios, {}, not equal to number of groups of {}, {}" 48 | .format(len(cfg["group_ratios"]), cfg["protected_group"], cfg["num_groups"]) 49 | ) 50 | 51 | if any(x > 1 or (x < 0 and x != -1) for x in cfg["group_ratios"]): 52 | raise ValueError("All elements of group_ratios must be in [0,1]. Indicate no sampling with -1.") 53 | 54 | pprint.sorted = lambda x, key=None: x 55 | pp = pprint.PrettyPrinter(indent=4) 56 | print(10 * "-" + "-cfg--" + 10 * "-") 57 | pp.pprint(cfg) 58 | 59 | # Set random seeds based on config 60 | random.seed(cfg["seed"]) 61 | np.random.seed(cfg["seed"]) 62 | torch.manual_seed(cfg["seed"]) 63 | 64 | train_loader, valid_loader, test_loader = get_loaders_from_config( 65 | cfg, 66 | device 67 | ) 68 | 69 | writer = Writer( 70 | logdir=cfg.get("logdir_root", "runs"), 71 | make_subdir=True, 72 | tag_group=args.dataset, 73 | dir_name=cfg.get("logdir", "") 74 | ) 75 | writer.write_json(tag="config", data=cfg) 76 | 77 | model, optimizer = create_model(cfg, device) 78 | 79 | if cfg["method"] != "regular": 80 | sample_rate = 1 / len(train_loader) 81 | privacy_checker(sample_rate, cfg) 82 | 83 | if cfg["method"] == "dpsgd": 84 | privacy_engine = PrivacyEngine(accountant=cfg["accountant"]) 85 | model, optimizer, train_loader = privacy_engine.make_private( 86 | module=model, 87 | optimizer=optimizer, 88 | data_loader=train_loader, 89 | noise_multiplier=cfg["noise_multiplier"], 90 | max_grad_norm=cfg["l2_norm_clip"] # C 91 | ) 92 | elif cfg["method"] == "dpsgd-global": 93 | privacy_engine = DPSGDGlobalPrivacyEngine(accountant=cfg["accountant"]) 94 | model, optimizer, train_loader = privacy_engine.make_private( 95 | module=model, 96 | optimizer=optimizer, 97 | data_loader=train_loader, 98 | noise_multiplier=cfg["noise_multiplier"], # sigma in sigma * C 99 | max_grad_norm=cfg["l2_norm_clip"], # C 100 | ) 101 | elif cfg["method"] == "dpsgd-f": 102 | privacy_engine = DPSGDF_PrivacyEngine(accountant=cfg["accountant"]) 103 | model, optimizer, train_loader = privacy_engine.make_private( 104 | module=model, 105 | optimizer=optimizer, 106 | data_loader=train_loader, 107 | noise_multiplier=cfg["noise_multiplier"], 108 | max_grad_norm=0 # this parameter is not applicable for DPSGD-F 109 | ) 110 | 111 | elif cfg["method"] == "dpsgd-global-adapt": 112 | privacy_engine = DPSGDGlobalAdaptivePrivacyEngine(accountant=cfg["accountant"]) 113 | model, optimizer, train_loader = privacy_engine.make_private( 114 | module=model, 115 | optimizer=optimizer, 116 | data_loader=train_loader, 117 | noise_multiplier=cfg["noise_multiplier"], # sigma in sigma * C 118 | max_grad_norm=cfg["l2_norm_clip"], # C 119 | ) 120 | else: 121 | # doing regular training 122 | privacy_engine = PrivacyEngine() 123 | model, optimizer, train_loader = privacy_engine.make_private( 124 | module=model, 125 | optimizer=optimizer, 126 | data_loader=train_loader, 127 | noise_multiplier=0, 128 | max_grad_norm=sys.float_info.max, 129 | poisson_sampling=False 130 | ) 131 | 132 | evaluator = create_evaluator( 133 | model, 134 | valid_loader=valid_loader, test_loader=test_loader, 135 | valid_metrics=cfg["valid_metrics"], 136 | test_metrics=cfg["test_metrics"], 137 | num_classes=cfg["output_dim"], 138 | num_groups=cfg["num_groups"], 139 | ) 140 | 141 | trainer = create_trainer( 142 | train_loader, 143 | valid_loader, 144 | test_loader, 145 | model, 146 | optimizer, 147 | privacy_engine, 148 | evaluator, 149 | writer, 150 | device, 151 | cfg 152 | ) 153 | 154 | trainer.train() 155 | 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_factory import create_model -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .neural_networks import (CNN, MLP, LogisticRegression) 5 | 6 | activation_map = { 7 | "relu": nn.ReLU, 8 | "tanh": nn.Tanh, 9 | "swish": nn.SiLU 10 | } 11 | 12 | 13 | def create_model(config, device): 14 | if config["net"] == "mlp": 15 | model = MLP( 16 | n_units_list=[config["data_dim"], *config["hidden_dims"], config["output_dim"]], 17 | activation=activation_map[config.get("activation", "relu")], 18 | ) 19 | 20 | elif config["net"] == "cnn": 21 | model = CNN( 22 | input_channels=config["data_shape"][0], 23 | hidden_channels_list=config["hidden_channels"], 24 | output_dim=config["output_dim"], 25 | kernel_size=config["kernel_size"], 26 | stride=config["stride"], 27 | image_height=config["data_shape"][1], 28 | activation=activation_map[config.get("activation", "relu")], 29 | ) 30 | 31 | elif config["net"] == "logistic": 32 | model = LogisticRegression( 33 | input_dim=config["data_shape"][0], 34 | output_dim=config["output_dim"], 35 | ) 36 | 37 | else: 38 | raise ValueError(f"Unknown network type {config['net']}") 39 | 40 | if config["optimizer"] == "adam": 41 | optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) 42 | elif config["optimizer"] == "sgd": 43 | optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"]) 44 | else: 45 | raise ValueError(f"Unknown optimizer") 46 | 47 | model.set_device(device) 48 | 49 | return model, optimizer 50 | -------------------------------------------------------------------------------- /models/neural_networks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class LogisticRegression(nn.Module): 8 | def __init__(self, input_dim, output_dim): 9 | super().__init__() 10 | self.linear = torch.nn.Linear(input_dim, output_dim) 11 | 12 | def set_device(self, device): 13 | self.device = device 14 | self.to(device) 15 | 16 | def forward(self, x): 17 | outputs = torch.sigmoid(self.linear(x)) 18 | return outputs 19 | 20 | 21 | class MLP(nn.Module): 22 | 23 | def __init__(self, n_units_list, activation=nn.ReLU): 24 | super().__init__() 25 | layers = [] 26 | prev_layer_size = n_units_list[0] 27 | for n_units in n_units_list[1:-1]: 28 | layers.append(nn.Linear(in_features=prev_layer_size, out_features=n_units)) 29 | prev_layer_size = n_units 30 | layers.append(activation()) 31 | layers.append(nn.Linear(in_features=prev_layer_size, out_features=n_units_list[-1])) 32 | self.net = nn.Sequential(*layers) 33 | 34 | def set_device(self, device): 35 | self.device = device 36 | self.to(device) 37 | 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | 42 | class CNN(nn.Module): 43 | def __init__( 44 | self, 45 | input_channels, 46 | hidden_channels_list, 47 | output_dim, 48 | kernel_size, 49 | stride, 50 | image_height, 51 | activation=nn.ReLU, 52 | ): 53 | super().__init__() 54 | if type(stride) not in [list, tuple]: 55 | stride = [stride for _ in hidden_channels_list] 56 | 57 | if type(kernel_size) not in [list, tuple]: 58 | kernel_size = [kernel_size for _ in hidden_channels_list] 59 | 60 | cnn_layers = [] 61 | prev_channels = input_channels 62 | for hidden_channels, k, s in zip(hidden_channels_list, kernel_size, stride): 63 | cnn_layers.append(nn.Conv2d(prev_channels, hidden_channels, k, s)) 64 | cnn_layers.append(activation()) 65 | prev_channels = hidden_channels 66 | 67 | # NOTE: Assumes square image 68 | image_height = self._get_new_image_height(image_height, k, s) 69 | self.cnn_layers = nn.ModuleList(cnn_layers) 70 | 71 | self.fc_layer = nn.Linear(prev_channels * image_height ** 2, output_dim) 72 | 73 | def set_device(self, device): 74 | self.device = device 75 | self.to(device) 76 | 77 | def forward(self, x): 78 | for layer in self.cnn_layers: 79 | x = layer(x) 80 | x = torch.flatten(x, start_dim=1) 81 | 82 | return self.fc_layer(x) 83 | 84 | def _get_new_image_height(self, height, kernel, stride): 85 | # cf. https://pytorch.org/docs/1.9.1/generated/torch.nn.Conv2d.html 86 | # Assume dilation = 1, padding = 0 87 | return math.floor((height - kernel) / stride + 1) 88 | -------------------------------------------------------------------------------- /optimizers/dpsgd_f_optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import torch 4 | from opacus.optimizers.optimizer import DPOptimizer, _check_processed_flag, _mark_as_processed, _generate_noise, \ 5 | _get_flat_grad_sample 6 | from torch.optim import Optimizer 7 | 8 | 9 | class DPSGDF_Optimizer(DPOptimizer): 10 | """ 11 | Customized optimizer for DPSGD-F, inherited from DPOptimizer and overwriting the following 12 | 13 | - clip_and_accumulate(self, per_sample_clip_bound) now takes an extra tensor list parameter indicating the clipping bound per sample 14 | - add_noise(self, max_grad_clip:float) takes an extra paramter ``max_grad_clip``, 15 | which is the maximum clipping factor among all the groups, i.e. max(per_sample_clip_bound) 16 | - pre_step() and step() are overwritten by taking this extra parameter 17 | """ 18 | 19 | def __init__( 20 | self, 21 | optimizer: Optimizer, 22 | *, 23 | noise_multiplier: float, 24 | expected_batch_size: Optional[int], 25 | loss_reduction: str = "mean", 26 | generator=None, 27 | secure_mode: bool = False, 28 | ): 29 | super().__init__( 30 | optimizer, 31 | noise_multiplier=noise_multiplier, 32 | max_grad_norm=0, # not applicable for DPSGDF_Optimizer 33 | expected_batch_size=expected_batch_size, 34 | loss_reduction=loss_reduction, 35 | generator=generator, 36 | secure_mode=secure_mode, 37 | ) 38 | 39 | def clip_and_accumulate(self, per_sample_clip_bound): 40 | """ 41 | Clips gradient according to per sample clipping bounds and accumulates gradient for a given batch 42 | Args: 43 | per_sample_clip_bound: a tensor list of clip bound per sample 44 | """ 45 | # self.grad_samples are calculated from parent class, equivalent to the following 46 | # 47 | # ret = [] 48 | # for p in self.params: 49 | # ret.append(_get_flat_grad_sample(p)) 50 | # return ret 51 | 52 | # For neural network, this per_param_norms is per layer's normalization across parameters, not samples 53 | # output dimension: num_layers * num_samples(per batch) 54 | per_param_norms = [ 55 | g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples 56 | ] 57 | 58 | # torch.stack(per_param_norms, dim=1) will make the dimension num_samples * num_layers 59 | # per_sample_norms has dimension of num_samples 60 | per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) 61 | per_sample_clip_factor = (per_sample_clip_bound / (per_sample_norms + 1e-6)).clamp( 62 | max=1.0 63 | ) 64 | 65 | for p in self.params: 66 | _check_processed_flag(p.grad_sample) 67 | grad_sample = _get_flat_grad_sample(p) 68 | # equivalent to grad = grad * min(1, Ck / norm(grad)) 69 | grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) 70 | 71 | if p.summed_grad is not None: 72 | p.summed_grad += grad 73 | else: 74 | p.summed_grad = grad 75 | 76 | _mark_as_processed(p.grad_sample) 77 | 78 | def add_noise(self, max_grad_clip: float): 79 | """ 80 | Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad`` 81 | Args: 82 | max_grad_clip: C = max(C_k), for all group k 83 | """ 84 | 85 | for p in self.params: 86 | _check_processed_flag(p.summed_grad) 87 | 88 | noise = _generate_noise( 89 | std=self.noise_multiplier * max_grad_clip, 90 | reference=p.summed_grad, 91 | generator=self.generator, 92 | secure_mode=self.secure_mode, 93 | ) 94 | p.grad = (p.summed_grad + noise).view_as(p.grad) 95 | _mark_as_processed(p.summed_grad) 96 | 97 | def pre_step( 98 | self, per_sample_clip_bound: torch.Tensor, closure: Optional[Callable[[], float]] = None 99 | ) -> Optional[float]: 100 | """ 101 | Perform actions specific to ``DPOptimizer`` before calling 102 | underlying ``optimizer.step()`` 103 | Args: 104 | per_sample_clip_bound: Defines the clipping bound for each sample. 105 | closure: A closure that reevaluates the model and 106 | returns the loss. Optional for most optimizers. 107 | """ 108 | self.clip_and_accumulate(per_sample_clip_bound) 109 | if self._check_skip_next_step(): 110 | self._is_last_step_skipped = True 111 | return False 112 | 113 | self.add_noise(torch.max(per_sample_clip_bound).item()) 114 | self.scale_grad() 115 | 116 | if self.step_hook: 117 | self.step_hook(self) 118 | 119 | self._is_last_step_skipped = False 120 | return True 121 | 122 | def step(self, per_sample_clip_bound: torch.Tensor, closure: Optional[Callable[[], float]] = None) -> Optional[ 123 | float]: 124 | if closure is not None: 125 | with torch.enable_grad(): 126 | closure() 127 | 128 | if self.pre_step(per_sample_clip_bound): 129 | return self.original_optimizer.step() 130 | else: 131 | return None 132 | -------------------------------------------------------------------------------- /optimizers/dpsgd_global_adaptive_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from opacus.optimizers.optimizer import _check_processed_flag, _get_flat_grad_sample, _mark_as_processed 3 | 4 | from optimizers.dpsgd_global_optimizer import DPSGD_Global_Optimizer 5 | 6 | 7 | class DPSGD_Global_Adaptive_Optimizer(DPSGD_Global_Optimizer): 8 | 9 | def clip_and_accumulate(self, strict_max_grad_norm): 10 | """ 11 | Performs gradient clipping. 12 | Stores clipped and aggregated gradients into `p.summed_grad``` 13 | """ 14 | 15 | per_param_norms = [ 16 | g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples 17 | ] 18 | per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) 19 | per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp( 20 | max=1.0 21 | ) 22 | 23 | # C = max_grad_norm 24 | # Z = strict_max_grad_norm 25 | # condition is equivalent to norm[i] <= Z 26 | # when condition holds, scale gradient by C/Z 27 | # otherwise, clip to norm C, note that here we remove the aggressive clipping in global method 28 | per_sample_global_clip_factor = torch.where(per_sample_clip_factor >= self.max_grad_norm / strict_max_grad_norm, 29 | # scale by C/Z 30 | torch.ones_like( 31 | per_sample_clip_factor) * self.max_grad_norm / strict_max_grad_norm, 32 | per_sample_clip_factor) # clip to C 33 | for p in self.params: 34 | _check_processed_flag(p.grad_sample) 35 | 36 | grad_sample = _get_flat_grad_sample(p) 37 | 38 | # refer to lines 197-199 in 39 | # https://github.com/pytorch/opacus/blob/ee6867e6364781e67529664261243c16c3046b0b/opacus/per_sample_gradient_clip.py 40 | # as well as https://github.com/woodyx218/opacus_global_clipping README 41 | grad = torch.einsum("i,i...", per_sample_global_clip_factor, grad_sample) 42 | 43 | if p.summed_grad is not None: 44 | p.summed_grad += grad 45 | else: 46 | p.summed_grad = grad 47 | 48 | _mark_as_processed(p.grad_sample) 49 | -------------------------------------------------------------------------------- /optimizers/dpsgd_global_optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import torch 4 | from opacus.optimizers.optimizer import DPOptimizer, _check_processed_flag, _get_flat_grad_sample, _mark_as_processed 5 | 6 | 7 | class DPSGD_Global_Optimizer(DPOptimizer): 8 | 9 | def clip_and_accumulate(self, strict_max_grad_norm): 10 | """ 11 | Performs gradient clipping. 12 | Stores clipped and aggregated gradients into `p.summed_grad``` 13 | """ 14 | 15 | per_param_norms = [ 16 | g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples 17 | ] 18 | per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) 19 | per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp( 20 | max=1.0 21 | ) 22 | 23 | # C = max_grad_norm 24 | # Z = strict_max_grad_norm 25 | # condition is equivalent to norm[i] <= Z 26 | # when condition holds, scale gradient by C/Z 27 | # otherwise, clip to 0 28 | per_sample_global_clip_factor = torch.where(per_sample_clip_factor >= self.max_grad_norm / strict_max_grad_norm, 29 | # scale by C/Z 30 | torch.ones_like( 31 | per_sample_clip_factor) * self.max_grad_norm / strict_max_grad_norm, 32 | torch.zeros_like(per_sample_clip_factor)) # clip to 0 33 | for p in self.params: 34 | _check_processed_flag(p.grad_sample) 35 | 36 | grad_sample = _get_flat_grad_sample(p) 37 | 38 | # refer to lines 197-199 in 39 | # https://github.com/pytorch/opacus/blob/ee6867e6364781e67529664261243c16c3046b0b/opacus/per_sample_gradient_clip.py 40 | # as well as https://github.com/woodyx218/opacus_global_clipping README 41 | grad = torch.einsum("i,i...", per_sample_global_clip_factor, grad_sample) 42 | 43 | if p.summed_grad is not None: 44 | p.summed_grad += grad 45 | else: 46 | p.summed_grad = grad 47 | 48 | _mark_as_processed(p.grad_sample) 49 | 50 | # note add_noise does not have to be modified since max_grad_norm = C is sensitivity 51 | 52 | def pre_step( 53 | self, strict_max_grad_norm, closure: Optional[Callable[[], float]] = None 54 | ) -> Optional[float]: 55 | """ 56 | Perform actions specific to ``DPOptimizer`` before calling 57 | underlying ``optimizer.step()`` 58 | Args: 59 | closure: A closure that reevaluates the model and 60 | returns the loss. Optional for most optimizers. 61 | """ 62 | self.clip_and_accumulate(strict_max_grad_norm) 63 | if self._check_skip_next_step(): 64 | self._is_last_step_skipped = True 65 | return False 66 | 67 | self.add_noise() 68 | self.scale_grad() 69 | 70 | if self.step_hook: 71 | self.step_hook(self) 72 | 73 | self._is_last_step_skipped = False 74 | return True 75 | 76 | def step(self, strict_max_grad_norm, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 77 | if closure is not None: 78 | with torch.enable_grad(): 79 | closure() 80 | 81 | if self.pre_step(strict_max_grad_norm): 82 | return self.original_optimizer.step() 83 | else: 84 | return None 85 | -------------------------------------------------------------------------------- /privacy_engines/dpsgd_f_engine.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from opacus import PrivacyEngine 4 | from opacus.optimizers.optimizer import DPOptimizer 5 | from optimizers.dpsgd_f_optimizer import DPSGDF_Optimizer 6 | from torch import optim 7 | 8 | 9 | class DPSGDF_PrivacyEngine(PrivacyEngine): 10 | """ 11 | This class defines the customized privacy engine for DPSGD-F. 12 | Specifically, it overwrites the _prepare_optimizer() method from parent class to return DPSGDF_Optimizer 13 | """ 14 | 15 | def __init__(self, *, accountant: str = "rdp", secure_mode: bool = False): 16 | if accountant != 'rdp': 17 | raise ValueError("DPSGD-F must use an RDP accountant since it composes SGM with different parameters.") 18 | 19 | super().__init__(accountant=accountant, secure_mode=secure_mode) 20 | 21 | def _prepare_optimizer( 22 | self, 23 | optimizer: optim.Optimizer, 24 | *, 25 | noise_multiplier: float, 26 | max_grad_norm: Union[float, List[float]], # not applicable in DPSGDF 27 | expected_batch_size: int, 28 | loss_reduction: str = "mean", 29 | distributed: bool = False, 30 | clipping: str = "flat", 31 | noise_generator=None, 32 | ) -> DPOptimizer: 33 | if isinstance(optimizer, DPOptimizer): 34 | optimizer = optimizer.original_optimizer 35 | 36 | generator = None 37 | if self.secure_mode: 38 | generator = self.secure_rng 39 | elif noise_generator is not None: 40 | generator = noise_generator 41 | 42 | return DPSGDF_Optimizer( 43 | optimizer=optimizer, 44 | noise_multiplier=noise_multiplier, 45 | expected_batch_size=expected_batch_size, 46 | loss_reduction=loss_reduction, 47 | generator=generator, 48 | secure_mode=self.secure_mode, 49 | ) 50 | -------------------------------------------------------------------------------- /privacy_engines/dpsgd_global_adaptive_engine.py: -------------------------------------------------------------------------------- 1 | # adapted from opacus/privacy_engine.py 2 | # https://github.com/pytorch/opacus/blob/030b723fb89aabf3cde663018bb63e5bb95f197a/opacus/privacy_engine.py 3 | # opacus v1.1.0 4 | 5 | from typing import List, Union 6 | 7 | from opacus import PrivacyEngine 8 | from opacus.optimizers import DPOptimizer 9 | from optimizers.dpsgd_global_adaptive_optimizer import DPSGD_Global_Adaptive_Optimizer 10 | from torch import optim 11 | 12 | 13 | class DPSGDGlobalAdaptivePrivacyEngine(PrivacyEngine): 14 | """ 15 | This class defines the customized privacy engine for DPSGD-Global-Adaptive. 16 | Specifically, it overwrites the _prepare_optimizer() method from parent class to return DPSGD_Global_Optimizer 17 | """ 18 | 19 | def _prepare_optimizer( 20 | self, 21 | optimizer: optim.Optimizer, 22 | *, 23 | noise_multiplier: float, 24 | max_grad_norm: Union[float, List[float]], 25 | expected_batch_size: int, 26 | loss_reduction: str = "mean", 27 | distributed: bool = False, # deprecated for this method 28 | clipping: str = "flat", # deprecated for this method 29 | noise_generator=None, 30 | ) -> DPOptimizer: 31 | if isinstance(optimizer, DPOptimizer): 32 | optimizer = optimizer.original_optimizer 33 | 34 | generator = None 35 | if self.secure_mode: 36 | generator = self.secure_rng 37 | elif noise_generator is not None: 38 | generator = noise_generator 39 | 40 | optimizer = DPSGD_Global_Adaptive_Optimizer(optimizer=optimizer, 41 | noise_multiplier=noise_multiplier, 42 | max_grad_norm=max_grad_norm, 43 | expected_batch_size=expected_batch_size, 44 | loss_reduction=loss_reduction, 45 | generator=generator, 46 | secure_mode=self.secure_mode) 47 | 48 | return optimizer 49 | -------------------------------------------------------------------------------- /privacy_engines/dpsgd_global_engine.py: -------------------------------------------------------------------------------- 1 | # adapted from opacus/privacy_engine.py 2 | # https://github.com/pytorch/opacus/blob/030b723fb89aabf3cde663018bb63e5bb95f197a/opacus/privacy_engine.py 3 | # opacus v1.1.0 4 | 5 | from typing import List, Union 6 | 7 | from opacus import PrivacyEngine 8 | from opacus.optimizers import DPOptimizer 9 | from optimizers.dpsgd_global_optimizer import DPSGD_Global_Optimizer 10 | from torch import optim 11 | 12 | 13 | class DPSGDGlobalPrivacyEngine(PrivacyEngine): 14 | """ 15 | This class defines the customized privacy engine for DPSGD-Global. 16 | Specifically, it overwrites the _prepare_optimizer() method from parent class to return DPSGD_Global_Optimizer 17 | """ 18 | 19 | def _prepare_optimizer( 20 | self, 21 | optimizer: optim.Optimizer, 22 | *, 23 | noise_multiplier: float, 24 | max_grad_norm: Union[float, List[float]], 25 | expected_batch_size: int, 26 | loss_reduction: str = "mean", 27 | distributed: bool = False, # deprecated for this method 28 | clipping: str = "flat", # deprecated for this method 29 | noise_generator=None, 30 | ) -> DPOptimizer: 31 | if isinstance(optimizer, DPOptimizer): 32 | optimizer = optimizer.original_optimizer 33 | 34 | generator = None 35 | if self.secure_mode: 36 | generator = self.secure_rng 37 | elif noise_generator is not None: 38 | generator = noise_generator 39 | 40 | optimizer = DPSGD_Global_Optimizer(optimizer=optimizer, 41 | noise_multiplier=noise_multiplier, 42 | max_grad_norm=max_grad_norm, 43 | expected_batch_size=expected_batch_size, 44 | loss_reduction=loss_reduction, 45 | generator=generator, 46 | secure_mode=self.secure_mode) 47 | 48 | return optimizer 49 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer_factory import create_trainer 2 | -------------------------------------------------------------------------------- /trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from datasets.loaders import get_loader 9 | from functorch import make_functional, vjp, grad 10 | from torch.nn.functional import cosine_similarity 11 | from tqdm import tqdm 12 | from utils import * 13 | 14 | 15 | class BaseTrainer: 16 | """Base class for various training methods""" 17 | 18 | def __init__(self, 19 | model, 20 | optimizer, 21 | train_loader, 22 | valid_loader, 23 | test_loader, 24 | writer, 25 | evaluator, 26 | device, 27 | method="regular", 28 | max_epochs=100, 29 | num_groups=None, 30 | selected_groups=[0, 1], 31 | evaluate_angles=False, 32 | evaluate_hessian=False, 33 | angle_comp_step=10, 34 | lr=0.01, 35 | seed=0, 36 | num_hutchinson_estimates=100, 37 | sampled_expected_loss=False 38 | ): 39 | 40 | self.model = model 41 | self.optimizer = optimizer 42 | self.train_loader = train_loader 43 | self.valid_loader = valid_loader 44 | self.test_loader = test_loader 45 | self.writer = writer 46 | self.evaluator = evaluator 47 | self.device = device 48 | 49 | self.method = method 50 | self.max_epochs = max_epochs 51 | self.num_groups = num_groups 52 | self.num_batch = len(self.train_loader) 53 | self.selected_groups = selected_groups 54 | self.epoch = 0 55 | self.num_layers = get_num_layers(self.model) 56 | 57 | self.evaluate_angles = evaluate_angles 58 | self.evaluate_hessian = evaluate_hessian 59 | self.angle_comp_step = angle_comp_step 60 | self.lr = lr 61 | self.seed = seed 62 | self.num_hutchinson_estimates = num_hutchinson_estimates 63 | self.sampled_expected_loss = sampled_expected_loss 64 | 65 | def _train_epoch(self, cosine_sim_per_epoch, expected_loss, param_for_step=None): 66 | # methods: regular, dpsgd, dpsgd-global, dpsgd-f, dpsgd-global-adapt 67 | criterion = torch.nn.CrossEntropyLoss() 68 | losses = [] 69 | losses_per_group = np.zeros(self.num_groups) 70 | all_grad_norms = [[] for _ in range(self.num_groups)] 71 | group_max_grads = [0] * self.num_groups 72 | g_B_k_norms = [[] for _ in range(self.num_groups)] 73 | 74 | for _batch_idx, (data, target, group) in enumerate(tqdm(self.train_loader)): 75 | data, target = data.to(self.device), target.to(self.device) 76 | self.optimizer.zero_grad() 77 | output = self.model(data) 78 | loss = criterion(output, target) 79 | losses_per_group = self.get_losses_per_group(criterion, data, target, group, losses_per_group) 80 | loss.backward() 81 | per_sample_grads = self.flatten_all_layer_params() 82 | 83 | # get sum of grads over groups over current batch 84 | if self.method == "regular": 85 | grad_norms, _, sum_grad_vec_batch, sum_clip_grad_vec_batch = self.get_sum_grad_batch_from_vec( 86 | per_sample_grads, group) 87 | elif self.method in ["dpsgd", "dpsgd-global", "dpsgd-global-adapt"]: 88 | grad_norms, _, sum_grad_vec_batch, sum_clip_grad_vec_batch = self.get_sum_grad_batch_from_vec( 89 | per_sample_grads, group, clipping_bound=self.optimizer.max_grad_norm) 90 | elif self.method == "dpsgd-f": 91 | C = self.compute_clipping_bound_per_sample(per_sample_grads, group) 92 | grad_norms, _, sum_grad_vec_batch, sum_clip_grad_vec_batch = self.get_sum_grad_batch_from_vec( 93 | per_sample_grads, group) 94 | _, group_counts_batch = split_by_group(data, target, group, self.num_groups, return_counts=1) 95 | g_B, g_B_k, bar_g_B, bar_g_B_k = self.mean_grads_over(group_counts_batch, sum_grad_vec_batch, 96 | sum_clip_grad_vec_batch) 97 | if (self.evaluate_angles or self.evaluate_hessian) and ( 98 | self.epoch * self.num_batch + _batch_idx) % self.angle_comp_step == 0: 99 | # compute sum of gradients over groups over entire training dataset 100 | if self.method == "regular": 101 | sum_grad_vec_all, sum_clip_grad_vec_all, group_counts = self.get_sum_grad( 102 | self.train_loader.dataset, criterion, g_B, bar_g_B, expected_loss, _batch_idx) 103 | elif self.method in ["dpsgd", "dpsgd-f", "dpsgd-global", "dpsgd-global-adapt"]: 104 | sum_grad_vec_all, sum_clip_grad_vec_all, group_counts = self.get_sum_grad(self.train_loader.dataset, 105 | criterion, 106 | g_B, 107 | bar_g_B, 108 | expected_loss, 109 | _batch_idx, 110 | clipping_bound=self.optimizer.max_grad_norm) 111 | 112 | # average sum of gradients per group over entire training dataset 113 | _, g_D_k, _, _ = self.mean_grads_over(group_counts, sum_grad_vec_all, sum_clip_grad_vec_all) 114 | cosine_sim_per_epoch.append(self.evaluate_cosine_sim(_batch_idx, g_D_k, g_B, bar_g_B, g_B_k, bar_g_B_k)) 115 | self.optimizer.zero_grad() 116 | output = self.model(data) 117 | loss = criterion(output, target) 118 | loss.backward() 119 | 120 | for i in range(self.num_groups): 121 | if len(grad_norms[i]) != 0: 122 | all_grad_norms[i] = all_grad_norms[i] + grad_norms[i] 123 | group_max_grads[i] = max(group_max_grads[i], max(grad_norms[i])) 124 | g_B_k_norms[i].append(torch.linalg.norm(g_B_k[i]).item()) 125 | 126 | if self.method == "dpsgd-f": 127 | self.optimizer.step(C) 128 | elif self.method == "dpsgd-global": 129 | self.optimizer.step(self.strict_max_grad_norm) 130 | elif self.method == "dpsgd-global-adapt": 131 | next_Z = self._update_Z(per_sample_grads, self.strict_max_grad_norm) 132 | self.optimizer.step(self.strict_max_grad_norm) 133 | self.strict_max_grad_norm = next_Z 134 | else: 135 | self.optimizer.step() 136 | losses.append(loss.item()) 137 | if self.method != "regular": 138 | if self.method in ["dpsgd-f", "dpsgd-global-adapt"]: 139 | self._update_privacy_accountant() 140 | epsilon = self.privacy_engine.get_epsilon(delta=self.delta) 141 | print(f"(ε = {epsilon:.2f}, δ = {self.delta})") 142 | privacy_dict = {"epsilon": epsilon, "delta": self.delta} 143 | self.writer.record_dict("Privacy", privacy_dict, step=0, save=True, print_results=False) 144 | group_ave_grad_norms = [np.mean(all_grad_norms[i]) for i in range(self.num_groups)] 145 | group_norm_grad_ave = [np.mean(g_B_k_norms[i]) for i in range(self.num_groups)] 146 | return group_ave_grad_norms, group_max_grads, group_norm_grad_ave, losses, losses_per_group / self.num_batch 147 | 148 | def train(self, write_checkpoint=True): 149 | training_time = 0 150 | group_loss_epochs = [] 151 | cos_sim_per_epoch = [] 152 | expected_loss = [] 153 | avg_grad_norms_epochs = [] 154 | max_grads_epochs = [] 155 | norm_avg_grad_epochs = [] 156 | while self.epoch < self.max_epochs: 157 | epoch_start_time = time.time() 158 | self.model.train() 159 | avg_grad_norms, max_grads, norm_avg_grad, losses, group_losses = self._train_epoch(cos_sim_per_epoch, 160 | expected_loss) 161 | group_loss_epochs.append([self.epoch, np.mean(losses)] + list(group_losses)) 162 | avg_grad_norms_epochs.append([self.epoch] + list(avg_grad_norms)) 163 | max_grads_epochs.append([self.epoch] + list(max_grads)) 164 | norm_avg_grad_epochs.append([self.epoch] + list(norm_avg_grad)) 165 | 166 | epoch_training_time = time.time() - epoch_start_time 167 | training_time += epoch_training_time 168 | 169 | print( 170 | f"Train Epoch: {self.epoch} \t" 171 | f"Loss: {np.mean(losses):.6f} \t" 172 | f"Loss per group: {group_losses}" 173 | ) 174 | 175 | self._validate() 176 | self.writer.write_scalar("train/" + "Loss", np.mean(losses), self.epoch) 177 | self.writer.write_scalars("train/AverageGrad", 178 | {'group' + str(k): v for k, v in enumerate(avg_grad_norms)}, 179 | self.epoch) 180 | self.writer.write_scalars("train/MaxGrad", 181 | {'group' + str(k): v for k, v in enumerate(max_grads)}, 182 | self.epoch) 183 | if write_checkpoint: self.write_checkpoint("latest") 184 | self.epoch += 1 185 | 186 | if self.epoch == self.max_epochs: 187 | loss_dict = dict() 188 | 189 | loss_dict["final_loss"] = np.mean(losses) 190 | loss_dict["final_loss_per_group"] = group_losses 191 | self.writer.record_dict("final_loss", loss_dict, 0, save=1, print_results=0) 192 | 193 | K = self.num_groups 194 | # write group_loss to csv 195 | columns = ["epoch", "train_loss"] + [f"train_loss_{k}" for k in range(K)] 196 | self.create_csv(group_loss_epochs, columns, "train_loss_per_epochs") 197 | 198 | # write avg_grad_norms to csv 199 | columns = ["epoch"] + [f"ave_grads_{k}" for k in range(K)] 200 | self.create_csv(avg_grad_norms_epochs, columns, "avg_grad_norms_per_epochs") 201 | 202 | # write max_grads_epochs to csv 203 | columns = ["epoch"] + [f"max_grads_{k}" for k in range(K)] 204 | self.create_csv(max_grads_epochs, columns, "max_grad_norms_per_epochs") 205 | 206 | # write norm_avg_grad to csv 207 | columns = ["epoch"] + [f"norm_avg_grad_{k}" for k in range(K)] 208 | self.create_csv(norm_avg_grad_epochs, columns, "norm_avg_grad_per_epochs") 209 | 210 | # write norms, angles to csv 211 | columns = ["epoch", "batch"] + \ 212 | [f"cos_g_D_{k}_g_B_{k}" for k in self.selected_groups] + \ 213 | [f"cos_g_D_{k}_bar_g_B_{k}" for k in self.selected_groups] + \ 214 | [f"cos_g_D_{k}_g_B" for k in self.selected_groups] + \ 215 | [f"cos_g_D_{k}_bar_g_B" for k in self.selected_groups] + \ 216 | ["cos_g_B_bar_g_B", "|g_B|", "|bar_g_B|"] + \ 217 | [f"|g_D_{k}|" for k in self.selected_groups] + \ 218 | [f"|g_B_{k}|" for k in self.selected_groups] + \ 219 | [f"|bar_g_B_{k}|" for k in self.selected_groups] 220 | self.create_csv(cos_sim_per_epoch, columns, "angles_per_epochs") 221 | 222 | # write expected loss terms to csv 223 | columns = ["epoch", "batch"] + \ 224 | [f"R_non_private_{k}" for k in self.selected_groups] + \ 225 | [f"R_clip_{k}" for k in self.selected_groups] + \ 226 | [f"R_clip_dir_inner_prod_term_{k}" for k in self.selected_groups] + \ 227 | [f"R_clip_dir_hess_term_{k}" for k in self.selected_groups] + \ 228 | [f"R_clip_dir_{k}" for k in self.selected_groups] + \ 229 | [f"R_clip_mag_inner_prod_term_{k}" for k in self.selected_groups] + \ 230 | [f"R_clip_mag_hess_term_{k}" for k in self.selected_groups] + \ 231 | [f"R_clip_mag_{k}" for k in self.selected_groups] + \ 232 | [f"R_noise_{k}" for k in self.selected_groups] 233 | self.create_csv(expected_loss, columns, "expected_loss_per_epochs") 234 | 235 | self.writer.write_scalar("train/" + "avg_train_time_over_epoch", 236 | training_time / (self.max_epochs * 60)) # in minutes 237 | self._test() 238 | 239 | def create_csv(self, data, columns, title): 240 | df = pd.DataFrame(data, columns=columns) 241 | df.to_csv(os.path.join(self.writer.logdir, f"{title}.csv"), index=False) 242 | 243 | def flatten_all_layer_params(self): 244 | """ 245 | Flatten the parameters of all layers in a model 246 | 247 | Args: 248 | model: a pytorch model 249 | 250 | Returns: 251 | a tensor of shape num_samples in a batch * num_params 252 | """ 253 | per_sample_grad = None 254 | for n, p in self.model.named_parameters(): 255 | if p.requires_grad: 256 | if per_sample_grad is None: 257 | per_sample_grad = torch.flatten(p.grad_sample, 1, -1) 258 | else: 259 | per_sample_grad = torch.cat((per_sample_grad, torch.flatten(p.grad_sample, 1, -1)), 1) 260 | return per_sample_grad 261 | 262 | def _validate(self): 263 | valid_results = self.evaluator.validate() 264 | self.writer.record_dict("Validation", valid_results, self.epoch, save=True) 265 | 266 | def _test(self): 267 | test_results = self.evaluator.test() 268 | self.writer.record_dict("Test", test_results, self.epoch, save=True) 269 | if "accuracy_per_group" in test_results.keys(): 270 | plot_by_group(test_results["accuracy_per_group"], self.writer, data_title="final accuracy_per_group", 271 | scale_to_01=1) 272 | 273 | def write_checkpoint(self, tag): 274 | checkpoint = { 275 | "epoch": self.epoch, 276 | 277 | "module_state_dict": self.model.state_dict(), 278 | "opt_state_dict": self.optimizer.state_dict(), 279 | } 280 | 281 | self.writer.write_checkpoint(f"{tag}", checkpoint) 282 | 283 | def record_expected_loss(self, R_non_private, R_clip, R_noise, R_clip_dir_inner_prod_term, R_clip_dir_hess_term, 284 | R_clip_dir, R_clip_mag_inner_prod_term, R_clip_mag_hess_term, R_clip_mag, batch_idx): 285 | step = self.epoch * self.num_batch + batch_idx 286 | self.writer.write_scalars("R_non_private", {'group' + str(k): v for k, v in enumerate(R_non_private)}, step) 287 | self.writer.write_scalars("R_clip", {'group' + str(k): v for k, v in enumerate(R_clip)}, step) 288 | self.writer.write_scalars("R_noise", {'group' + str(k): v for k, v in enumerate(R_noise)}, step) 289 | self.writer.write_scalars("R_clip_dir_inner_prod_term", 290 | {'group' + str(k): v for k, v in enumerate(R_clip_dir_inner_prod_term)}, step) 291 | self.writer.write_scalars("R_clip_dir_hess_term", 292 | {'group' + str(k): v for k, v in enumerate(R_clip_dir_hess_term)}, step) 293 | self.writer.write_scalars("R_clip_dir", {'group' + str(k): v for k, v in enumerate(R_clip_dir)}, step) 294 | self.writer.write_scalars("R_clip_mag_inner_prod_term", 295 | {'group' + str(k): v for k, v in enumerate(R_clip_mag_inner_prod_term)}, step) 296 | self.writer.write_scalars("R_clip_mag_hess_term", 297 | {'group' + str(k): v for k, v in enumerate(R_clip_mag_hess_term)}, step) 298 | self.writer.write_scalars("R_clip_mag", {'group' + str(k): v for k, v in enumerate(R_clip_mag)}, step) 299 | 300 | def expected_loss_batch_terms(self, data, target, group, g_B, bar_g_B, C, criterion): 301 | def create_hvp_fn(data, target): 302 | func_model, params = make_functional(self.model) 303 | 304 | def compute_loss(params): 305 | preds = func_model(params, data) 306 | loss = criterion(preds, target) 307 | return loss 308 | 309 | _, hvp_fn = vjp(grad(compute_loss), params) 310 | return hvp_fn 311 | 312 | per_group, counts = split_by_group(data, target, group, self.num_groups, True) 313 | per_slct_group = [per_group[i] for i in self.selected_groups] 314 | slct_counts = [counts[i] for i in self.selected_groups] 315 | groups_len = len(self.selected_groups) 316 | grad_hess_grad = np.zeros(groups_len) 317 | clip_grad_hess_clip_grad = np.zeros(groups_len) 318 | R_noise = np.zeros(groups_len) 319 | loss = np.zeros(groups_len) 320 | self.model.disable_hooks() 321 | _, params = make_functional(self.model) 322 | unflattened_g_B = unflatten_grads(params, g_B) 323 | unflattened_bar_g_B = unflatten_grads(params, bar_g_B) 324 | for group_idx, (data_group, target_group) in enumerate(per_slct_group): 325 | with torch.no_grad(): 326 | hvp_fn = create_hvp_fn(data_group, target_group) 327 | self.optimizer.zero_grad() 328 | preds = self.model(data_group) 329 | loss[group_idx] = criterion(preds, target_group) * slct_counts[group_idx] 330 | result = 0 331 | for i in range(self.num_hutchinson_estimates): 332 | rand_z = tuple(rademacher(el) for el in params) 333 | hess_z = hvp_fn(rand_z)[0] 334 | z_hess_z = torch.sum( 335 | torch.stack([torch.dot(x.flatten(), y.flatten()) for (x, y) in zip(rand_z, hess_z)])) 336 | result += z_hess_z.item() 337 | # combine results taking into account different batch sizes 338 | hessian_trace = result * slct_counts[group_idx] / self.num_hutchinson_estimates 339 | grad_hess = hvp_fn(unflattened_g_B)[0] 340 | flat_grad_hess = torch.cat([torch.flatten(t) for t in grad_hess]) 341 | grad_hess_grad[group_idx] = torch.dot(flat_grad_hess, g_B) * slct_counts[group_idx] 342 | clip_grad_hess = hvp_fn(unflattened_bar_g_B)[0] 343 | flat_clip_grad_hess = torch.cat([torch.flatten(t) for t in clip_grad_hess]) 344 | clip_grad_hess_clip_grad[group_idx] = torch.dot(flat_clip_grad_hess, bar_g_B) * slct_counts[group_idx] 345 | R_noise[group_idx] = self.lr ** 2 / 2 * hessian_trace * C ** 2 * self.optimizer.noise_multiplier ** 2 346 | self.model.enable_hooks() 347 | return grad_hess_grad, clip_grad_hess_clip_grad, R_noise, loss 348 | 349 | def expected_loss(self, g_B, bar_g_B, sum_grad_vec, grad_hess_grad, clip_grad_hess_clip_grad, 350 | R_noise, loss, group_counts, expected_loss_terms, batch_indx): 351 | norm_g_B = torch.linalg.norm(g_B).item() 352 | norm_bar_g_B = torch.linalg.norm(bar_g_B).item() 353 | groups_len = len(self.selected_groups) 354 | R_non_private = np.zeros(groups_len) 355 | R_clip = np.zeros(groups_len) 356 | new_R_clip_dir = np.zeros(groups_len) 357 | new_R_clip_dir_inner_prod_term = np.zeros(groups_len) 358 | new_R_clip_dir_hess_term = np.zeros(groups_len) 359 | new_R_clip_mag = np.zeros(groups_len) 360 | new_R_clip_mag_inner_prod_term = np.zeros(groups_len) 361 | new_R_clip_mag_hess_term = np.zeros(groups_len) 362 | for group_idx in range(groups_len): 363 | g_D_a = sum_grad_vec[group_idx] / group_counts[group_idx] 364 | group_grad_dot_grad = torch.dot(g_D_a, g_B) 365 | R_non_private[group_idx] = loss[group_idx] - self.lr * group_grad_dot_grad + self.lr ** 2 / 2 * \ 366 | grad_hess_grad[group_idx] 367 | R_clip[group_idx] = self.lr * ( 368 | group_grad_dot_grad - torch.dot(g_D_a, bar_g_B)) \ 369 | + self.lr ** 2 / 2 * (clip_grad_hess_clip_grad[group_idx] - grad_hess_grad[group_idx]) 370 | 371 | new_R_clip_dir_inner_prod_term[group_idx] = self.lr * torch.dot(g_D_a, 372 | norm_bar_g_B / norm_g_B * g_B - bar_g_B) 373 | new_R_clip_dir_hess_term[group_idx] = self.lr ** 2 / 2 * ( 374 | clip_grad_hess_clip_grad[group_idx] - (norm_bar_g_B / norm_g_B) ** 2 * grad_hess_grad[ 375 | group_idx]) 376 | new_R_clip_dir[group_idx] = new_R_clip_dir_inner_prod_term[group_idx] + new_R_clip_dir_hess_term[group_idx] 377 | new_R_clip_mag_inner_prod_term[group_idx] = self.lr * torch.dot(g_D_a, g_B - norm_bar_g_B / norm_g_B * g_B) 378 | new_R_clip_mag_hess_term[group_idx] = self.lr ** 2 / 2 * ((norm_bar_g_B / norm_g_B) ** 2 - 1) * \ 379 | grad_hess_grad[group_idx] 380 | new_R_clip_mag[group_idx] = new_R_clip_mag_inner_prod_term[group_idx] + new_R_clip_mag_hess_term[group_idx] 381 | 382 | self.record_expected_loss(R_non_private, R_clip, R_noise, new_R_clip_dir_inner_prod_term, 383 | new_R_clip_dir_hess_term, 384 | new_R_clip_dir, new_R_clip_mag_inner_prod_term, new_R_clip_mag_hess_term, 385 | new_R_clip_mag, batch_indx) 386 | row = [self.epoch, 387 | batch_indx] + R_non_private.tolist() + R_clip.tolist() + new_R_clip_dir_inner_prod_term.tolist() + \ 388 | new_R_clip_dir_hess_term.tolist() + new_R_clip_dir.tolist() + new_R_clip_mag_inner_prod_term.tolist() + \ 389 | new_R_clip_mag_hess_term.tolist() + new_R_clip_mag.tolist() + R_noise.tolist() 390 | expected_loss_terms.append(row) 391 | 392 | def get_losses_per_group(self, criterion, data, target, group, group_losses): 393 | ''' 394 | Given subset of GroupLabelDataset (data, target, group), computes 395 | loss of model on each subset (data, target, group=k) and returns 396 | np array of length num_groups = group_losses + group losses over given data 397 | ''' 398 | per_group = split_by_group(data, target, group, self.num_groups) 399 | group_loss_batch = np.zeros(self.num_groups) 400 | for group_idx, (data_group, target_group) in enumerate(per_group): 401 | with torch.no_grad(): 402 | if data_group.shape[0] == 0: # if batch does not contain samples of group i 403 | group_loss_batch[group_idx] = 0 404 | else: 405 | group_output = self.model(data_group) 406 | group_loss_batch[group_idx] = criterion(group_output, target_group).item() 407 | group_losses = group_loss_batch + group_losses 408 | return group_losses 409 | 410 | def get_sum_grad_batch(self, data, targets, groups, criterion, **kwargs): 411 | data = data.to(self.device) 412 | targets = targets.to(self.device) 413 | 414 | self.optimizer.zero_grad() 415 | outputs = self.model(data) 416 | loss = criterion(outputs, targets) 417 | loss.backward() 418 | per_sample_grads = self.flatten_all_layer_params() 419 | 420 | return self.get_sum_grad_batch_from_vec(per_sample_grads, groups, **kwargs) 421 | 422 | def get_sum_grad_batch_from_vec(self, per_sample_grads, groups, **kwargs): 423 | if self.method == "dpsgd-f": 424 | clipping_bounds = self.compute_clipping_bound_per_sample(per_sample_grads, groups) 425 | grad_norms, clip_grad_norms, sum_grad_vec, sum_clip_grad_vec = get_grad_norms_clip(per_sample_grads, groups, 426 | self.num_groups, 427 | self.clipping_scale_fn, 428 | clipping_bounds=clipping_bounds) 429 | else: 430 | grad_norms, clip_grad_norms, sum_grad_vec, sum_clip_grad_vec = get_grad_norms_clip(per_sample_grads, groups, 431 | self.num_groups, 432 | self.clipping_scale_fn, 433 | **kwargs) 434 | return grad_norms, clip_grad_norms, sum_grad_vec, sum_clip_grad_vec 435 | 436 | def get_sum_grad(self, dataset, criterion, g_B, bar_g_B, expected_loss_terms, batch_idx, **kwargs): 437 | loader = get_loader(self.train_loader.dataset, self.device, 1000, drop_last=False) 438 | groups_len = len(self.selected_groups) 439 | running_sum_grad_vec = None 440 | running_sum_clip_grad_vec = None 441 | sum_grad_hess_grad = np.zeros(groups_len) 442 | sum_clip_grad_hess_clip_grad = np.zeros(groups_len) 443 | sum_R_noise = np.zeros(groups_len) 444 | sum_loss = np.zeros(groups_len) 445 | # First argument is a dummy 446 | _, group_counts = split_by_group(dataset.y, dataset.y, dataset.z, self.num_groups, return_counts=True) 447 | for data, target, group in loader: 448 | if self.method == "dpsgd-f": 449 | _, _, sum_grad_vec_batch, sum_clip_grad_vec_batch = self.get_sum_grad_batch( 450 | data, target, group, criterion, **kwargs) 451 | else: 452 | _, _, sum_grad_vec_batch, sum_clip_grad_vec_batch = self.get_sum_grad_batch( 453 | data, target, group, criterion, **kwargs) 454 | if running_sum_grad_vec is None: 455 | running_sum_grad_vec = sum_grad_vec_batch 456 | else: 457 | running_sum_grad_vec = [a + b for a, b in zip(running_sum_grad_vec, sum_grad_vec_batch)] 458 | if running_sum_clip_grad_vec is None: 459 | running_sum_clip_grad_vec = sum_clip_grad_vec_batch 460 | else: 461 | running_sum_clip_grad_vec = [a + b for a, b in zip(running_sum_clip_grad_vec, sum_clip_grad_vec_batch)] 462 | if self.evaluate_hessian and self.method != "regular": 463 | clipping_bound = kwargs['clipping_bound'] 464 | grad_hess_grad, clip_grad_hess_clip_grad, R_noise, loss = self.expected_loss_batch_terms( 465 | data, target, group, g_B, bar_g_B, clipping_bound, criterion) 466 | sum_grad_hess_grad += grad_hess_grad 467 | sum_clip_grad_hess_clip_grad += clip_grad_hess_clip_grad 468 | sum_R_noise += R_noise 469 | sum_loss += loss 470 | if self.sampled_expected_loss: 471 | _, group_counts = split_by_group(data, target, group, self.num_groups, return_counts=True) 472 | break 473 | 474 | 475 | if self.evaluate_hessian: 476 | final_sum_grad_vec_batch = [running_sum_grad_vec[i] for i in self.selected_groups] 477 | group_counts_vec = np.array([group_counts[i] for i in self.selected_groups]) 478 | final_grad_hess_grad = sum_grad_hess_grad / group_counts_vec 479 | final_clip_grad_hess_clip_grad = sum_clip_grad_hess_clip_grad / group_counts_vec 480 | final_R_noise = sum_R_noise / group_counts_vec 481 | final_loss = sum_loss / group_counts_vec 482 | self.expected_loss(g_B, bar_g_B, final_sum_grad_vec_batch, final_grad_hess_grad, 483 | final_clip_grad_hess_clip_grad, final_R_noise, final_loss, 484 | group_counts_vec, expected_loss_terms, batch_idx) 485 | return running_sum_grad_vec, running_sum_clip_grad_vec, group_counts 486 | 487 | def mean_grads_over(self, group_counts, sum_grad_vec, clip_sum_grad_vec): 488 | g_D = torch.stack(sum_grad_vec, dim=0).sum(dim=0) / sum(group_counts) 489 | g_D_k = [sum_grad_vec[i] / group_counts[i] for i in range(self.num_groups)] 490 | 491 | bar_g_D = torch.stack(clip_sum_grad_vec, dim=0).sum(dim=0) / sum(group_counts) 492 | bar_g_D_k = [clip_sum_grad_vec[i] / group_counts[i] for i in range(self.num_groups)] 493 | return g_D, g_D_k, bar_g_D, bar_g_D_k 494 | 495 | def evaluate_cosine_sim(self, batch_idx, g_D_k, g_B, bar_g_B, g_B_k, bar_g_B_k): 496 | cos_g_D_k_g_B_k = [] 497 | cos_g_D_k_bar_g_B_k = [] 498 | cos_g_D_k_g_B = [] 499 | cos_g_D_k_bar_g_B = [] 500 | norm_g_D_k = [] 501 | norm_g_B_k = [] 502 | norm_bar_g_B_k = [] 503 | 504 | cos_g_B_bar_g_B = cosine_similarity(g_B, bar_g_B, dim=0).item() 505 | norm_g_B = torch.linalg.norm(g_B).item() 506 | norm_bar_g_B = torch.linalg.norm(bar_g_B).item() 507 | 508 | for k in self.selected_groups: 509 | cos_g_D_k_g_B_k.append(cosine_similarity(g_D_k[k], g_B_k[k], dim=0).item()) 510 | cos_g_D_k_bar_g_B_k.append(cosine_similarity(g_D_k[k], bar_g_B_k[k], dim=0).item()) 511 | cos_g_D_k_g_B.append(cosine_similarity(g_D_k[k], g_B, dim=0).item()) 512 | cos_g_D_k_bar_g_B.append(cosine_similarity(g_D_k[k], bar_g_B, dim=0).item()) 513 | 514 | norm_g_D_k.append(torch.linalg.norm(g_D_k[k]).item()) 515 | norm_g_B_k.append(torch.linalg.norm(g_B_k[k]).item()) 516 | norm_bar_g_B_k.append(torch.linalg.norm(bar_g_B_k[k]).item()) 517 | 518 | row = [self.epoch, batch_idx] + cos_g_D_k_g_B_k + cos_g_D_k_bar_g_B_k + cos_g_D_k_g_B + cos_g_D_k_bar_g_B + [ 519 | cos_g_B_bar_g_B, norm_g_B, norm_bar_g_B] + norm_g_D_k + norm_g_B_k + norm_bar_g_B_k 520 | return row 521 | 522 | 523 | class RegularTrainer(BaseTrainer): 524 | """Class for non-private training""" 525 | 526 | # given norm of gradient, computes S such that clipped gradient = S * gradient 527 | def clipping_scale_fn(self, grad_norm, idx): 528 | return 1 529 | 530 | 531 | class DpsgdTrainer(BaseTrainer): 532 | """Class for DPSGD training""" 533 | 534 | # given norm of gradient, computes S such that clipped gradient = S * gradient 535 | def clipping_scale_fn(self, grad_norm, idx, clipping_bound): 536 | return min(1, clipping_bound / grad_norm) 537 | 538 | def __init__( 539 | self, 540 | model, 541 | optimizer, 542 | privacy_engine, 543 | train_loader, 544 | valid_loader, 545 | test_loader, 546 | writer, 547 | evaluator, 548 | device, 549 | delta=1e-5, 550 | **kwargs 551 | ): 552 | super().__init__( 553 | model, 554 | optimizer, 555 | train_loader, 556 | valid_loader, 557 | test_loader, 558 | writer, 559 | evaluator, 560 | device, 561 | **kwargs 562 | ) 563 | 564 | self.privacy_engine = privacy_engine 565 | self.delta = delta 566 | 567 | 568 | class DpsgdFTrainer(BaseTrainer): 569 | """Class for DPSGD-F training""" 570 | 571 | # given norm of gradient, computes S such that clipped gradient = S * gradient 572 | def clipping_scale_fn(self, grad_norm, idx, **kwargs): 573 | clipping_bounds = kwargs["clipping_bounds"] 574 | return min((clipping_bounds[idx] / grad_norm).item(), 1) 575 | 576 | def __init__( 577 | self, 578 | model, 579 | optimizer, 580 | privacy_engine, 581 | train_loader, 582 | valid_loader, 583 | test_loader, 584 | writer, 585 | evaluator, 586 | device, 587 | delta=1e-5, 588 | base_max_grad_norm=1, # C0 589 | counts_noise_multiplier=10, # noise multiplier applied on mk and ok 590 | **kwargs 591 | ): 592 | """ 593 | Initialization function. Initialize parent class while adding new parameter clipping_bound and noise_scale. 594 | 595 | Args: 596 | model: model from privacy_engine.make_private() 597 | optimizer: a DPSGDF_Optimizer 598 | privacy_engine: DPSGDF_Engine 599 | train_loader: train_loader from privacy_engine.make_private() 600 | valid_loader: normal pytorch data loader for validation set 601 | test_loader: normal pytorch data loader for test set 602 | writer: writer to tensorboard 603 | evaluator: evaluate for model performance 604 | device: device to train the model 605 | delta: definition in privacy budget 606 | clipping_bound: C0 in the original paper, defines the threshold of gradients 607 | counts_noise_multiplier: sigma1 in the original paper, defines noise added to the number of samples with gradient bigger than clipping_bound C0 608 | """ 609 | super().__init__( 610 | model, 611 | optimizer, 612 | train_loader, 613 | valid_loader, 614 | test_loader, 615 | writer, 616 | evaluator, 617 | device, 618 | **kwargs 619 | ) 620 | 621 | self.privacy_engine = privacy_engine 622 | self.delta = delta 623 | # new parameters for DPSGDF 624 | self.base_max_grad_norm = base_max_grad_norm # C0 625 | self.counts_noise_multiplier = counts_noise_multiplier # noise scale applied on mk and ok 626 | self.sample_rate = 1 / self.num_batch 627 | self.privacy_step_history = [] 628 | 629 | def _update_privacy_accountant(self): 630 | """ 631 | The Opacus RDP accountant minimizes computation when many SGM steps are taken in a row with the same parameters. 632 | We alternate between privatizing counts, and gradients with different parameters. 633 | Accounting is sped up by tracking steps in groups rather than alternating. 634 | The order of accounting does not affect the privacy guarantee. 635 | """ 636 | for step in self.privacy_step_history: 637 | self.privacy_engine.accountant.step(noise_multiplier=step[0], sample_rate=step[1]) 638 | self.privacy_step_history = [] 639 | 640 | def compute_clipping_bound_per_sample(self, per_sample_grads, group): 641 | """compute clipping bound for each sample """ 642 | # calculate mk, ok 643 | mk = collections.defaultdict(int) 644 | ok = collections.defaultdict(int) 645 | # get the l2 norm of gradients of all parameters for each sample, in shape of (batch_size, ) 646 | l2_norm_grad_per_sample = torch.norm(per_sample_grads, p=2, dim=1) # batch_size 647 | 648 | assert len(group) == len(l2_norm_grad_per_sample) 649 | 650 | for i in range(len(group)): # looping over batch 651 | group_idx = group[i].item() 652 | if l2_norm_grad_per_sample[i].item() > self.base_max_grad_norm: 653 | mk[group_idx] += 1 654 | else: 655 | ok[group_idx] += 1 656 | 657 | # add noise scale to mk and ok 658 | m2k = {} 659 | o2k = {} 660 | m = 0 661 | 662 | # note that some group idx might have 0 sample counts in the batch and we are still adding noise to it 663 | for group_idx in range(self.num_groups): 664 | m2k[group_idx] = mk[group_idx] + torch.normal(0, self.counts_noise_multiplier, (1,)).item() 665 | m2k[group_idx] = max(int(m2k[group_idx]), 0) 666 | o2k[group_idx] = ok[group_idx] + torch.normal(0, self.counts_noise_multiplier, (1,)).item() 667 | o2k[group_idx] = max(int(o2k[group_idx]), 0) 668 | m += m2k[group_idx] 669 | 670 | # Account for privacy cost of privately estimating group sizes 671 | # using the built in sampled-gaussian-mechanism accountant. 672 | # L2 sensitivity of per-group counts vector is always 1, 673 | # so std use in torch.normal is the same as noise_multiplier in accountant. 674 | # Accounting is done lazily, see _update_privacy_accountant method. 675 | self.privacy_step_history.append([self.counts_noise_multiplier, self.sample_rate]) 676 | array = [] 677 | bk = {} 678 | Ck = {} 679 | for group_idx in range(self.num_groups): 680 | bk[group_idx] = m2k[group_idx] + o2k[group_idx] 681 | # added 682 | if bk[group_idx] == 0: 683 | array.append(1) # when bk = 0, m2k = 0, we have 0/0 = 1 684 | else: 685 | array.append(m2k[group_idx] * 1.0 / bk[group_idx]) 686 | 687 | for group_idx in range(self.num_groups): 688 | Ck[group_idx] = self.base_max_grad_norm * (1 + array[group_idx] / (np.mean(array) + 1e-8)) 689 | 690 | per_sample_clipping_bound = [] 691 | for i in range(len(group)): # looping over batch 692 | group_idx = group[i].item() 693 | per_sample_clipping_bound.append(Ck[group_idx]) 694 | 695 | return torch.Tensor(per_sample_clipping_bound).to(device=self.device) 696 | 697 | 698 | class DpsgdGlobalTrainer(DpsgdTrainer): 699 | 700 | # given norm of gradient, computes S such that clipped gradient = S * gradient 701 | def clipping_scale_fn(self, grad_norm, idx, clipping_bound): 702 | if grad_norm > self.strict_max_grad_norm: 703 | return 0 704 | else: 705 | return clipping_bound / self.strict_max_grad_norm 706 | 707 | def __init__( 708 | self, 709 | model, 710 | optimizer, 711 | privacy_engine, 712 | train_loader, 713 | valid_loader, 714 | test_loader, 715 | writer, 716 | evaluator, 717 | device, 718 | delta=1e-5, 719 | strict_max_grad_norm=100, 720 | **kwargs 721 | ): 722 | super().__init__( 723 | model, 724 | optimizer, 725 | privacy_engine, 726 | train_loader, 727 | valid_loader, 728 | test_loader, 729 | writer, 730 | evaluator, 731 | device, 732 | delta=delta, 733 | **kwargs 734 | ) 735 | self.strict_max_grad_norm = strict_max_grad_norm 736 | 737 | 738 | class DpsgdGlobalAdaptiveTrainer(BaseTrainer): 739 | 740 | # given norm of gradient, computes S such that clipped gradient = S * gradient 741 | def clipping_scale_fn(self, grad_norm, idx, clipping_bound): 742 | if grad_norm > self.strict_max_grad_norm: 743 | return min(1, clipping_bound / grad_norm) 744 | else: 745 | return clipping_bound / self.strict_max_grad_norm 746 | 747 | def __init__( 748 | self, 749 | model, 750 | optimizer, 751 | privacy_engine, 752 | train_loader, 753 | valid_loader, 754 | test_loader, 755 | writer, 756 | evaluator, 757 | device, 758 | delta=1e-5, 759 | strict_max_grad_norm=100, 760 | bits_noise_multiplier=10, 761 | lr_Z=0.01, 762 | threshold=1.0, 763 | **kwargs 764 | ): 765 | super().__init__( 766 | model, 767 | optimizer, 768 | train_loader, 769 | valid_loader, 770 | test_loader, 771 | writer, 772 | evaluator, 773 | device, 774 | **kwargs 775 | ) 776 | self.privacy_engine = privacy_engine 777 | self.delta = delta 778 | self.strict_max_grad_norm = strict_max_grad_norm # Z 779 | self.bits_noise_multiplier = bits_noise_multiplier 780 | self.lr_Z = lr_Z 781 | self.sample_rate = 1 / self.num_batch 782 | self.privacy_step_history = [] 783 | self.threshold = threshold 784 | 785 | def _update_privacy_accountant(self): 786 | """ 787 | The Opacus RDP accountant minimizes computation when many SGM steps are taken in a row with the same parameters. 788 | We alternate between privatizing counts, and gradients with different parameters. 789 | Accounting is sped up by tracking steps in groups rather than alternating. 790 | The order of accounting does not affect the privacy guarantee. 791 | """ 792 | for step in self.privacy_step_history: 793 | self.privacy_engine.accountant.step(noise_multiplier=step[0], sample_rate=step[1]) 794 | self.privacy_step_history = [] 795 | 796 | def _update_Z(self, per_sample_grads, Z): 797 | # get the l2 norm of gradients of all parameters for each sample, in shape of (batch_size, ) 798 | l2_norm_grad_per_sample = torch.norm(per_sample_grads, p=2, dim=1) 799 | batch_size = len(l2_norm_grad_per_sample) 800 | 801 | dt = 0 # sample count in a batch exceeding Z * threshold 802 | for i in range(batch_size): # looping over batch 803 | if l2_norm_grad_per_sample[i].item() > self.threshold * Z: 804 | dt += 1 805 | 806 | dt = dt * 1.0 / batch_size # percentage of samples in a batch that's bigger than the threshold * Z 807 | noisy_dt = dt + torch.normal(0, self.bits_noise_multiplier, (1,)).item() * 1.0 / batch_size 808 | 809 | factor = math.exp(- self.lr_Z + noisy_dt) 810 | 811 | next_Z = Z * factor 812 | 813 | self.privacy_step_history.append([self.bits_noise_multiplier, self.sample_rate]) 814 | return next_Z 815 | -------------------------------------------------------------------------------- /trainers/trainer_factory.py: -------------------------------------------------------------------------------- 1 | from .trainer import RegularTrainer, DpsgdTrainer, DpsgdFTrainer, DpsgdGlobalTrainer, DpsgdGlobalAdaptiveTrainer 2 | 3 | 4 | def create_trainer( 5 | train_loader, 6 | valid_loader, 7 | test_loader, 8 | model, 9 | optimizer, 10 | privacy_engine, 11 | evaluator, 12 | writer, 13 | device, 14 | config 15 | ): 16 | kwargs = { 17 | 'method': config['method'], 18 | 'max_epochs': config['max_epochs'], 19 | 'num_groups': config['num_groups'], 20 | 'selected_groups': config['selected_groups'], 21 | 'evaluate_angles': config['evaluate_angles'], 22 | 'evaluate_hessian': config['evaluate_hessian'], 23 | 'angle_comp_step': config['angle_comp_step'], 24 | 'lr': config['lr'], 25 | 'seed': config['seed'], 26 | 'num_hutchinson_estimates': config['num_hutchinson_estimates'], 27 | 'sampled_expected_loss': config['sampled_expected_loss'] 28 | } 29 | 30 | if config["method"] == "regular": 31 | trainer = RegularTrainer( 32 | model, 33 | optimizer, 34 | train_loader, 35 | valid_loader, 36 | test_loader, 37 | writer, 38 | evaluator, 39 | device, 40 | **kwargs 41 | ) 42 | elif config["method"] == "dpsgd": 43 | trainer = DpsgdTrainer( 44 | model, 45 | optimizer, 46 | privacy_engine, 47 | train_loader, 48 | valid_loader, 49 | test_loader, 50 | writer, 51 | evaluator, 52 | device, 53 | delta=config["delta"], 54 | **kwargs 55 | ) 56 | elif config["method"] == "dpsgd-f": 57 | trainer = DpsgdFTrainer( 58 | model, 59 | optimizer, 60 | privacy_engine, 61 | train_loader, 62 | valid_loader, 63 | test_loader, 64 | writer, 65 | evaluator, 66 | device, 67 | delta=config["delta"], 68 | base_max_grad_norm=config["base_max_grad_norm"], # C0 69 | counts_noise_multiplier=config["counts_noise_multiplier"], # noise scale applied on mk and ok 70 | **kwargs 71 | ) 72 | elif config["method"] == "dpsgd-global": 73 | trainer = DpsgdGlobalTrainer( 74 | model, 75 | optimizer, 76 | privacy_engine, 77 | train_loader, 78 | valid_loader, 79 | test_loader, 80 | writer, 81 | evaluator, 82 | device, 83 | delta=config["delta"], 84 | strict_max_grad_norm=config["strict_max_grad_norm"], 85 | **kwargs 86 | ) 87 | elif config["method"] == "dpsgd-global-adapt": 88 | trainer = DpsgdGlobalAdaptiveTrainer( 89 | model, 90 | optimizer, 91 | privacy_engine, 92 | train_loader, 93 | valid_loader, 94 | test_loader, 95 | writer, 96 | evaluator, 97 | device, 98 | delta=config["delta"], 99 | strict_max_grad_norm=config["strict_max_grad_norm"], 100 | bits_noise_multiplier=config["bits_noise_multiplier"], 101 | lr_Z=config["lr_Z"], 102 | threshold=config["threshold"], 103 | **kwargs 104 | ) 105 | else: 106 | raise ValueError("Training method not implemented") 107 | 108 | return trainer 109 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from opacus.accountants.analysis.gdp import compute_eps_poisson 6 | from opacus.accountants.analysis.rdp import compute_rdp, get_privacy_spent 7 | 8 | 9 | def privacy_checker(sample_rate, config): 10 | assert sample_rate <= 1.0 11 | steps = config["max_epochs"] * math.ceil(1 / sample_rate) 12 | 13 | if config["accountant"] == 'rdp': 14 | orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64)) 15 | rdp = compute_rdp( 16 | q=sample_rate, 17 | noise_multiplier=config["noise_multiplier"], 18 | steps=steps, 19 | orders=orders) 20 | epsilon, alpha = get_privacy_spent( 21 | orders=orders, 22 | rdp=rdp, 23 | delta=config["delta"]) 24 | print( 25 | "-----------privacy------------" 26 | f"\nDP-SGD (RDP) with\n\tsampling rate = {100 * sample_rate:.3g}%," 27 | f"\n\tnoise_multiplier = {config['noise_multiplier']}," 28 | f"\n\titerated over {steps} steps,\nsatisfies " 29 | f"differential privacy with\n\tepsilon = {epsilon:.3g}," 30 | f"\n\tdelta = {config['delta']}." 31 | f"\nThe optimal alpha is {alpha}." 32 | ) 33 | elif config["accountant"] == 'gdp': 34 | eps = compute_eps_poisson( 35 | steps=steps, 36 | noise_multiplier=config["noise_multiplier"], 37 | sample_rate=sample_rate, 38 | delta=config["delta"], 39 | ) 40 | print( 41 | "-----------privacy------------" 42 | f"\nDP-SGD (GDP) with\n\tsampling rate = {100 * sample_rate:.3g}%," 43 | f"\n\tnoise_multiplier = {config['noise_multiplier']}," 44 | f"\n\titerated over {steps} steps,\nsatisfies " 45 | f"differential privacy with\n\tepsilon = {eps:.3g}," 46 | f"\n\tdelta = {config['delta']}." 47 | ) 48 | else: 49 | raise ValueError(f"Unknown accountant {config['accountant']}. Try 'rdp' or 'gdp'.") 50 | 51 | 52 | def get_grads(named_parameters, group, num_groups): 53 | ave_grads = [[] for _ in range(num_groups)] 54 | max_grads = [[] for _ in range(num_groups)] 55 | name_grads = list(named_parameters) 56 | for batch_idx in range(group.shape[0]): 57 | grads_per_sample = [] 58 | for layer_idx in range(len(name_grads)): 59 | if name_grads[layer_idx][1].requires_grad: 60 | grads_per_sample.append(name_grads[layer_idx][1].grad_sample[batch_idx].abs().reshape(-1)) 61 | ave_grads[group[batch_idx]].append(torch.mean(torch.cat(grads_per_sample, 0))) 62 | max_grads[group[batch_idx]].append(torch.max(torch.cat(grads_per_sample, 0))) 63 | return ave_grads, max_grads 64 | 65 | 66 | def get_grad_norms_clip(per_sample_grads, group, num_groups, clipping_scale_fn, **kwargs): 67 | grad_norms = [[] for _ in range(num_groups)] 68 | clip_grad_norms = [[] for _ in range(num_groups)] 69 | sum_grad_vec = [] 70 | sum_clip_grad_vec = [] 71 | for sample_idx in range(group.shape[0]): 72 | grad_vec = per_sample_grads[sample_idx] 73 | grad_norm = torch.linalg.norm(grad_vec).item() 74 | clipping_scale = clipping_scale_fn(grad_norm, sample_idx, **kwargs) 75 | clip_grad_vec = clipping_scale * grad_vec 76 | clip_grad_norm = torch.linalg.norm(clip_grad_vec).item() 77 | grad_norms[group[sample_idx]].append(grad_norm) 78 | clip_grad_norms[group[sample_idx]].append(clip_grad_norm) 79 | if sample_idx == 0: 80 | for _ in range(num_groups): 81 | sum_grad_vec.append(torch.zeros(grad_vec.shape[0], device=grad_vec.device, requires_grad=False)) 82 | sum_clip_grad_vec.append(torch.zeros(grad_vec.shape[0], device=grad_vec.device, requires_grad=False)) 83 | sum_grad_vec[group[sample_idx]] += grad_vec 84 | sum_clip_grad_vec[group[sample_idx]] += clip_grad_vec 85 | return grad_norms, clip_grad_norms, sum_grad_vec, sum_clip_grad_vec 86 | 87 | 88 | def get_grad_norms(per_sample_grads, group, num_groups): 89 | grad_norms = [[] for _ in range(num_groups)] 90 | sum_grad_vec = [] 91 | for sample_idx in range(group.shape[0]): 92 | grad_vec = per_sample_grads[sample_idx] 93 | grad_norms[group[sample_idx]].append(torch.linalg.norm(grad_vec).item()) 94 | if sample_idx == 0: 95 | for _ in range(num_groups): 96 | sum_grad_vec.append(torch.zeros(grad_vec.shape[0], device=grad_vec.device, requires_grad=False)) 97 | sum_grad_vec[group[sample_idx]] += grad_vec 98 | return grad_norms, sum_grad_vec 99 | 100 | 101 | def get_num_layers(model): 102 | num_layers = 0 103 | for n, p in model.named_parameters(): 104 | if (p.requires_grad) and ("bias" not in n): 105 | num_layers += 1 106 | 107 | return num_layers 108 | 109 | 110 | # splits data, labels according to group of data point 111 | # returns tensor of size num_groups, each element is (subset of data, subset of labels) 112 | # corresponding to specific group given by index 113 | def split_by_group(data, labels, group, num_groups, return_counts=False): 114 | sorter = torch.argsort(group) 115 | unique, counts = torch.unique(group, return_counts=True) 116 | unique = unique.tolist() 117 | counts = counts.tolist() 118 | 119 | complete_unique = [0] * num_groups 120 | complete_counts = [0] * num_groups 121 | for i in range(num_groups): 122 | complete_unique[i] = i 123 | if i in unique: 124 | j = unique.index(i) 125 | complete_counts[i] = counts[j] 126 | 127 | sorted_data = torch.split(data[sorter], complete_counts) 128 | sorted_labels = torch.split(labels[sorter], complete_counts) 129 | 130 | if not return_counts: 131 | return list(zip(sorted_data, sorted_labels)) 132 | return list(zip(sorted_data, sorted_labels)), complete_counts 133 | 134 | 135 | def plot_by_group(data_by_group, writer, data_title=None, scale_to_01=False): 136 | fig = plt.figure() 137 | plt.bar(range(len(data_by_group)), data_by_group, width=0.9) 138 | plt.xlabel("Groups") 139 | if data_title is not None: 140 | plt.ylabel(data_title) 141 | plt.title(data_title) 142 | plt.xticks(range(len(data_by_group))) 143 | if scale_to_01: 144 | plt.ylim(0, 1) 145 | writer.write_figure(data_title, fig) 146 | 147 | 148 | def unflatten_grads(params, grad_vec): 149 | grad_size = [] 150 | for layer in params: 151 | grad_size.append(layer.reshape(-1).shape[0]) 152 | grad_list = list(torch.split(grad_vec, grad_size)) 153 | for layer_idx in range(len(grad_list)): 154 | grad_list[layer_idx] = grad_list[layer_idx].reshape(params[layer_idx].shape) 155 | return tuple(grad_list) 156 | 157 | 158 | def rademacher(tens): 159 | """Draws a random tensor of size [tens.shape] from the Rademacher distribution (P(x=1) == P(x=-1) == 0.5)""" 160 | x = torch.empty_like(tens) 161 | x.random_(0, 2) 162 | x[x == 0] = -1 163 | return x 164 | -------------------------------------------------------------------------------- /writer.py: -------------------------------------------------------------------------------- 1 | # NOTE: The below file is modified from commit `aeaf5fd` of 2 | # https://github.com/jrmcornish/cif/blob/master/cif/writer.py 3 | 4 | import datetime 5 | import json 6 | import os 7 | import sys 8 | from typing import Iterable 9 | 10 | import numpy as np 11 | import torch 12 | from tensorboardX import SummaryWriter 13 | 14 | 15 | class Tee: 16 | """This class allows for redirecting of stdout and stderr""" 17 | 18 | def __init__(self, primary_file, secondary_file): 19 | self.primary_file = primary_file 20 | self.secondary_file = secondary_file 21 | 22 | self.encoding = self.primary_file.encoding 23 | 24 | def isatty(self): 25 | return self.primary_file.isatty() 26 | 27 | def fileno(self): 28 | return self.primary_file.fileno() 29 | 30 | def write(self, data): 31 | if isinstance(data, bytes): 32 | data = data.decode() 33 | 34 | self.primary_file.write(data) 35 | self.secondary_file.write(data) 36 | 37 | def flush(self): 38 | self.primary_file.flush() 39 | self.secondary_file.flush() 40 | 41 | 42 | class Writer: 43 | _STDOUT = sys.stdout 44 | _STDERR = sys.stderr 45 | 46 | def __init__(self, logdir, make_subdir, tag_group, dir_name): 47 | if make_subdir: 48 | os.makedirs(logdir, exist_ok=True) 49 | 50 | if dir_name == "": 51 | dir_name = datetime.datetime.now().strftime('%b%d_%H-%M-%S') 52 | logdir = os.path.join(logdir, dir_name) 53 | 54 | self._writer = SummaryWriter(logdir=logdir) 55 | 56 | assert logdir == self._writer.logdir 57 | self.logdir = logdir 58 | 59 | self._tag_group = tag_group 60 | 61 | sys.stdout = Tee( 62 | primary_file=self._STDOUT, 63 | secondary_file=open(os.path.join(logdir, "stdout"), "a") 64 | ) 65 | 66 | sys.stderr = Tee( 67 | primary_file=self._STDERR, 68 | secondary_file=open(os.path.join(logdir, "stderr"), "a") 69 | ) 70 | 71 | def write_scalar(self, tag, scalar_value, global_step=None): 72 | self._writer.add_scalar(self._tag(tag), scalar_value, global_step=global_step) 73 | 74 | def write_scalars(self, tag, scalar_dict, global_step=None): 75 | self._writer.add_scalars(self._tag(tag), scalar_dict, global_step=global_step) 76 | 77 | def write_image(self, tag, img_tensor, global_step=None): 78 | self._writer.add_image(self._tag(tag), img_tensor, global_step=global_step) 79 | 80 | def write_figure(self, tag, figure, global_step=None): 81 | self._writer.add_figure(self._tag(tag), figure, global_step=global_step) 82 | 83 | def write_hparams(self, hparam_dict=None, metric_dict=None): 84 | self._writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict) 85 | 86 | def write_json(self, tag, data): 87 | text = json.dumps(data, indent=4) 88 | 89 | self._writer.add_text( 90 | self._tag(tag), 91 | 4 * " " + text.replace("\n", "\n" + 4 * " ") # Indent by 4 to ensure codeblock formatting 92 | ) 93 | 94 | json_path = os.path.join(self.logdir, f"{tag}.json") 95 | 96 | with open(json_path, "w") as f: 97 | f.write(text) 98 | 99 | def write_textfile(self, tag, text): 100 | path = os.path.join(self.logdir, f"{tag}.txt") 101 | with open(path, "w") as f: 102 | f.write(text) 103 | 104 | def write_numpy(self, tag, arr): 105 | path = os.path.join(self.logdir, f"{tag}.npy") 106 | np.save(path, arr) 107 | print(f"Saved array to {path}") 108 | 109 | def write_checkpoint(self, tag, data): 110 | os.makedirs(self._checkpoints_dir, exist_ok=True) 111 | checkpoint_path = self._checkpoint_path(tag) 112 | 113 | tmp_checkpoint_path = os.path.join( 114 | os.path.dirname(checkpoint_path), 115 | f"{os.path.basename(checkpoint_path)}.tmp" 116 | ) 117 | 118 | torch.save(data, tmp_checkpoint_path) 119 | # replace is atomic, so we guarantee our checkpoints are always good 120 | os.replace(tmp_checkpoint_path, checkpoint_path) 121 | 122 | def record_dict(self, tag_prefix, value_dict, step, save=False, print_results=True): 123 | if print_results: 124 | for k, v in value_dict.items(): 125 | if isinstance(v, Iterable): 126 | print("{} {}: {}".format(tag_prefix, k, v)) 127 | for idx, item in enumerate(v): 128 | self.write_scalar(f"{tag_prefix}/{k}", item, idx) 129 | else: 130 | print(f"{tag_prefix} {k}: {v:.4f}") 131 | self.write_scalar(f"{tag_prefix}/{k}", v, step) 132 | 133 | if save: 134 | values = dict() 135 | for k, v in value_dict.items(): 136 | if isinstance(v, torch.Tensor): 137 | values[k] = v.item() 138 | elif isinstance(v, Iterable): 139 | for idx, item in enumerate(v): 140 | tag = k + "_" + str(idx) 141 | values[tag] = item 142 | else: 143 | values[k] = v 144 | 145 | self.write_json( 146 | f"{tag_prefix}_metrics", values 147 | ) 148 | 149 | def load_checkpoint(self, tag, device): 150 | return torch.load(self._checkpoint_path(tag), map_location=device) 151 | 152 | def _checkpoint_path(self, tag): 153 | return os.path.join(self._checkpoints_dir, f"{tag}.pt") 154 | 155 | @property 156 | def _checkpoints_dir(self): 157 | return os.path.join(self.logdir, "checkpoints") 158 | 159 | def _tag(self, tag): 160 | return f"{self._tag_group}/{tag}" 161 | --------------------------------------------------------------------------------