├── .gitmodules ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── accountant.py ├── configs ├── cifar.yaml └── mnist.yaml ├── datasets.py ├── mnist_logistic_reconstruction.py ├── mnist_logistic_regression.py ├── train_classifier.py ├── trainer.py └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fisher_information_loss"] 2 | path = fisher_information_loss 3 | url = git@github.com:facebookresearch/fisher_information_loss.git 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Bounding Data Reconstruction 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to Bounding Data Reconstruction, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bounding Training Data Reconstruction in Private (Deep) Learning 2 | 3 | This repository contains code for reproducing results in the paper: 4 | - Chuan Guo, Brian Karrer, Kamalika Chaudhuri, Laurens van der Maaten. **[Bounding Training Data Reconstruction in Private (Deep) Learning](https://arxiv.org/abs/2201.12383)**. 5 | 6 | ## Setup 7 | 8 | Dependencies: [hydra](https://github.com/facebookresearch/hydra), [numpy](https://numpy.org/), [pytorch](https://pytorch.org/), [fisher_information_loss](https://github.com/facebookresearch/fisher_information_loss) 9 | 10 | For private SGD experiments: [jax](https://github.com/google/jax), [tensorflow-privacy](https://github.com/tensorflow/privacy) 11 | 12 | After installing dependencies, download the [fisher_information_loss](https://github.com/facebookresearch/fisher_information_loss) submodule: 13 | ``` 14 | git submodule update --init 15 | ``` 16 | 17 | ## Experiments 18 | 19 | ### MNIST Logistic Regression 20 | 21 | Trains a logistic regression model for MNIST 0 vs. 1 classification and compute RDP and FIL privacy accounting: 22 | ``` 23 | python mnist_logistic_regression.py --lam 1e-2 --sigma 1e-2 24 | ``` 25 | 26 | Runs the [Balle et al.](https://arxiv.org/abs/2201.04845) GLM attack on the logistic regression model: 27 | ``` 28 | python mnist_logistic_reconstruction.py --lam 1e-2 --sigma 1e-5 29 | ``` 30 | 31 | ### Private SGD Training 32 | 33 | Trains a private model on MNIST/CIFAR-10 with RDP and FIL privacy accounting: 34 | ``` 35 | python train_classifier.py --config-name [mnist.yaml/cifar.yaml] 36 | ``` 37 | Check `configs` directory for Hydra configs, and see appendix in our paper for the full grid of hyperparameter values. 38 | 39 | ## Code Acknowledgements 40 | 41 | The majority of Bounding Data Reconstruction is licensed under CC-BY-NC, however portions of the project are available under separate terms: [hydra](https://github.com/facebookresearch/hydra) is licensed under the MIT license; and [jax](https://github.com/google/jax) and [tensorflow-privacy](https://github.com/tensorflow/privacy) are licensed under the Apache 2.0 license. 42 | -------------------------------------------------------------------------------- /accountant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import jax.numpy as jnp 10 | import jax.random as jnr 11 | 12 | from jax import jit, jvp, vjp, jacrev, vmap, nn 13 | from jax.tree_util import tree_flatten 14 | import trainer 15 | from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp, get_privacy_spent 16 | 17 | 18 | def get_grad_jacobian_norm_func(grad_func, get_params, method="jvp", reshape=True, label_privacy=False): 19 | """ 20 | Returns a function that computes norm of the Jacobian of the parameter 21 | gradients for the specified `loss` function for an optimizer in which the 22 | `get_params` function returns the model parameters. 23 | """ 24 | 25 | # assertions: 26 | assert method in ["jvp", "full"], f"Unknown method: {method}" 27 | 28 | @jit 29 | def compute_power_iteration_jvp(params, w, inputs, targets): 30 | """ 31 | Computes a single power iteration via the JVP method. Does not include 32 | Jacobian w.r.t. targets. 33 | """ 34 | 35 | # compute JVP of per-example parameter gradient Jacobian with w: 36 | if label_privacy: 37 | perex_grad = lambda x: vmap(grad_func, in_axes=(None, 0, 0))( 38 | params, inputs, x 39 | ) 40 | _, w = jvp(perex_grad, (targets,), (w,)) 41 | else: 42 | perex_grad = lambda x: vmap(grad_func, in_axes=(None, 0, 0))( 43 | params, x, targets 44 | ) 45 | _, w = jvp(perex_grad, (inputs,), (w,)) 46 | 47 | # compute norm of the JVP: 48 | w_flattened, _ = tree_flatten(w) 49 | norms = [ 50 | jnp.power(jnp.reshape(v, (v.shape[0], -1)), 2).sum(axis=1) 51 | for v in w_flattened 52 | ] 53 | norms = jnp.sqrt(sum(norms) + 1e-7) 54 | 55 | # compute VJP of per-example parameter gradient Jacobian with w: 56 | if label_privacy: 57 | _, f_vjp = vjp(perex_grad, targets) 58 | else: 59 | _, f_vjp = vjp(perex_grad, inputs) 60 | w_out = f_vjp(w)[0] 61 | 62 | return norms, w_out 63 | 64 | @jit 65 | def compute_power_iteration_full(params, w, inputs, targets): 66 | """ 67 | Computes a single power iteration by computing the full Jacobian and 68 | right-multiplying it. Does not include Jacobian w.r.t. targets. 69 | """ 70 | 71 | # compute per-example parameter gradient Jacobian: 72 | J = jacrev(grad_func, 1)(params, inputs, targets) 73 | J_flattened, _ = tree_flatten(J) 74 | 75 | # compute JVP with w: 76 | jvp_exact = [(v * w).sum(-1) for v in J_flattened] 77 | 78 | # compute norm of the JVP: 79 | norms = [ 80 | jnp.power(jnp.reshape(v, (-1, v.shape[-1])), 2).sum(axis=0) 81 | for v in jvp_exact 82 | ] 83 | norms = jnp.sqrt(sum(norms)) 84 | 85 | # compute VJP of per-example parameter gradient Jacobian with w: 86 | vjp_exact = [ 87 | J_flattened[i] * jnp.expand_dims(jvp_exact[i], -1) 88 | for i in jnp.arange(len(jvp_exact)) 89 | ] 90 | w_out = sum( 91 | [jnp.reshape(v, (-1, v.shape[-2], v.shape[-1])).sum(0) for v in vjp_exact] 92 | ) 93 | return norms, w_out 94 | 95 | @jit 96 | def grad_jacobian_norm(rng, opt_state, batch, num_iters=20): 97 | """ 98 | Computes norm of the Jacobian of the parameter gradients. The function 99 | performs `num_iters` power iterations. 100 | """ 101 | 102 | # initialize power iterates: 103 | inputs, targets = batch 104 | if reshape: 105 | inputs = jnp.expand_dims(inputs, 1) 106 | 107 | w = jnr.normal(rng, shape=(targets.shape if label_privacy else inputs.shape)) 108 | w_norm = jnp.sqrt(jnp.power(w.reshape(w.shape[0], -1), 2).sum(axis=1) + 1e-7) 109 | w = w / jnp.expand_dims(w_norm, tuple(range(1, len(w.shape)))) 110 | 111 | # perform power iterations: 112 | params = get_params(opt_state) 113 | for i in jnp.arange(num_iters): 114 | if method == "jvp": 115 | norms, w = compute_power_iteration_jvp(params, w, inputs, targets) 116 | elif method == "full": 117 | norms, w = compute_power_iteration_full(params, w, inputs, targets) 118 | w_norm = jnp.sqrt(jnp.power(w.reshape(w.shape[0], -1), 2).sum(axis=1) + 1e-7) 119 | w = w / jnp.expand_dims(w_norm, tuple(range(1, len(w.shape)))) 120 | 121 | # set nan values to 0 because gradient is 0 122 | norms = jnp.nan_to_num(norms) 123 | return norms 124 | 125 | # return the function: 126 | return grad_jacobian_norm 127 | 128 | 129 | def get_grad_jacobian_trace_func(grad_func, get_params, reshape=True, label_privacy=False): 130 | """ 131 | Returns a function that computes the (square root of the) trace of the Jacobian 132 | of the parameters. 133 | """ 134 | 135 | @jit 136 | def grad_jacobian_trace(rng, opt_state, batch, num_iters=50): 137 | 138 | params = get_params(opt_state) 139 | inputs, targets = batch 140 | if reshape: 141 | inputs = jnp.expand_dims(inputs, 1) 142 | 143 | if label_privacy: 144 | flattened_shape = jnp.reshape(targets, (targets.shape[0], -1)).shape 145 | perex_grad = lambda x: vmap(grad_func, in_axes=(None, 0, 0))( 146 | params, inputs, x 147 | ) 148 | else: 149 | flattened_shape = jnp.reshape(inputs, (inputs.shape[0], -1)).shape 150 | perex_grad = lambda x: vmap(grad_func, in_axes=(None, 0, 0))( 151 | params, x, targets 152 | ) 153 | 154 | num_iters = targets.shape[1] if label_privacy else num_iters 155 | rngs = jnr.split(rng, num_iters) 156 | trace = jnp.zeros(inputs.shape[0]) 157 | for i, g in zip(jnp.arange(num_iters), rngs): 158 | indices = jnr.categorical(g, jnp.ones(shape=flattened_shape)) 159 | if label_privacy: 160 | indices = i * jnp.ones(flattened_shape[0]) 161 | w = jnp.reshape(nn.one_hot(indices, flattened_shape[1]), targets.shape) 162 | _, w = jvp(perex_grad, (targets,), (w,)) 163 | else: 164 | indices = jnr.categorical(rng, jnp.ones(shape=flattened_shape)) 165 | w = jnp.reshape(nn.one_hot(indices, flattened_shape[1]), inputs.shape) 166 | _, w = jvp(perex_grad, (inputs,), (w,)) 167 | # compute norm of the JVP: 168 | w_flattened, _ = tree_flatten(w) 169 | norms = [ 170 | jnp.power(jnp.reshape(v, (v.shape[0], -1)), 2).sum(axis=1) 171 | for v in w_flattened 172 | ] 173 | trace = trace + sum(norms) / num_iters 174 | 175 | # set nan values to 0 because gradient is 0 176 | trace = jnp.nan_to_num(trace) 177 | return jnp.sqrt(trace + 1e-7) 178 | 179 | # return the function: 180 | return grad_jacobian_trace 181 | 182 | 183 | def get_dp_accounting_func(batch_size, sigma): 184 | """ 185 | Returns the (eps, delta)-DP accountant if alpha=None, 186 | or the (alpha, eps)-RDP accountant otherwise. 187 | """ 188 | 189 | def compute_epsilon(steps, num_examples, target_delta=1e-5, alpha=None): 190 | if num_examples * target_delta > 1.: 191 | warnings.warn('Your delta might be too high.') 192 | q = batch_size / float(num_examples) 193 | if alpha is None: 194 | orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64)) 195 | rdp_const = compute_rdp(q, sigma, steps, orders) 196 | eps, _, _ = get_privacy_spent(orders, rdp_const, target_delta=target_delta) 197 | else: 198 | eps = compute_rdp(q, sigma, steps, alpha) 199 | return eps 200 | 201 | return compute_epsilon -------------------------------------------------------------------------------- /configs/cifar.yaml: -------------------------------------------------------------------------------- 1 | dataset: "cifar10" 2 | model: "cnn_cifar" # "linear", "mlp", "mlp_tanh", "cnn", "cnn_tanh" or "cnn_cifar" 3 | do_accounting: True 4 | label_privacy: False 5 | binary: False 6 | pca_dims: 0 7 | batch_size: 200 8 | momentum_mass: 0.5 9 | num_epochs: 150 10 | optimizer: "sgd" # "sgd" or "adam" 11 | step_size: 0.03 12 | weight_decay: 0 13 | sigma: 0.5 14 | norm_clip: 1 15 | delta: 1e-10 16 | -------------------------------------------------------------------------------- /configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset: "mnist" 2 | model: "cnn_tanh" # "linear", "mlp", "mlp_tanh", "cnn", "cnn_tanh" or "cnn_cifar" 3 | do_accounting: True 4 | label_privacy: False 5 | binary: False 6 | pca_dims: 0 7 | batch_size: 600 8 | momentum_mass: 0.5 9 | num_epochs: 50 10 | optimizer: "sgd" # "sgd" or "adam" 11 | step_size: 0.03 12 | weight_decay: 0 13 | sigma: 0.5 14 | norm_clip: 8 15 | delta: 1e-9 -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import array 10 | import gzip 11 | import logging 12 | import os 13 | from os import path 14 | import struct 15 | import math 16 | import urllib.request 17 | from torchvision import datasets as torch_datasets 18 | from torchvision import transforms 19 | 20 | import numpy as np 21 | import numpy.random as npr 22 | from sklearn.decomposition import PCA 23 | 24 | 25 | _DATA_FOLDER = "data/" 26 | 27 | 28 | def _download(url, data_folder, filename): 29 | """ 30 | Download a URL to a file in the temporary data directory, if it does not 31 | already exist. 32 | """ 33 | if not path.exists(data_folder): 34 | os.makedirs(data_folder) 35 | out_file = path.join(data_folder, filename) 36 | if not path.isfile(out_file): 37 | urllib.request.urlretrieve(url, out_file) 38 | logging.info(f"Downloaded {url} to {data_folder}") 39 | 40 | 41 | def _partial_flatten(x): 42 | """ 43 | Flatten all but the first dimension of an ndarray. 44 | """ 45 | return np.reshape(x, (x.shape[0], -1)) 46 | 47 | 48 | def _one_hot(x, k, dtype=np.float32): 49 | """ 50 | Create a one-hot encoding of x of size k. 51 | """ 52 | return np.array(x[:, None] == np.arange(k), dtype) 53 | 54 | 55 | def mnist_raw(dataset): 56 | """ 57 | Download and parse the raw MNIST dataset. 58 | """ 59 | 60 | if dataset == "mnist": 61 | # mirror of http://yann.lecun.com/exdb/mnist/: 62 | base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" 63 | elif dataset == "fmnist": 64 | base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" 65 | elif dataset == "kmnist": 66 | base_url = "http://codh.rois.ac.jp/kmnist/dataset/kmnist/" 67 | else: 68 | raise RuntimeError("Unknown dataset: " + dataset) 69 | data_folder = path.join(_DATA_FOLDER, dataset) 70 | 71 | def parse_labels(filename): 72 | """ 73 | Parses labels in MNIST raw label file. 74 | """ 75 | with gzip.open(filename, "rb") as fh: 76 | _ = struct.unpack(">II", fh.read(8)) 77 | return np.array(array.array("B", fh.read()), dtype=np.uint8) 78 | 79 | def parse_images(filename): 80 | """ 81 | Parses images in MNIST raw label file. 82 | """ 83 | with gzip.open(filename, "rb") as fh: 84 | _, num_DATA_FOLDER, rows, cols = struct.unpack(">IIII", fh.read(16)) 85 | return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( 86 | num_DATA_FOLDER, rows, cols 87 | ) 88 | 89 | # download all MNIST files: 90 | for filename in [ 91 | "train-images-idx3-ubyte.gz", 92 | "train-labels-idx1-ubyte.gz", 93 | "t10k-images-idx3-ubyte.gz", 94 | "t10k-labels-idx1-ubyte.gz", 95 | ]: 96 | _download(base_url + filename, data_folder, filename) 97 | 98 | # parse all images and labels: 99 | train_images = parse_images(path.join(data_folder, "train-images-idx3-ubyte.gz")) 100 | train_labels = parse_labels(path.join(data_folder, "train-labels-idx1-ubyte.gz")) 101 | test_images = parse_images(path.join(data_folder, "t10k-images-idx3-ubyte.gz")) 102 | test_labels = parse_labels(path.join(data_folder, "t10k-labels-idx1-ubyte.gz")) 103 | return train_images, train_labels, test_images, test_labels 104 | 105 | 106 | def preprocess_data(train_images, train_labels, test_images, test_labels, 107 | binary, permute_train, normalize, pca_dims): 108 | if binary: 109 | num_labels = 2 110 | train_mask = np.logical_or(train_labels == 0, train_labels == 1) 111 | test_mask = np.logical_or(test_labels == 0, test_labels == 1) 112 | train_images, train_labels = train_images[train_mask], train_labels[train_mask] 113 | test_images, test_labels = test_images[test_mask], test_labels[test_mask] 114 | else: 115 | num_labels = np.max(test_labels) + 1 116 | train_labels = _one_hot(train_labels, num_labels) 117 | test_labels = _one_hot(test_labels, num_labels) 118 | 119 | if pca_dims > 0: 120 | pca = PCA(n_components=pca_dims, svd_solver='full') 121 | pca.fit(train_images) 122 | train_images = pca.transform(train_images) 123 | test_images = pca.transform(test_images) 124 | 125 | if normalize: 126 | train_images /= np.linalg.norm(train_images, 2, 1)[:, None] 127 | test_images /= np.linalg.norm(test_images, 2, 1)[:, None] 128 | 129 | # permute training data: 130 | if permute_train: 131 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 132 | train_images = train_images[perm] 133 | train_labels = train_labels[perm] 134 | return train_images, train_labels, test_images, test_labels 135 | 136 | 137 | def mnist(dataset="mnist", binary=False, permute_train=False, normalize=False, pca_dims=0): 138 | """ 139 | Download, parse and process MNIST data to unit scale and one-hot labels. 140 | """ 141 | 142 | # obtain raw MNIST data: 143 | train_images, train_labels, test_images, test_labels = mnist_raw(dataset) 144 | 145 | # flatten and normalize images, create one-hot labels: 146 | train_images = _partial_flatten(train_images) / np.float32(255.0) 147 | test_images = _partial_flatten(test_images) / np.float32(255.0) 148 | 149 | return preprocess_data(train_images, train_labels, test_images, test_labels, 150 | binary, permute_train, normalize, pca_dims) 151 | 152 | 153 | def cifar(dataset="cifar10", binary=False, permute_train=False, normalize=False, pca_dims=0): 154 | 155 | data_folder = path.join(_DATA_FOLDER, dataset) 156 | normalizer = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 157 | train_transforms = transforms.Compose([transforms.ToTensor(), normalizer]) 158 | if dataset == "cifar10": 159 | train_set = torch_datasets.CIFAR10(root=data_folder, train=True, transform=train_transforms, download=True) 160 | test_set = torch_datasets.CIFAR10(root=data_folder, train=False, transform=train_transforms, download=True) 161 | elif dataset == "cifar100": 162 | train_set = torch_datasets.CIFAR100(root=data_folder, train=True, transform=train_transforms, download=True) 163 | test_set = torch_datasets.CIFAR100(root=data_folder, train=False, transform=train_transforms, download=True) 164 | 165 | train_images = [] 166 | train_labels = [] 167 | for (x, y) in train_set: 168 | train_images.append(np.rollaxis(x.numpy(), 0, 3).flatten()) 169 | train_labels.append(y) 170 | train_images = np.stack(train_images) 171 | train_labels = np.array(train_labels) 172 | 173 | test_images = [] 174 | test_labels = [] 175 | for (x, y) in test_set: 176 | test_images.append(np.rollaxis(x.numpy(), 0, 3).flatten()) 177 | test_labels.append(y) 178 | test_images = np.stack(test_images) 179 | test_labels = np.array(test_labels) 180 | 181 | return preprocess_data(train_images, train_labels, test_images, test_labels, 182 | binary, permute_train, normalize, pca_dims) 183 | 184 | 185 | def get_datastream(images, labels, batch_size, permutation=False, last_batch=True): 186 | """ 187 | Returns a data stream of `images` and corresponding `labels` in batches of 188 | size `batch_size`. Also returns the number of batches per epoch, `num_batches`. 189 | 190 | To loop through the whole dataset in permuted order, set `permutation` to `True`. 191 | To not return the last batch, set `last_batch` to `False`. 192 | """ 193 | 194 | # compute number of batches to return: 195 | num_images = images.shape[0] 196 | 197 | def permutation_datastream(): 198 | """ 199 | Data stream iterator that returns randomly permuted images until eternity. 200 | """ 201 | while True: 202 | perm = npr.permutation(num_images) 203 | for i in range(num_batches): 204 | batch_idx = perm[i * batch_size : (i + 1) * batch_size] 205 | yield images[batch_idx], labels[batch_idx], batch_idx 206 | 207 | def random_sampler_datastream(): 208 | """ 209 | Data stream iterator that returns a uniformly random batch of images until eternity. 210 | """ 211 | while True: 212 | batch_idx = npr.permutation(num_images)[:batch_size] 213 | yield images[batch_idx], labels[batch_idx], batch_idx 214 | 215 | # return iterator factory: 216 | if permutation: 217 | num_batches = int((math.ceil if last_batch else math.floor)(float(num_images) / float(batch_size))) 218 | return random_sampler_datastream, num_batches 219 | else: 220 | num_complete_batches, leftover = divmod(num_images, batch_size) 221 | num_batches = num_complete_batches + (last_batch and bool(leftover)) 222 | return permutation_datastream, num_batches 223 | -------------------------------------------------------------------------------- /mnist_logistic_reconstruction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import math 11 | import torch 12 | import random 13 | import numpy as np 14 | from tqdm import tqdm 15 | import os 16 | import matplotlib.pyplot as plt 17 | 18 | import sys 19 | sys.path.append("fisher_information_loss") 20 | import models 21 | import dataloading 22 | 23 | def recons_attack(model, X, y, lam, link_func): 24 | """ 25 | Runs the Balle et al. GLM attack https://arxiv.org/abs/2201.04845. 26 | """ 27 | def compute_grad(model, X, y): 28 | return ((X @ model.theta).sigmoid() - y)[:, None] * X 29 | n = len(y) 30 | grad = compute_grad(model, X, y) 31 | B1 = (grad.sum(0)[None, :] - grad)[:, 0] 32 | denom = B1 + n * lam * model.theta[0][None] 33 | X_hat = (grad.sum(0)[None, :] - grad + n * lam * model.theta[None, :]) / denom[:, None] 34 | y_hat = link_func(X_hat @ model.theta) + denom 35 | return X_hat, y_hat 36 | 37 | def compute_correct_ratio(etas, num_bins, predictions, target): 38 | order = etas.argsort() 39 | bin_size = len(target) // num_bins + 1 40 | bin_accs = [] 41 | for prediction in predictions: 42 | prediction = np.array(prediction) 43 | correct = (prediction == target) 44 | bin_accs.append([correct[order[lower:lower + bin_size]].mean() 45 | or lower in range(0, len(correct), bin_size)]) 46 | return np.array(bin_accs) 47 | 48 | parser = argparse.ArgumentParser(description="Evaluate GLM reconstruction attack.") 49 | parser.add_argument("--data_folder", default="data/", type=str, 50 | help="folder in which to store data") 51 | parser.add_argument("--num_trials", default=10000, type=int, 52 | help="Number of trials") 53 | parser.add_argument("--lam", default=0.01, type=float, 54 | help="regularization parameter for logistic regression") 55 | parser.add_argument("--sigma", default=1e-5, type=float, 56 | help="Gaussian noise parameter for output perturbation") 57 | args = parser.parse_args() 58 | 59 | train_data = dataloading.load_dataset( 60 | name="mnist", split="train", normalize=False, 61 | num_classes=2, root=args.data_folder, regression=False) 62 | test_data = dataloading.load_dataset( 63 | name="mnist", split="test", normalize=False, 64 | num_classes=2, root=args.data_folder, regression=False) 65 | train_data['features'] = torch.cat([torch.ones(len(train_data['targets']), 1), train_data['features']], 1) 66 | test_data['features'] = torch.cat([torch.ones(len(test_data['targets']), 1), test_data['features']], 1) 67 | 68 | model = models.get_model("logistic") 69 | model.train(train_data, l2=args.lam, weights=None) 70 | true_theta = model.theta.clone() 71 | 72 | predictions = model.predict(train_data["features"]) 73 | acc = ((predictions == train_data["targets"]).float()).mean() 74 | print(f"Training accuracy of classifier {acc.item():.3f}") 75 | 76 | predictions = model.predict(test_data["features"]) 77 | acc = ((predictions == test_data["targets"]).float()).mean() 78 | print(f"Test accuracy of classifier {acc.item():.3f}") 79 | 80 | J = model.influence_jacobian(train_data)[:, :, 1:-1] / args.sigma 81 | etas = J.pow(2).sum(1).mean(1) 82 | 83 | X = train_data["features"] 84 | y = train_data["targets"].float() 85 | n, d = X.size(0), X.size(1) - 1 86 | link_func = torch.sigmoid 87 | 88 | X_means = torch.zeros(X.shape) 89 | errors = torch.zeros(len(y)) 90 | with torch.no_grad(): 91 | print('Running reconstruction attack for %d trials:' % args.num_trials) 92 | for i in tqdm(range(args.num_trials)): 93 | model.theta = true_theta + args.sigma * torch.randn(true_theta.size()) 94 | X_hat, y_hat = recons_attack(model, X, y, args.lam, link_func) 95 | X_means += X_hat / args.num_trials 96 | errors += (X_hat[:, 1:] - X[:, 1:]).pow(2).sum(1) / (d * args.num_trials) 97 | X_means = X_means[:, 1:] 98 | 99 | # filter out examples that the attack failed on 100 | mask = torch.logical_not(torch.isnan(errors)) 101 | etas = etas[mask] 102 | errors = errors[mask] 103 | _, order = etas.reciprocal().sort() 104 | 105 | # plot MSE lower bound vs. true MSE 106 | plt.figure(figsize=(8,5)) 107 | below_bound = etas.reciprocal() < errors 108 | plt.scatter(etas[below_bound].reciprocal().detach(), errors[below_bound].detach(), s=10) 109 | plt.scatter(etas[torch.logical_not(below_bound)].reciprocal().detach(), errors[torch.logical_not(below_bound)].detach(), 110 | s=10, color='indianred') 111 | plt.plot(np.power(10, np.arange(-5.5, 3, 0.1)), np.power(10, np.arange(-5.5, 3, 0.1)), 'k', label='Lower bound') 112 | plt.axvline(x=1, color='k', linestyle=':') 113 | plt.xticks(fontsize=20) 114 | plt.xlim([1e-6, 1e4]) 115 | plt.xlabel('Predicted MSE', fontsize=20) 116 | plt.xscale('log') 117 | plt.yticks(fontsize=20) 118 | plt.ylabel('Recons. attack MSE', fontsize=20) 119 | plt.yscale('log') 120 | plt.legend(loc='lower right', fontsize=20) 121 | os.makedirs("figs", exist_ok=True) 122 | plt.savefig("figs/recons_mse.pdf", bbox_inches="tight") 123 | 124 | # plot reconstructed samples 125 | plt.figure(figsize=(48, 6)) 126 | for i in range(8): 127 | plt.subplot(1, 8, i+1) 128 | plt.imshow(X[mask][order[i], 1:].clamp(0, 1).view(28, 28).detach()) 129 | plt.axis('off') 130 | plt.savefig("figs/orig_highest8.pdf", bbox_inches="tight") 131 | 132 | plt.figure(figsize=(48, 6)) 133 | for i in range(8): 134 | plt.subplot(1, 8, i+1) 135 | plt.imshow(X_means[mask][order[i]].clamp(0, 1).view(28, 28).detach()) 136 | plt.axis('off') 137 | plt.savefig("figs/recons_highest8.pdf", bbox_inches="tight") 138 | 139 | plt.figure(figsize=(48, 6)) 140 | for i in range(8): 141 | plt.subplot(1, 8, i+1) 142 | plt.imshow(X[mask][order[-i-1], 1:].clamp(0, 1).view(28, 28).detach()) 143 | plt.axis('off') 144 | plt.savefig("figs/orig_lowest8.pdf", bbox_inches="tight") 145 | 146 | plt.figure(figsize=(48, 6)) 147 | for i in range(8): 148 | plt.subplot(1, 8, i+1) 149 | plt.imshow(X_means[mask][order[-i-1]].clamp(0, 1).view(28, 28).detach()) 150 | plt.axis('off') 151 | plt.savefig("figs/recons_lowest8.pdf", bbox_inches="tight") 152 | -------------------------------------------------------------------------------- /mnist_logistic_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import math 11 | import torch 12 | import numpy as np 13 | import os 14 | import matplotlib.pyplot as plt 15 | 16 | import sys 17 | sys.path.append("fisher_information_loss") 18 | import models 19 | import dataloading 20 | 21 | parser = argparse.ArgumentParser(description="MNIST training with FIL.") 22 | parser.add_argument("--data_folder", default="data/", type=str, 23 | help="folder in which to store data") 24 | parser.add_argument("--num_trials", default=10, type=int, 25 | help="number of repeated trials") 26 | parser.add_argument("--lam", default=0.01, type=float, 27 | help="l2 regularization parameter") 28 | parser.add_argument("--sigma", default=0.01, type=float, 29 | help="Gaussian noise multiplier") 30 | args = parser.parse_args() 31 | 32 | train_data = dataloading.load_dataset( 33 | name="mnist", split="train", normalize=False, 34 | num_classes=2, root=args.data_folder, regression=False) 35 | test_data = dataloading.load_dataset( 36 | name="mnist", split="test", normalize=False, 37 | num_classes=2, root=args.data_folder, regression=False) 38 | n = len(train_data["targets"]) 39 | 40 | all_etas, all_epsilons, all_rdp_epsilons = [], [], [] 41 | 42 | for i in range(args.num_trials): 43 | 44 | model = models.get_model("logistic") 45 | model.train(train_data, l2=args.lam, weights=None) 46 | # Renyi-DP accounting 47 | rdp_eps = 4 / (n * args.lam * args.sigma) ** 2 48 | # FIL accounting 49 | J = model.influence_jacobian(train_data)[:, :, :-1] / args.sigma 50 | etas = J.pow(2).sum(1).mean(1).sqrt() 51 | print(f"Trial {i+1:d}: RDP epsilon = {rdp_eps:.4f}, Max FIL eta = {etas.max():.4f}") 52 | model.theta = model.theta + args.sigma * torch.randn_like(model.theta) 53 | 54 | all_etas.append(etas.detach().numpy()) 55 | all_rdp_epsilons.append(rdp_eps) 56 | 57 | predictions = model.predict(train_data["features"]) 58 | acc = ((predictions == train_data["targets"]).float()).mean() 59 | print(f"Training accuracy of classifier {acc.item():.3f}") 60 | 61 | predictions = model.predict(test_data["features"]) 62 | acc = ((predictions == test_data["targets"]).float()).mean() 63 | print(f"Test accuracy of classifier {acc.item():.3f}") 64 | 65 | all_etas = np.stack(all_etas, 0) 66 | all_rdp_epsilons = np.stack(all_rdp_epsilons, 0) 67 | 68 | fil_bound = 1 / np.power(all_etas, 2).mean(0) 69 | rdp_bound = 0.25 / (math.exp(all_rdp_epsilons.mean()) - 1) 70 | 71 | plt.figure(figsize=(8,5)) 72 | _ = plt.hist(np.log10(fil_bound), bins=100, label='dFIL bound', color='silver', edgecolor='black', linewidth=0.3) 73 | plt.axvline(x=np.log10(rdp_bound), color='k', linestyle='--', label='RDP bound') 74 | plt.axvline(x=0, color='k', linestyle=':') 75 | plt.xlabel('MSE lower bound', fontsize=20) 76 | plt.ylabel('Count', fontsize=20) 77 | plt.xticks(np.arange(-1, 11, 2), labels=['$10^{%d}$' % t for t in np.arange(-1, 11, 2)], fontsize=20) 78 | plt.yticks(fontsize=20) 79 | plt.legend(loc='upper left', fontsize=20) 80 | os.makedirs("figs", exist_ok=True) 81 | plt.savefig("figs/mnist_linear_hist.pdf", bbox_inches="tight") 82 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import itertools 10 | import logging 11 | 12 | import jax 13 | import jax.numpy as jnp 14 | import jax.random as jnr 15 | import hydra 16 | 17 | from jax import grad 18 | from jax.experimental import optimizers 19 | from jax.tree_util import tree_flatten, tree_unflatten 20 | 21 | import math 22 | import accountant 23 | import datasets 24 | import trainer 25 | import utils 26 | import time 27 | 28 | 29 | def batch_predict(predict, params, images, batch_size): 30 | num_images = images.shape[0] 31 | num_batches = int(math.ceil(float(num_images) / float(batch_size))) 32 | predictions = [] 33 | for i in range(num_batches): 34 | lower = i * batch_size 35 | upper = min((i+1) * batch_size, num_images) 36 | predictions.append(predict(params, images[lower:upper])) 37 | return jnp.concatenate(predictions) 38 | 39 | 40 | @hydra.main(config_path="configs", config_name="mnist") 41 | def main(cfg): 42 | 43 | # set up random number generator: 44 | logging.info(f"Running using JAX {jax.__version__}...") 45 | rng = jnr.PRNGKey(int(time.time())) 46 | 47 | # create dataloader for MNIST dataset: 48 | if cfg.dataset.startswith("cifar"): 49 | num_channels = 3 50 | image_size = 32 51 | train_images, train_labels, test_images, test_labels = datasets.cifar( 52 | dataset=cfg.dataset, binary=cfg.binary, pca_dims=cfg.pca_dims) 53 | else: 54 | num_channels = 1 55 | image_size = 28 56 | train_images, train_labels, test_images, test_labels = datasets.mnist( 57 | dataset=cfg.dataset, binary=cfg.binary, pca_dims=cfg.pca_dims) 58 | logging.info(f"Training set max variance: %.4f" % train_images.var(0).max()) 59 | 60 | num_samples, d = train_images.shape 61 | num_labels = train_labels.shape[1] 62 | if num_labels == 2: 63 | num_labels = 1 64 | if cfg.model.startswith("cnn"): 65 | assert cfg.pca_dims == 0, f"Cannot use PCA with {cfg.model} model." 66 | image_shape = (-1, image_size, image_size, num_channels) 67 | train_images = jnp.reshape(train_images, image_shape) 68 | test_images = jnp.reshape(test_images, image_shape) 69 | data_stream, num_batches = datasets.get_datastream( 70 | train_images, train_labels, cfg.batch_size 71 | ) 72 | batches = data_stream() 73 | 74 | # set up model: 75 | if cfg.model.startswith("cnn"): 76 | input_shape = (-1, image_size, image_size, num_channels) 77 | else: 78 | input_shape = (-1, d) 79 | init_params, predict = utils.get_model(rng, cfg.model, input_shape, num_labels) 80 | num_params = sum(p.size for p in tree_flatten(init_params)[0]) 81 | 82 | # create optimizer: 83 | if cfg.optimizer == "sgd": 84 | opt_init, opt_update, get_params = optimizers.momentum( 85 | cfg.step_size, cfg.momentum_mass 86 | ) 87 | elif cfg.optimizer == "adam": 88 | opt_init, opt_update, get_params = optimizers.adam(cfg.step_size) 89 | else: 90 | raise ValueError(f"Unknown optimizer: {cfg.optimizer}") 91 | opt_state = opt_init(init_params) 92 | 93 | # get loss function and update functions: 94 | loss = trainer.get_loss_func(predict) 95 | grad_func = trainer.get_grad_func(loss, norm_clip=cfg.norm_clip, soft_clip=True) 96 | update = trainer.get_update_func( 97 | get_params, grad_func, opt_update, norm_clip=cfg.norm_clip, 98 | reshape=cfg.model.startswith("cnn") 99 | ) 100 | 101 | # get function that computes the Jacobian norms for privacy accounting: 102 | gelu_approx = 1.115 103 | fil_accountant = accountant.get_grad_jacobian_trace_func( 104 | grad_func, get_params, reshape=cfg.model.startswith("cnn"), 105 | label_privacy=cfg.label_privacy 106 | ) 107 | dp_accountant = accountant.get_dp_accounting_func(cfg.batch_size, cfg.sigma / gelu_approx) 108 | 109 | # compute subsampling factor 110 | if cfg.sigma > 0: 111 | eps = math.sqrt(2 * math.log(1.25 / cfg.delta)) * 2 * gelu_approx / cfg.sigma 112 | q = float(cfg.batch_size) / num_samples 113 | subsampling_factor = q / (q + (1-q) * math.exp(-eps)) 114 | else: 115 | subsampling_factor = 0 116 | logging.info(f"Subsampling factor is {subsampling_factor:.4f}") 117 | 118 | # train the model: 119 | logging.info(f"Training {cfg.model} model with {num_params} parameters using {cfg.optimizer}...") 120 | etas_squared = jnp.zeros((cfg.num_epochs, train_images.shape[0])) 121 | epsilons = jnp.zeros(cfg.num_epochs) 122 | rdp_epsilons = jnp.zeros(cfg.num_epochs) 123 | train_accs = jnp.zeros(cfg.num_epochs) 124 | test_accs = jnp.zeros(cfg.num_epochs) 125 | num_iters = 0 126 | for epoch in range(cfg.num_epochs): 127 | 128 | # perform full training sweep through the data: 129 | itercount = itertools.count() 130 | if epoch > 0: 131 | etas_squared = etas_squared.at[epoch].set(etas_squared[epoch-1]) 132 | 133 | for batch_counter in range(num_batches): 134 | 135 | # get next batch: 136 | num_iters += 1 137 | i = next(itercount) 138 | rng = jnr.fold_in(rng, i) 139 | images, labels, batch_idx = next(batches) 140 | batch = (images, labels) 141 | 142 | # update privacy loss: 143 | if cfg.sigma > 0 and cfg.do_accounting: 144 | etas_batch = fil_accountant(rng, opt_state, batch) / cfg.sigma / cfg.norm_clip 145 | etas_squared = etas_squared.at[epoch, batch_idx].add( 146 | subsampling_factor * jnp.power(etas_batch, 2), unique_indices=True 147 | ) 148 | 149 | # perform private parameter update: 150 | opt_state = update(i, rng, opt_state, batch, cfg.sigma, cfg.weight_decay) 151 | 152 | 153 | # measure training and test accuracy, and average privacy loss: 154 | params = get_params(opt_state) 155 | spectral_norm = utils.estimate_spectral_norm(lambda x: predict(params, x), input_shape) 156 | train_predictions = batch_predict(predict, params, train_images, cfg.batch_size) 157 | test_predictions = batch_predict(predict, params, test_images, cfg.batch_size) 158 | train_accuracy = utils.accuracy(train_predictions, train_labels) 159 | test_accuracy = utils.accuracy(test_predictions, test_labels) 160 | train_accs = train_accs.at[epoch].set(train_accuracy) 161 | test_accs = test_accs.at[epoch].set(test_accuracy) 162 | params, _ = tree_flatten(params) 163 | params_norm = math.sqrt(sum([jnp.power(p, 2).sum() for p in params])) 164 | if cfg.sigma > 0 and cfg.do_accounting: 165 | median_eta = jnp.median(jnp.sqrt(etas_squared[epoch])) 166 | max_eta = jnp.sqrt(etas_squared[epoch]).max() 167 | delta = 1e-5 168 | epsilon = dp_accountant(num_iters, len(train_labels), delta) 169 | epsilons = epsilons.at[epoch].set(epsilon) 170 | rdp_epsilon = dp_accountant(num_iters, len(train_labels), delta, alpha=2) 171 | rdp_epsilons = rdp_epsilons.at[epoch].set(rdp_epsilon) 172 | 173 | # print out progress: 174 | logging.info(f"Epoch {epoch + 1}:") 175 | logging.info(f" -> training accuracy = {train_accuracy:.4f}") 176 | logging.info(f" -> test accuracy = {test_accuracy:.4f}") 177 | logging.info(f" -> parameter norm = {params_norm:.4f}, spectral norm = {spectral_norm:.4f}") 178 | if cfg.sigma > 0 and cfg.do_accounting: 179 | logging.info(f" -> Median FIL privacy loss = {median_eta:.4f}") 180 | logging.info(f" -> Max FIL privacy loss = {max_eta:.4f}") 181 | logging.info(f" -> DP privacy loss = ({epsilon:.4f}, {delta:.2e})") 182 | logging.info(f" -> 2-RDP privacy loss = {rdp_epsilon:.4f}") 183 | 184 | etas = jnp.sqrt(etas_squared) if cfg.sigma > 0 and cfg.do_accounting else float("inf") 185 | 186 | return etas, epsilons, rdp_epsilons, train_accs, test_accs 187 | 188 | 189 | # run all the things: 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import jax.numpy as jnp 10 | import jax.random as jnr 11 | from jax import jit, grad, vmap, nn 12 | from jax.tree_util import tree_flatten, tree_unflatten 13 | import math 14 | 15 | 16 | def get_loss_func(predict): 17 | """ 18 | Returns the loss function for the specified `predict`ion function. 19 | """ 20 | 21 | @jit 22 | def loss(params, inputs, targets): 23 | """ 24 | Multi-class loss entropy loss function for model with parameters `params` 25 | and the specified `inputs` and one-hot `targets`. 26 | """ 27 | predictions = nn.log_softmax(predict(params, inputs)) 28 | if predictions.ndim == 1: 29 | return -jnp.sum(predictions * targets) 30 | return -jnp.mean(jnp.sum(predictions * targets, axis=-1)) 31 | 32 | return loss 33 | 34 | 35 | def get_grad_func(loss, norm_clip=0, soft_clip=False): 36 | 37 | @jit 38 | def clipped_grad(params, inputs, targets): 39 | grads = grad(loss)(params, inputs, targets) 40 | if norm_clip == 0: 41 | return grads 42 | else: 43 | nonempty_grads, tree_def = tree_flatten(grads) 44 | total_grad_norm = jnp.add(jnp.linalg.norm( 45 | [jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads]), 1e-7) 46 | if soft_clip: 47 | divisor = nn.gelu(total_grad_norm / norm_clip - 1) + 1 48 | else: 49 | divisor = jnp.maximum(total_grad_norm / norm_clip, 1.) 50 | normalized_nonempty_grads = [g / divisor for g in nonempty_grads] 51 | return tree_unflatten(tree_def, normalized_nonempty_grads) 52 | 53 | return clipped_grad 54 | 55 | 56 | def get_update_func(get_params, grad_func, opt_update, norm_clip=0, reshape=True): 57 | """ 58 | Returns the parameter update function for the specified `predict`ion function. 59 | """ 60 | 61 | @jit 62 | def update(i, rng, opt_state, batch, sigma, weight_decay): 63 | """ 64 | Function that performs `i`-th model update using the specified `batch` on 65 | optimizer state `opt_state`. Updates are privatized by noise addition 66 | with variance `sigma`. 67 | """ 68 | 69 | # compute parameter gradient: 70 | inputs, targets = batch 71 | if reshape: 72 | inputs = jnp.expand_dims(inputs, 1) 73 | params = get_params(opt_state) 74 | multiplier = 1 if norm_clip == 0 else norm_clip 75 | 76 | # add noise to gradients: 77 | grads = vmap(grad_func, in_axes=(None, 0, 0))(params, inputs, targets) 78 | grads_flat, grads_treedef = tree_flatten(grads) 79 | grads_flat = [g.sum(0) for g in grads_flat] 80 | rngs = jnr.split(rng, len(grads_flat)) 81 | noisy_grads = [ 82 | (g + multiplier * sigma * jnr.normal(r, g.shape)) / len(targets) 83 | for r, g in zip(rngs, grads_flat) 84 | ] 85 | 86 | # weight decay 87 | params_flat, _ = tree_flatten(params) 88 | noisy_grads = [ 89 | g + weight_decay * param 90 | for g, param in zip(noisy_grads, params_flat) 91 | ] 92 | noisy_grads = tree_unflatten(grads_treedef, noisy_grads) 93 | 94 | # perform parameter update: 95 | return opt_update(i, noisy_grads, opt_state) 96 | 97 | return update -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | from jax.experimental import stax 12 | 13 | DTYPE_MAPPING = { 14 | "float32": "f32", 15 | "float64": "f64", 16 | "int32": "s32", 17 | "int64": "s64", 18 | "uint32": "u32", 19 | "uint64": "u64", 20 | } 21 | 22 | 23 | def _l2_normalize(x, eps=1e-7): 24 | return x * jax.lax.rsqrt((x ** 2).sum() + eps) 25 | 26 | 27 | def estimate_spectral_norm(f, input_shape, seed=0, n_steps=20): 28 | input_shape = tuple([1] + [input_shape[i] for i in range(1, len(input_shape))]) 29 | rng = jax.random.PRNGKey(seed) 30 | u0 = jax.random.normal(rng, input_shape) 31 | v0 = jnp.zeros_like(f(u0)) 32 | def fun(carry, _): 33 | u, v = carry 34 | v, f_vjp = jax.vjp(f, u) 35 | v = _l2_normalize(v) 36 | u, = f_vjp(v) 37 | u = _l2_normalize(u) 38 | return (u, v), None 39 | (u, v), _ = jax.lax.scan(fun, (u0, v0), xs=None, length=n_steps) 40 | return jnp.vdot(v, f(u)) 41 | 42 | 43 | def accuracy(predictions, targets): 44 | """ 45 | Compute accuracy of `predictions` given the associated `targets`. 46 | """ 47 | target_class = jnp.argmax(targets, axis=-1) 48 | predicted_class = jnp.argmax(predictions, axis=-1) 49 | return jnp.mean(predicted_class == target_class) 50 | 51 | 52 | def get_model(rng, model_name, input_shape, num_labels): 53 | """ 54 | Returns model specified by `model_name`. Model is initialized using the 55 | specified random number generator `rng`. 56 | 57 | Optionally, the input image `height` and `width` can be specified as well. 58 | """ 59 | 60 | # initialize convolutional network: 61 | if model_name == "cnn": 62 | init_random_params, predict = stax.serial( 63 | stax.Conv(16, (8, 8), padding="SAME", strides=(2, 2)), 64 | stax.Gelu, 65 | stax.AvgPool((2, 2), (1, 1)), 66 | stax.Conv(32, (4, 4), padding="VALID", strides=(2, 2)), 67 | stax.Gelu, 68 | stax.AvgPool((2, 2), (1, 1)), 69 | stax.Flatten, 70 | stax.Dense(32), 71 | stax.Gelu, 72 | stax.Dense(num_labels), 73 | ) 74 | _, init_params = init_random_params(rng, input_shape) 75 | 76 | elif model_name == "cnn_tanh": 77 | init_random_params, predict = stax.serial( 78 | stax.Conv(16, (8, 8), padding="SAME", strides=(2, 2)), 79 | stax.Tanh, 80 | stax.AvgPool((2, 2), (1, 1)), 81 | stax.Conv(32, (4, 4), padding="VALID", strides=(2, 2)), 82 | stax.Tanh, 83 | stax.AvgPool((2, 2), (1, 1)), 84 | stax.Flatten, 85 | stax.Dense(32), 86 | stax.Tanh, 87 | stax.Dense(num_labels), 88 | ) 89 | _, init_params = init_random_params(rng, input_shape) 90 | 91 | elif model_name == "cnn_cifar": 92 | init_random_params, predict = stax.serial( 93 | stax.Conv(32, (3, 3), padding="SAME", strides=(1, 1)), 94 | stax.Tanh, 95 | stax.Conv(32, (3, 3), padding="SAME", strides=(1, 1)), 96 | stax.Tanh, 97 | stax.AvgPool((2, 2), (2, 2)), 98 | stax.Conv(64, (3, 3), padding="SAME", strides=(1, 1)), 99 | stax.Tanh, 100 | stax.Conv(64, (3, 3), padding="SAME", strides=(1, 1)), 101 | stax.Tanh, 102 | stax.AvgPool((2, 2), (2, 2)), 103 | stax.Conv(128, (3, 3), padding="SAME", strides=(1, 1)), 104 | stax.Tanh, 105 | stax.Conv(128, (3, 3), padding="SAME", strides=(1, 1)), 106 | stax.Tanh, 107 | stax.AvgPool((2, 2), (2, 2)), 108 | stax.Flatten, 109 | stax.Dense(128), 110 | stax.Tanh, 111 | stax.Dense(num_labels), 112 | ) 113 | _, init_params = init_random_params(rng, input_shape) 114 | 115 | # initialize multi-layer perceptron: 116 | elif model_name == "mlp": 117 | init_random_params, predict = stax.serial( 118 | stax.Dense(256), 119 | stax.Gelu, 120 | stax.Dense(256), 121 | stax.Gelu, 122 | stax.Dense(num_labels), 123 | ) 124 | _, init_params = init_random_params(rng, input_shape) 125 | 126 | elif model_name == "mlp_tanh": 127 | init_random_params, predict = stax.serial( 128 | stax.Dense(256), 129 | stax.Tanh, 130 | stax.Dense(256), 131 | stax.Tanh, 132 | stax.Dense(num_labels), 133 | ) 134 | _, init_params = init_random_params(rng, input_shape) 135 | # initialize linear model: 136 | elif model_name == "linear": 137 | init_random_params, predict_raw = stax.Dense(num_labels) 138 | def predict(params, inputs): 139 | logits = predict_raw(params, inputs) 140 | return jnp.hstack([logits, jnp.zeros(logits.shape)]) 141 | _, init_params = init_random_params(rng, input_shape) 142 | 143 | else: 144 | raise ValueError(f"Unknown model: {model_name}") 145 | 146 | # return initial model parameters and prediction function: 147 | return init_params, predict 148 | --------------------------------------------------------------------------------