├── ACKNOWLEDGEMENTS.txt ├── ATTRIBUTION.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── assets └── FLAIR_sample.jpeg ├── benchmark ├── __init__.py ├── central_main.py ├── federated_main.py ├── flair_data.py ├── flair_metrics.py └── flair_model.py ├── download_dataset.py ├── explore_images.ipynb ├── explore_labels.ipynb ├── prepare_dataset.py ├── prepare_tfrecords.py └── requirements.txt /ACKNOWLEDGEMENTS.txt: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | 7 | AQR Capital Management, LLC, Lambda Foundry, Inc., PyData Development Team, and open source contributors (pandas) 8 | BSD 3-Clause License 9 | 10 | Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team 11 | All rights reserved. 12 | 13 | Copyright (c) 2011-2022, Open source contributors. 14 | 15 | Redistribution and use in source and binary forms, with or without 16 | modification, are permitted provided that the following conditions are met: 17 | 18 | * Redistributions of source code must retain the above copyright notice, this 19 | list of conditions and the following disclaimer. 20 | 21 | * Redistributions in binary form must reproduce the above copyright notice, 22 | this list of conditions and the following disclaimer in the documentation 23 | and/or other materials provided with the distribution. 24 | 25 | * Neither the name of the copyright holder nor the names of its 26 | contributors may be used to endorse or promote products derived from 27 | this software without specific prior written permission. 28 | 29 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 30 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 31 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 32 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 33 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 34 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 35 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 36 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 37 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 38 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 39 | 40 | Andrew Collette and contributors (h5py) 41 | Copyright (c) 2008 Andrew Collette and contributors 42 | All rights reserved. 43 | 44 | Redistribution and use in source and binary forms, with or without 45 | modification, are permitted provided that the following conditions are 46 | met: 47 | 48 | 1. Redistributions of source code must retain the above copyright 49 | notice, this list of conditions and the following disclaimer. 50 | 51 | 2. Redistributions in binary form must reproduce the above copyright 52 | notice, this list of conditions and the following disclaimer in the 53 | documentation and/or other materials provided with the 54 | distribution. 55 | 56 | 3. Neither the name of the copyright holder nor the names of its 57 | contributors may be used to endorse or promote products derived from 58 | this software without specific prior written permission. 59 | 60 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 61 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 62 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 63 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 64 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 65 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 66 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 67 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 68 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 69 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 70 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 71 | 72 | IPython Development Team and Jupyter Development Team (Jupyter notebook) 73 | This project is licensed under the terms of the Modified BSD License 74 | (also known as New or Revised or 3-Clause BSD), as follows: 75 | 76 | - Copyright (c) 2001-2015, IPython Development Team 77 | - Copyright (c) 2015-, Jupyter Development Team 78 | 79 | All rights reserved. 80 | 81 | Redistribution and use in source and binary forms, with or without 82 | modification, are permitted provided that the following conditions are met: 83 | 84 | Redistributions of source code must retain the above copyright notice, this 85 | list of conditions and the following disclaimer. 86 | 87 | Redistributions in binary form must reproduce the above copyright notice, this 88 | list of conditions and the following disclaimer in the documentation and/or 89 | other materials provided with the distribution. 90 | 91 | Neither the name of the Jupyter Development Team nor the names of its 92 | contributors may be used to endorse or promote products derived from this 93 | software without specific prior written permission. 94 | 95 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 96 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 97 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 98 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 99 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 100 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 101 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 102 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 103 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 104 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 105 | 106 | ## About the Jupyter Development Team 107 | 108 | The Jupyter Development Team is the set of all contributors to the Jupyter project. 109 | This includes all of the Jupyter subprojects. 110 | 111 | The core team that coordinates development on GitHub can be found here: 112 | https://github.com/jupyter/. 113 | 114 | ## Our Copyright Policy 115 | 116 | Jupyter uses a shared copyright model. Each contributor maintains copyright 117 | over their contributions to Jupyter. But, it is important to note that these 118 | contributions are typically only changes to the repositories. Thus, the Jupyter 119 | source code, in its entirety is not the copyright of any single person or 120 | institution. Instead, it is the collective copyright of the entire Jupyter 121 | Development Team. If individual contributors want to maintain a record of what 122 | changes/contributions they have specific copyright on, they should indicate 123 | their copyright in the commit message of the change, when they commit the 124 | change to one of the Jupyter repositories. 125 | 126 | With this in mind, the following banner should be used in any source code file 127 | to indicate the copyright and license terms: 128 | 129 | # Copyright (c) Jupyter Development Team. 130 | # Distributed under the terms of the Modified BSD License. 131 | 132 | NumPy Developers (NumPy) 133 | Copyright (c) 2005-2022, NumPy Developers. 134 | All rights reserved. 135 | 136 | Redistribution and use in source and binary forms, with or without 137 | modification, are permitted provided that the following conditions are 138 | met: 139 | 140 | * Redistributions of source code must retain the above copyright 141 | notice, this list of conditions and the following disclaimer. 142 | 143 | * Redistributions in binary form must reproduce the above 144 | copyright notice, this list of conditions and the following 145 | disclaimer in the documentation and/or other materials provided 146 | with the distribution. 147 | 148 | * Neither the name of the NumPy Developers nor the names of any 149 | contributors may be used to endorse or promote products derived 150 | from this software without specific prior written permission. 151 | 152 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 153 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 154 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 155 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 156 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 157 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 158 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 159 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 160 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 161 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 162 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 163 | 164 | Secret Labs AB, Fredrik Lundh, Alex Clark and contributors (Pillow) 165 | The Python Imaging Library (PIL) is 166 | 167 | Copyright © 1997-2011 by Secret Labs AB 168 | Copyright © 1995-2011 by Fredrik Lundh 169 | 170 | Pillow is the friendly PIL fork. It is 171 | 172 | Copyright © 2010-2022 by Alex Clark and contributors 173 | 174 | Like PIL, Pillow is licensed under the open source HPND License: 175 | 176 | By obtaining, using, and/or copying this software and/or its associated 177 | documentation, you agree that you have read, understood, and will comply 178 | with the following terms and conditions: 179 | 180 | Permission to use, copy, modify, and distribute this software and its 181 | associated documentation for any purpose and without fee is hereby granted, 182 | provided that the above copyright notice appears in all copies, and that 183 | both that copyright notice and this permission notice appear in supporting 184 | documentation, and that the name of Secret Labs AB or the author not be 185 | used in advertising or publicity pertaining to distribution of the software 186 | without specific, written prior permission. 187 | 188 | SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS 189 | SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. 190 | IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, 191 | INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM 192 | LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE 193 | OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 194 | PERFORMANCE OF THIS SOFTWARE. 195 | 196 | The TensorFlow Authors (TensorFlow, TensorFlow Addons, Tensorflow Federated, TensorFlow Privacy) 197 | Apache License 198 | Version 2.0, January 2004 199 | http://www.apache.org/licenses/ 200 | 201 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 202 | 203 | 1. Definitions. 204 | 205 | "License" shall mean the terms and conditions for use, reproduction, 206 | and distribution as defined by Sections 1 through 9 of this document. 207 | 208 | "Licensor" shall mean the copyright owner or entity authorized by 209 | the copyright owner that is granting the License. 210 | 211 | "Legal Entity" shall mean the union of the acting entity and all 212 | other entities that control, are controlled by, or are under common 213 | control with that entity. For the purposes of this definition, 214 | "control" means (i) the power, direct or indirect, to cause the 215 | direction or management of such entity, whether by contract or 216 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 217 | outstanding shares, or (iii) beneficial ownership of such entity. 218 | 219 | "You" (or "Your") shall mean an individual or Legal Entity 220 | exercising permissions granted by this License. 221 | 222 | "Source" form shall mean the preferred form for making modifications, 223 | including but not limited to software source code, documentation 224 | source, and configuration files. 225 | 226 | "Object" form shall mean any form resulting from mechanical 227 | transformation or translation of a Source form, including but 228 | not limited to compiled object code, generated documentation, 229 | and conversions to other media types. 230 | 231 | "Work" shall mean the work of authorship, whether in Source or 232 | Object form, made available under the License, as indicated by a 233 | copyright notice that is included in or attached to the work 234 | (an example is provided in the Appendix below). 235 | 236 | "Derivative Works" shall mean any work, whether in Source or Object 237 | form, that is based on (or derived from) the Work and for which the 238 | editorial revisions, annotations, elaborations, or other modifications 239 | represent, as a whole, an original work of authorship. For the purposes 240 | of this License, Derivative Works shall not include works that remain 241 | separable from, or merely link (or bind by name) to the interfaces of, 242 | the Work and Derivative Works thereof. 243 | 244 | "Contribution" shall mean any work of authorship, including 245 | the original version of the Work and any modifications or additions 246 | to that Work or Derivative Works thereof, that is intentionally 247 | submitted to Licensor for inclusion in the Work by the copyright owner 248 | or by an individual or Legal Entity authorized to submit on behalf of 249 | the copyright owner. For the purposes of this definition, "submitted" 250 | means any form of electronic, verbal, or written communication sent 251 | to the Licensor or its representatives, including but not limited to 252 | communication on electronic mailing lists, source code control systems, 253 | and issue tracking systems that are managed by, or on behalf of, the 254 | Licensor for the purpose of discussing and improving the Work, but 255 | excluding communication that is conspicuously marked or otherwise 256 | designated in writing by the copyright owner as "Not a Contribution." 257 | 258 | "Contributor" shall mean Licensor and any individual or Legal Entity 259 | on behalf of whom a Contribution has been received by Licensor and 260 | subsequently incorporated within the Work. 261 | 262 | 2. Grant of Copyright License. Subject to the terms and conditions of 263 | this License, each Contributor hereby grants to You a perpetual, 264 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 265 | copyright license to reproduce, prepare Derivative Works of, 266 | publicly display, publicly perform, sublicense, and distribute the 267 | Work and such Derivative Works in Source or Object form. 268 | 269 | 3. Grant of Patent License. Subject to the terms and conditions of 270 | this License, each Contributor hereby grants to You a perpetual, 271 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 272 | (except as stated in this section) patent license to make, have made, 273 | use, offer to sell, sell, import, and otherwise transfer the Work, 274 | where such license applies only to those patent claims licensable 275 | by such Contributor that are necessarily infringed by their 276 | Contribution(s) alone or by combination of their Contribution(s) 277 | with the Work to which such Contribution(s) was submitted. If You 278 | institute patent litigation against any entity (including a 279 | cross-claim or counterclaim in a lawsuit) alleging that the Work 280 | or a Contribution incorporated within the Work constitutes direct 281 | or contributory patent infringement, then any patent licenses 282 | granted to You under this License for that Work shall terminate 283 | as of the date such litigation is filed. 284 | 285 | 4. Redistribution. You may reproduce and distribute copies of the 286 | Work or Derivative Works thereof in any medium, with or without 287 | modifications, and in Source or Object form, provided that You 288 | meet the following conditions: 289 | 290 | (a) You must give any other recipients of the Work or 291 | Derivative Works a copy of this License; and 292 | 293 | (b) You must cause any modified files to carry prominent notices 294 | stating that You changed the files; and 295 | 296 | (c) You must retain, in the Source form of any Derivative Works 297 | that You distribute, all copyright, patent, trademark, and 298 | attribution notices from the Source form of the Work, 299 | excluding those notices that do not pertain to any part of 300 | the Derivative Works; and 301 | 302 | (d) If the Work includes a "NOTICE" text file as part of its 303 | distribution, then any Derivative Works that You distribute must 304 | include a readable copy of the attribution notices contained 305 | within such NOTICE file, excluding those notices that do not 306 | pertain to any part of the Derivative Works, in at least one 307 | of the following places: within a NOTICE text file distributed 308 | as part of the Derivative Works; within the Source form or 309 | documentation, if provided along with the Derivative Works; or, 310 | within a display generated by the Derivative Works, if and 311 | wherever such third-party notices normally appear. The contents 312 | of the NOTICE file are for informational purposes only and 313 | do not modify the License. You may add Your own attribution 314 | notices within Derivative Works that You distribute, alongside 315 | or as an addendum to the NOTICE text from the Work, provided 316 | that such additional attribution notices cannot be construed 317 | as modifying the License. 318 | 319 | You may add Your own copyright statement to Your modifications and 320 | may provide additional or different license terms and conditions 321 | for use, reproduction, or distribution of Your modifications, or 322 | for any such Derivative Works as a whole, provided Your use, 323 | reproduction, and distribution of the Work otherwise complies with 324 | the conditions stated in this License. 325 | 326 | 5. Submission of Contributions. Unless You explicitly state otherwise, 327 | any Contribution intentionally submitted for inclusion in the Work 328 | by You to the Licensor shall be under the terms and conditions of 329 | this License, without any additional terms or conditions. 330 | Notwithstanding the above, nothing herein shall supersede or modify 331 | the terms of any separate license agreement you may have executed 332 | with Licensor regarding such Contributions. 333 | 334 | 6. Trademarks. This License does not grant permission to use the trade 335 | names, trademarks, service marks, or product names of the Licensor, 336 | except as required for reasonable and customary use in describing the 337 | origin of the Work and reproducing the content of the NOTICE file. 338 | 339 | 7. Disclaimer of Warranty. Unless required by applicable law or 340 | agreed to in writing, Licensor provides the Work (and each 341 | Contributor provides its Contributions) on an "AS IS" BASIS, 342 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 343 | implied, including, without limitation, any warranties or conditions 344 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 345 | PARTICULAR PURPOSE. You are solely responsible for determining the 346 | appropriateness of using or redistributing the Work and assume any 347 | risks associated with Your exercise of permissions under this License. 348 | 349 | 8. Limitation of Liability. In no event and under no legal theory, 350 | whether in tort (including negligence), contract, or otherwise, 351 | unless required by applicable law (such as deliberate and grossly 352 | negligent acts) or agreed to in writing, shall any Contributor be 353 | liable to You for damages, including any direct, indirect, special, 354 | incidental, or consequential damages of any character arising as a 355 | result of this License or out of the use or inability to use the 356 | Work (including but not limited to damages for loss of goodwill, 357 | work stoppage, computer failure or malfunction, or any and all 358 | other commercial damages or losses), even if such Contributor 359 | has been advised of the possibility of such damages. 360 | 361 | 9. Accepting Warranty or Additional Liability. While redistributing 362 | the Work or Derivative Works thereof, You may choose to offer, 363 | and charge a fee for, acceptance of support, warranty, indemnity, 364 | or other liability obligations and/or rights consistent with this 365 | License. However, in accepting such obligations, You may act only 366 | on Your own behalf and on Your sole responsibility, not on behalf 367 | of any other Contributor, and only if You agree to indemnify, 368 | defend, and hold each Contributor harmless for any liability 369 | incurred by, or claims asserted against, such Contributor by reason 370 | of your accepting any such warranty or additional liability. 371 | 372 | END OF TERMS AND CONDITIONS 373 | 374 | APPENDIX: How to apply the Apache License to your work. 375 | 376 | To apply the Apache License to your work, attach the following 377 | boilerplate notice, with the fields enclosed by brackets "[]" 378 | replaced with your own identifying information. (Don't include 379 | the brackets!) The text should be enclosed in the appropriate 380 | comment syntax for the file format. We also recommend that a 381 | file or class name and description of purpose be included on the 382 | same "printed page" as the copyright notice for easier 383 | identification within third-party archives. 384 | 385 | Copyright [yyyy] [name of copyright owner] 386 | 387 | Licensed under the Apache License, Version 2.0 (the "License"); 388 | you may not use this file except in compliance with the License. 389 | You may obtain a copy of the License at 390 | 391 | http://www.apache.org/licenses/LICENSE-2.0 392 | 393 | Unless required by applicable law or agreed to in writing, software 394 | distributed under the License is distributed on an "AS IS" BASIS, 395 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 396 | See the License for the specific language governing permissions and 397 | limitations under the License. 398 | 399 | ## Some of TensorFlow's code is derived from Caffe, which is subject to the following copyright notice: 400 | 401 | COPYRIGHT 402 | 403 | All contributions by the University of California: 404 | 405 | Copyright (c) 2014, The Regents of the University of California (Regents) 406 | All rights reserved. 407 | 408 | All other contributions: 409 | 410 | Copyright (c) 2014, the respective contributors 411 | All rights reserved. 412 | 413 | Caffe uses a shared copyright model: each contributor holds copyright over 414 | their contributions to Caffe. The project versioning records all such 415 | contribution and copyright details. If a contributor wants to further mark 416 | their specific copyright on a particular contribution, they should indicate 417 | their copyright solely in the commit message of the change when it is 418 | committed. 419 | 420 | LICENSE 421 | 422 | Redistribution and use in source and binary forms, with or without 423 | modification, are permitted provided that the following conditions are met: 424 | 425 | 1. Redistributions of source code must retain the above copyright notice, this 426 | list of conditions and the following disclaimer. 427 | 428 | 2. Redistributions in binary form must reproduce the above copyright notice, 429 | this list of conditions and the following disclaimer in the documentation 430 | and/or other materials provided with the distribution. 431 | 432 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 433 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 434 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 435 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 436 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 437 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 438 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 439 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 440 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 441 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 442 | 443 | CONTRIBUTION AGREEMENT 444 | 445 | By contributing to the BVLC/caffe repository through pull-request, comment, 446 | or otherwise, the contributor releases their content to the 447 | license and copyright terms herein. 448 | 449 | The HDF Group (HDF5) 450 | HDF5 (Hierarchical Data Format 5) Software Library and Utilities 451 | Copyright 2006-2007 by The HDF Group (THG). 452 | 453 | NCSA HDF5 (Hierarchical Data Format 5) Software Library and Utilities 454 | Copyright 1998-2006 by the Board of Trustees of the University of Illinois. 455 | 456 | All rights reserved. 457 | 458 | Contributors: National Center for Supercomputing Applications (NCSA) 459 | at the University of Illinois, Fortner Software, Unidata Program 460 | Center (netCDF), The Independent JPEG Group (JPEG), Jean-loup Gailly 461 | and Mark Adler (gzip), and Digital Equipment Corporation (DEC). 462 | 463 | Redistribution and use in source and binary forms, with or without 464 | modification, are permitted for any purpose (including commercial 465 | purposes) provided that the following conditions are met: 466 | 467 | 1. Redistributions of source code must retain the above copyright 468 | notice, this list of conditions, and the following disclaimer. 469 | 2. Redistributions in binary form must reproduce the above 470 | copyright notice, this list of conditions, and the following 471 | disclaimer in the documentation and/or materials provided with the 472 | distribution. 473 | 3. In addition, redistributions of modified forms of the source or 474 | binary code must carry prominent notices stating that the original 475 | code was changed and the date of the change. 476 | 4. All publications or advertising materials mentioning features or 477 | use of this software are asked, but not required, to acknowledge that 478 | it was developed by The HDF Group and by the National Center for 479 | Supercomputing Applications at the University of Illinois at 480 | Urbana-Champaign and credit the contributors. 481 | 5. Neither the name of The HDF Group, the name of the University, 482 | nor the name of any Contributor may be used to endorse or promote 483 | products derived from this software without specific prior written 484 | permission from THG, the University, or the Contributor, respectively. 485 | 486 | DISCLAIMER: THIS SOFTWARE IS PROVIDED BY THE HDF GROUP (THG) AND THE 487 | CONTRIBUTORS "AS IS" WITH NO WARRANTY OF ANY KIND, EITHER EXPRESSED OR 488 | IMPLIED. In no event shall THG or the Contributors be liable for any 489 | damages suffered by the users arising out of the use of this software, 490 | even if advised of the possibility of such damage. 491 | 492 | Portions of HDF5 were developed with support from the University of 493 | California, Lawrence Livermore National Laboratory (UC LLNL). The 494 | following statement applies to those portions of the product and must 495 | be retained in any redistribution of source code, binaries, 496 | documentation, and/or accompanying materials: 497 | 498 | This work was partially produced at the University of California, 499 | Lawrence Livermore National Laboratory (UC LLNL) under contract 500 | no. W-7405-ENG-48 (Contract 48) between the U.S. Department of Energy 501 | (DOE) and The Regents of the University of California (University) for 502 | the operation of UC LLNL. 503 | 504 | DISCLAIMER: This work was prepared as an account of work sponsored by 505 | an agency of the United States Government. Neither the United States 506 | Government nor the University of California nor any of their 507 | employees, makes any warranty, express or implied, or assumes any 508 | liability or responsibility for the accuracy, completeness, or 509 | usefulness of any information, apparatus, product, or process 510 | disclosed, or represents that its use would not infringe privately- 511 | owned rights. Reference herein to any specific commercial products, 512 | process, or service by trade name, trademark, manufacturer, or 513 | otherwise, does not necessarily constitute or imply its endorsement, 514 | recommendation, or favoring by the United States Government or the 515 | University of California. The views and opinions of authors expressed 516 | herein do not necessarily state or reflect those of the United States 517 | Government or the University of California, and shall not be used for 518 | advertising or product endorsement purposes. 519 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (C) 2020 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning Annotated Image Repository (FLAIR): A large labelled image dataset for benchmarking in federated learning 2 | 3 | **Paper:** [link](https://proceedings.neurips.cc/paper_files/paper/2022/file/f64e55d03e2fe61aa4114e49cb654acb-Paper-Datasets_and_Benchmarks.pdf) 4 | 5 | **Update:** FLAIR is now available in [pfl-research](https://github.com/apple/pfl-research/tree/develop/benchmarks/flair)! 6 | 7 | FLAIR is a large dataset of images that captures a number of characteristics encountered in federated learning and privacy-preserving ML tasks. 8 | This dataset comprises approximately 430,000 images from 51,000 Flickr users, which will better reflect federated learning problems arising in practice, and it is being released to aid research in the field. 9 | 10 | ![alt text](assets/FLAIR_sample.jpeg) 11 | 12 | ## Image Labels 13 | These images have been annotated by humans and assigned labels from a taxonomy of more than 1,600 fine-grained labels. 14 | All main subjects present in the images have been labeled, so images may have multiple labels. 15 | The taxonomy is hierarchical where the fine-grained labels can be mapped to 17 coarse-grained categories. 16 | The dataset includes both fine-grained and coarse-grained labels so researchers can vary the complexity of a machine learning task. 17 | 18 | ## User Labels and their use for Federated Learning 19 | We have used image metadata to extract artist names/IDs for the purposes of creating user datasets for federated learning. 20 | While optimization algorithms for machine learning are often designed under the assumption that each example is an independent sample from the distribution, federated learning applications deviate from this assumption in a few different ways that are reflected in our user-annotated examples. 21 | Different users differ in the number of images they have, as well as the number of classes represented in their image collection. 22 | Further, images of the same class but taken by different users are likely to have some distribution shift. 23 | These properties of the dataset better reflect federated learning applications, and we expect that benchmark tasks on this dataset will benefit from algorithms designed to handle such data heterogeneity. 24 | 25 | ## Getting Started 26 | ### Prerequisites 27 | Please make sure you have python >= 3.8 and have the required packages installed (see below). 28 | ```sh 29 | python3 -m pip install -r requirements.txt 30 | ``` 31 | ### Download the dataset 32 | Ensure you have a good network connection to download the ~6GB of image data, and enough local space to store and decompress it. 33 | Download the dataset with the following command: 34 | ```sh 35 | python3 download_dataset.py --dataset_dir=/path/to/data 36 | ``` 37 | The images and metadata will be saved to the provided `dataset_dir`. 38 | By default the script will download the down-sized images (size = 256 x 256). 39 | The images are split and compressed into dozens of tarball archives and will be decompressed after downloading. 40 | If you wish to download the full-size raw images, add `--download_raw` flag in the above command.\ 41 | ⚠️ Warning: the raw images take up to ~1.2TB disk space to store after decompressing. 42 | 43 | After downloading and decompressing, the `dataset_dir` will have the following layout: 44 | ``` 45 | dataset_dir 46 | ├── labels_and_metadata.json # a list of labels and metadata for each image 47 | ├── label_relationship.txt # a list of `(fine-grained label, label)` pair 48 | ├── small_images 49 | │ └── *.jpg # all down-sized images 50 | └── raw_images # exists if you added `--download_raw` flag 51 | └── *.jpg # all raw images 52 | ``` 53 | 54 | ### Dataset split 55 | We include a standard train/val/test split in `labels_and_metadata.json`. 56 | The partition is based on user ids with ratio 8:1:1, i.e. train, val and test sets have disjoint users. 57 | Below are the numbers for each partition: 58 | 59 | | Partition | Train | Val | Test | 60 | | ---------------- | ------- | ------ | ------ | 61 | | Number of users | 41,131 | 5,141 | 5,142 | 62 | | Number of images | 345,879 | 39,239 | 43,960 | 63 | 64 | We recommend using the provided split for reproducible benchmarks. 65 | 66 | ### Explore the dataset 67 | Below is an example metadata and label for one image from `labels_and_metadata.json`: 68 | ```json 69 | { 70 | "user_id": "59769174@N00", 71 | "image_id": "14913474848", 72 | "fine_grained_labels": [ 73 | "bag", 74 | "document", 75 | "furniture", 76 | "material", 77 | "printed_page" 78 | ], 79 | "labels": [ 80 | "equipment", 81 | "material", 82 | "structure" 83 | ], 84 | "partition": "train" 85 | } 86 | ``` 87 | Field `image_id` is the Flickr PhotoID and `user_id` is the Flickr NSID that owns the image. 88 | Field `partition` denotes which `train/dev/test` partition the image belongs to. 89 | Field `fine_grained_labels` is a list of annotated labels presenting the subjects in the image and `labels` is the list of coarse-grained labels obtained by mapping fine-grained labels to higher-order categories. 90 | The file `label_relationship.txt` includes the mapping from ~1,600 fine-grained labels to 17 higher-order categories. 91 | 92 | We provide scripts to explore the images and labels in more detail. First you need to start a jupyter notebook: 93 | ```sh 94 | jupyter notebook 95 | ``` 96 | - To explore the downloaded images, open in jupyter notebook [`explore_images.ipynb`](./explore_images.ipynb) which displays the images with corresponding metadata and labels. 97 | - To explore the labels, open in jupyter notebook [`explore_labels.ipynb`](./explore_labels.ipynb) which displays the statistics of the user and label distribution. 98 | 99 | ### (Optional) Prepare the dataset in HDF5 100 | We provide a script to prepare the dataset in HDF5 format for more efficient processing and training: 101 | ```sh 102 | python3 prepare_dataset.py --dataset_dir=/path/to/data --output_file=/path/to/hdf5 103 | ``` 104 | By default the script will group the images and labels by train/val/test split and then by user ids, making it suitable for federated learning experiments. 105 | With the flag `--not_group_data_by_user`, the script will simply group the images and labels by train/val/test split and ignore the user ids, which is the typical setup for centralized training. \ 106 | ⚠️ Warning: the hdf5 file take up to ~80GB disk space to store after processing. 107 | 108 | ## Benchmark FLAIR with TensorFlow Federated 109 | 110 | ### Prepare the dataset in TFRecords 111 | We provide a script to prepare the dataset in TFRecords format for benchmarking with TensorFlow Federated: 112 | ```sh 113 | python3 prepare_tfrecords.py --dataset_dir=/path/to/data --tfrecords_dir=/path/to/tfrecords 114 | ``` 115 | When the above script finishes, the `tfrecords_dir` will have the following layout: 116 | ``` 117 | tfrecords_dir 118 | ├── label_index.json # a mapping from class label to index 119 | ├── train 120 | │ └── .tfrecords # tfrecords for all train users 121 | ├── dev 122 | │ └── .tfrecords # tfrecords for all dev users 123 | └── test 124 | └── .tfrecords # tfrecords for all test users 125 | ``` 126 | ### Training in centralized setting 127 | In centralized setting, user split is ignored and all users' data are concatenated. 128 | Centralized model training can be done in TensorFlow Keras with the following command: 129 | ```sh 130 | python3 -m benchmark.central_main --tfrecords_dir=/path/to/tfrecords 131 | ``` 132 | To view all available arguments, please use the following command: 133 | ```sh 134 | python3 -m benchmark.central_main --help 135 | ``` 136 | Please refer to our [benchmark paper](https://arxiv.org/abs/2207.08869) for the recommended hyperparameters. 137 | 138 | ### Training in federated setting 139 | In federated setting, sampled users train on their own data locally and then share the model updates with the central server. 140 | Federated model training can be simulated in TensorFlow Federated with the following command: 141 | ```sh 142 | python3 -m benchmark.federated_main --tfrecords_dir=/path/to/tfrecords 143 | ``` 144 | To view all available arguments, please use the following command: 145 | ```sh 146 | python3 -m benchmark.federated_main --help 147 | ``` 148 | Please refer to our [benchmark paper](https://arxiv.org/abs/2207.08869) for the recommended hyperparameters. 149 | 150 | ### Training in federated setting with differential privacy 151 | To provide a formal privacy guarantee, we use [DP-SGD](https://arxiv.org/abs/1607.00133) 152 | in the [federated context](https://arxiv.org/abs/1710.06963) which is supported in TensorFlow Federated. 153 | The following command enables federated learning with differential privacy: 154 | ```sh 155 | python3 -m benchmark.federated_main --tfrecords_dir=/path/to/tfrecords --epsilon=2.0 --l2_norm_clip=0.1 156 | ``` 157 | where `epsilon` is the privacy budget and `l2_norm_clip` is the L2 norm clipping bound for Gaussian mechanism. 158 | By default, we use [adaptive clipping](https://arxiv.org/abs/1905.03871v3) to tune the L2 norm clipping bound automatically by setting `--target_unclipped_quantile=0.1`. 159 | 160 | ### Fine-tuning a pretrained ImageNet model 161 | Above commands are all for training from a random initialized model. 162 | We also provide a ResNet model pretrained on ImageNet, which can be downloaded with the following command: 163 | ```sh 164 | wget -O /path/to/model https://docs-assets.developer.apple.com/ml-research/datasets/flair/models/resnet18.h5 165 | ``` 166 | The pretrained model is originally from [torch vision](https://download.pytorch.org/models/resnet18-f37072fd.pth) and converted to Keras format. 167 | To use the pretrained model, please add the argument `--restore_model_path=/path/to/model` in the above training commands. 168 | 169 | ### Training a binary classifier for a single label 170 | By default, we train a multi-label classification model where the output is a multi-hot vector indicating which labels presented in the input image. 171 | We also provide the option to train a simpler binary classifier for a single label. 172 | For example, adding the argument `--binary_label=structure` trains a model only to predict whether `structure` label presented in the image. 173 | 174 | ## Disclaimer 175 | The annotations and Apple’s other rights in the dataset are licensed under CC-BY-NC 4.0 license. 176 | The images are copyright of the respective owners, the license terms of which can be found using the links provided in ATTRIBUTIONS.TXT (by matching the Image ID). 177 | Apple makes no representations or warranties regarding the license status of each image and you should verify the license for each image yourself. 178 | 179 | ## Citing FLAIR 180 | 181 | ``` 182 | @article{song2022flair, 183 | title={FLAIR: Federated Learning Annotated Image Repository}, 184 | author={Song, Congzheng and Granqvist, Filip and Talwar, Kunal}, 185 | journal={Advances in Neural Information Processing Systems}, 186 | volume={35}, 187 | pages={37792--37805}, 188 | year={2022} 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /assets/FLAIR_sample.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-flair/afb9203d9b230ed56bcf7c617774bd7edc6e9529/assets/FLAIR_sample.jpeg -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | -------------------------------------------------------------------------------- /benchmark/central_main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import atexit 7 | import os 8 | import tensorflow as tf 9 | from absl import app 10 | from absl import flags 11 | from absl import logging 12 | 13 | from . import flair_data, flair_metrics, flair_model 14 | 15 | # Central training hyperparameters 16 | flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate') 17 | flags.DEFINE_float('clipnorm', 10.0, 'Max L2 norm for gradient of each weight.') 18 | flags.DEFINE_integer('train_batch_size', 512, 'Batch size on the clients.') 19 | # Training loop configuration 20 | flags.DEFINE_integer('num_epochs', 100, 'Number of total training rounds.') 21 | flags.DEFINE_integer('eval_batch_size', 512, 22 | 'Batch size when evaluating on central datasets.') 23 | # Model configuration 24 | flags.DEFINE_string('restore_model_path', None, 'Path to pretrained model.') 25 | flags.DEFINE_string( 26 | 'save_model_dir', './', 'Path to directory for saving model.') 27 | # Data configuration 28 | flags.DEFINE_string('tfrecords_dir', None, 'Path to FLAIR tfrecords.') 29 | flags.DEFINE_integer('image_height', 224, 'Height of input image.') 30 | flags.DEFINE_integer('image_width', 224, 'Width of input image.') 31 | flags.DEFINE_boolean('use_fine_grained_labels', False, 32 | 'use_fine_grained_labels.') 33 | flags.DEFINE_string( 34 | 'binary_label', None, 35 | 'If set, train a binary classification model on the provided binary label.') 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 1: 42 | raise app.UsageError('Expected no command-line arguments, ' 43 | 'got: {}'.format(argv)) 44 | 45 | image_shape = (256, 256, 3) 46 | label_to_index = flair_data.load_label_to_index( 47 | os.path.join(FLAGS.tfrecords_dir, "label_to_index.json"), 48 | FLAGS.use_fine_grained_labels) 49 | num_labels = len(label_to_index) 50 | 51 | binary_label_index = None 52 | if FLAGS.binary_label is not None: 53 | binary_label_index = label_to_index[FLAGS.binary_label] 54 | 55 | train_fed_data, val_fed_data, test_fed_data = flair_data.load_tfrecords_data( 56 | FLAGS.tfrecords_dir, 57 | image_shape=image_shape, 58 | num_labels=num_labels, 59 | use_fine_grained_labels=FLAGS.use_fine_grained_labels, 60 | binary_label_index=binary_label_index) 61 | 62 | if binary_label_index is not None: 63 | num_labels = 1 64 | 65 | logging.info( 66 | "{} training users, {} validating users".format( 67 | len(train_fed_data.client_ids), len(val_fed_data.client_ids))) 68 | 69 | def preprocess_fn(data: tf.data.Dataset, 70 | is_training: bool) -> tf.data.Dataset: 71 | """Preprocesses `tf.data.Dataset` by shuffling and batching.""" 72 | if is_training: 73 | return data.shuffle(10000).batch(FLAGS.train_batch_size) 74 | else: 75 | return data.batch(FLAGS.eval_batch_size) 76 | 77 | train_data = preprocess_fn( 78 | train_fed_data.create_tf_dataset_from_all_clients(), is_training=True) 79 | val_data = preprocess_fn( 80 | val_fed_data.create_tf_dataset_from_all_clients(), is_training=False) 81 | test_data = preprocess_fn( 82 | test_fed_data.create_tf_dataset_from_all_clients(), is_training=False) 83 | 84 | strategy = tf.distribute.MirroredStrategy() 85 | # To prevent OSError: [Errno 9] Bad file descriptor 86 | # https://github.com/tensorflow/tensorflow/issues/50487 87 | atexit.register(strategy._extended._collective_ops._pool.close) 88 | logging.info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) 89 | 90 | # Open a strategy scope. 91 | with strategy.scope(): 92 | model = flair_model.resnet18( 93 | input_shape=image_shape, 94 | num_classes=num_labels, 95 | pretrained=FLAGS.restore_model_path is not None) 96 | if FLAGS.restore_model_path is not None: 97 | logging.info("Loading pretrained weights from {}".format( 98 | FLAGS.restore_model_path)) 99 | model.load_weights( 100 | FLAGS.restore_model_path, skip_mismatch=True, by_name=True) 101 | model.compile( 102 | loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 103 | metrics=flair_metrics.metrics_builder(num_labels), 104 | optimizer=tf.keras.optimizers.Adam( 105 | FLAGS.learning_rate, clipnorm=FLAGS.clipnorm)) 106 | 107 | os.makedirs(FLAGS.save_model_dir, exist_ok=True) 108 | save_model_path = os.path.join( 109 | FLAGS.save_model_dir, f"resnet18_central_{num_labels}labels.h5") 110 | model_ckpt_callback = tf.keras.callbacks.ModelCheckpoint( 111 | filepath=save_model_path, 112 | save_weights_only=True, 113 | monitor='val_loss', 114 | mode='min', 115 | save_best_only=True) 116 | 117 | logging.info('Training model:') 118 | logging.info(model.summary()) 119 | 120 | model.fit( 121 | train_data, 122 | epochs=FLAGS.num_epochs, 123 | batch_size=FLAGS.train_batch_size, 124 | validation_data=val_data, 125 | validation_batch_size=FLAGS.eval_batch_size, 126 | callbacks=[model_ckpt_callback]) 127 | 128 | model.load_weights(save_model_path, by_name=True) 129 | # final dev evaluation 130 | logging.info("Evaluating best model on val set.") 131 | val_metrics = model.evaluate( 132 | val_data, batch_size=FLAGS.eval_batch_size, return_dict=True) 133 | val_metrics = {'final val ' + k: v for k, v in 134 | flair_metrics.flatten_metrics(val_metrics).items()} 135 | flair_metrics.print_metrics(val_metrics) 136 | 137 | # final test evaluation 138 | logging.info("Evaluating best model on test set.") 139 | test_metrics = model.evaluate( 140 | test_data, batch_size=FLAGS.eval_batch_size, return_dict=True) 141 | test_metrics = {'final test ' + k: v for k, v in 142 | flair_metrics.flatten_metrics(test_metrics).items()} 143 | flair_metrics.print_metrics(test_metrics) 144 | 145 | 146 | if __name__ == '__main__': 147 | app.run(main) 148 | -------------------------------------------------------------------------------- /benchmark/federated_main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import atexit 7 | import functools 8 | import os 9 | import tensorflow as tf 10 | import tensorflow_federated as tff 11 | import time 12 | from absl import app 13 | from absl import flags 14 | from absl import logging 15 | from tensorflow_privacy.privacy.analysis import compute_noise_from_budget_lib 16 | from typing import Any, Callable, Optional, Dict 17 | 18 | from . import flair_data, flair_metrics, flair_model 19 | 20 | # Defining optimizer flags 21 | flags.DEFINE_float('client_learning_rate', 0.1, 'Client local learning rate') 22 | flags.DEFINE_float( 23 | 'client_clipnorm', 10.0, 24 | 'Max L2 norm for gradient of each weight. ' 25 | 'This is used to prevent gradient explosion in client local training') 26 | flags.DEFINE_float('server_learning_rate', 0.1, 'Server learning_rate') 27 | 28 | # Federated training hyperparameters 29 | flags.DEFINE_integer('client_epochs_per_round', 2, 30 | 'Number of epochs in the client to take per round.') 31 | flags.DEFINE_integer('client_batch_size', 16, 'Batch size on the clients.') 32 | flags.DEFINE_integer('clients_per_round', 200, 33 | 'How many clients to sample per round.') 34 | flags.DEFINE_integer('clients_per_thread', 50, 35 | 'How many clients to sample per thread.') 36 | flags.DEFINE_integer('client_datasets_random_seed', None, 37 | 'Random seed for client sampling.') 38 | # Training loop configuration 39 | flags.DEFINE_integer('total_rounds', 5000, 'Number of total training rounds.') 40 | flags.DEFINE_integer( 41 | 'rounds_per_eval', 10, 42 | 'How often to evaluate the global model on the validation dataset.') 43 | flags.DEFINE_integer('max_elements_per_client', 512, 44 | 'Max number of training examples to use per client.') 45 | flags.DEFINE_integer('eval_batch_size', 512, 46 | 'Batch size when evaluating on central datasets.') 47 | # Model configuration 48 | flags.DEFINE_string('restore_model_path', None, 'Path to pretrained model.') 49 | flags.DEFINE_string( 50 | 'save_model_dir', './', 'Path to directory for saving model.') 51 | # Data configuration 52 | flags.DEFINE_string('tfrecords_dir', None, 'Path to FLAIR tfrecords.') 53 | flags.DEFINE_integer('image_height', 224, 'Height of input image.') 54 | flags.DEFINE_integer('image_width', 224, 'Width of input image.') 55 | flags.DEFINE_boolean('use_fine_grained_labels', False, 56 | 'use_fine_grained_labels.') 57 | flags.DEFINE_string( 58 | 'binary_label', None, 59 | 'If set, train a binary classification model on the provided binary label.') 60 | # Differential privacy configuration 61 | flags.DEFINE_float('epsilon', 0.0, 'DP epsilon.') 62 | flags.DEFINE_float('l2_norm_clip', 0.1, 'DP clipping bound.') 63 | flags.DEFINE_float( 64 | 'target_unclipped_quantile', 0.1, 65 | 'Quantile for adaptive clipping bound. Value 0 turns off adaptive clipping') 66 | flags.DEFINE_integer( 67 | 'simulated_clients_per_round', None, 68 | 'A simulated `clients_per_round` for experimenting DP more efficiently.' 69 | 'If set larger than `clients_per_round`, the DP noise scale will be the ' 70 | 'same as if training with `simulated_clients_per_round` clients when only ' 71 | '`clients_per_round` clients are actually sampled. See detailed description' 72 | ' in Section 5.1 of https://arxiv.org/abs/2207.08869') 73 | 74 | FLAGS = flags.FLAGS 75 | 76 | 77 | def main(argv): 78 | if len(argv) > 1: 79 | raise app.UsageError('Expected no command-line arguments, ' 80 | 'got: {}'.format(argv)) 81 | 82 | gpu_devices = tf.config.list_logical_devices('GPU') 83 | if len(gpu_devices) > 0: 84 | tff.backends.native.set_local_python_execution_context( 85 | default_num_clients=FLAGS.clients_per_round, 86 | max_fanout=2 * FLAGS.clients_per_round, 87 | server_tf_device=tf.config.list_logical_devices('CPU')[0], 88 | client_tf_devices=gpu_devices, 89 | clients_per_thread=FLAGS.clients_per_thread) 90 | 91 | client_optimizer_fn = lambda: tf.keras.optimizers.SGD( 92 | FLAGS.client_learning_rate, clipnorm=FLAGS.client_clipnorm) 93 | server_optimizer_fn = lambda: tf.keras.optimizers.Adam( 94 | FLAGS.server_learning_rate, epsilon=0.01) 95 | 96 | image_shape = (256, 256, 3) 97 | label_to_index = flair_data.load_label_to_index( 98 | os.path.join(FLAGS.tfrecords_dir, "label_to_index.json"), 99 | FLAGS.use_fine_grained_labels) 100 | num_labels = len(label_to_index) 101 | 102 | binary_label_index = None 103 | if FLAGS.binary_label is not None: 104 | binary_label_index = label_to_index[FLAGS.binary_label] 105 | 106 | train_fed_data, val_fed_data, test_fed_data = flair_data.load_tfrecords_data( 107 | FLAGS.tfrecords_dir, 108 | image_shape=image_shape, 109 | num_labels=num_labels, 110 | use_fine_grained_labels=FLAGS.use_fine_grained_labels, 111 | binary_label_index=binary_label_index) 112 | 113 | if binary_label_index is not None: 114 | num_labels = 1 115 | 116 | logging.info( 117 | "{} training users, {} validating users".format( 118 | len(train_fed_data.client_ids), len(val_fed_data.client_ids))) 119 | 120 | model_update_aggregation_factory = None 121 | if FLAGS.epsilon > 0.0: 122 | # Setup TFF with differential privacy 123 | n = len(train_fed_data.client_ids) 124 | if FLAGS.simulated_clients_per_round is not None: 125 | assert FLAGS.simulated_clients_per_round >= FLAGS.clients_per_round 126 | batch_size = FLAGS.simulated_clients_per_round 127 | else: 128 | batch_size = FLAGS.clients_per_round 129 | 130 | # Compute central DP noise scale added to aggregated model updates 131 | noise_multiplier = compute_noise_from_budget_lib.compute_noise( 132 | n=n, 133 | batch_size=batch_size, 134 | target_epsilon=FLAGS.epsilon, 135 | epochs=FLAGS.total_rounds, 136 | delta=1 / (n ** 1.1), 137 | noise_lbd=1e-5) 138 | 139 | # Simulate the noise level of large cohort with small cohort 140 | if FLAGS.simulated_clients_per_round is not None: 141 | noise_multiplier = (noise_multiplier / 142 | FLAGS.simulated_clients_per_round * 143 | FLAGS.clients_per_round) 144 | 145 | logging.info("DP noise multiplier: {:.2f}".format(noise_multiplier)) 146 | if FLAGS.target_unclipped_quantile == 0.0: 147 | model_update_aggregation_factory = tff.aggregators. \ 148 | DifferentiallyPrivateFactory.gaussian_fixed( 149 | noise_multiplier=noise_multiplier, 150 | clients_per_round=FLAGS.clients_per_round, 151 | clip=FLAGS.l2_norm_clip) 152 | else: 153 | logging.info("Use adaptive clipping for L2 norm clip") 154 | model_update_aggregation_factory = tff.aggregators. \ 155 | DifferentiallyPrivateFactory.gaussian_adaptive( 156 | noise_multiplier=noise_multiplier, 157 | clients_per_round=FLAGS.clients_per_round, 158 | initial_l2_norm_clip=FLAGS.l2_norm_clip, 159 | target_unclipped_quantile=FLAGS.target_unclipped_quantile) 160 | 161 | # Add DP related metrics to report 162 | model_update_aggregation_factory = tff.learning.add_debug_measurements( 163 | model_update_aggregation_factory) 164 | 165 | def iterative_process_builder( 166 | model_fn: Callable[[], tff.learning.Model], 167 | client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, 168 | ) -> tff.templates.IterativeProcess: 169 | """Creates an iterative process using a given TFF `model_fn`.""" 170 | return tff.learning.build_federated_averaging_process( 171 | model_fn=model_fn, 172 | client_optimizer_fn=client_optimizer_fn, 173 | server_optimizer_fn=server_optimizer_fn, 174 | client_weighting=client_weight_fn, 175 | model_update_aggregation_factory=model_update_aggregation_factory, 176 | use_experimental_simulation_loop=True) 177 | 178 | model_builder = functools.partial( 179 | flair_model.resnet18, 180 | input_shape=image_shape, 181 | num_classes=num_labels, 182 | pretrained=FLAGS.restore_model_path is not None) 183 | 184 | loss_builder = functools.partial( 185 | tf.keras.losses.BinaryCrossentropy, from_logits=True) 186 | 187 | metrics_builder = functools.partial( 188 | flair_metrics.metrics_builder, num_labels=num_labels) 189 | 190 | def preprocess_fn(data: tf.data.Dataset, 191 | is_training: bool) -> tf.data.Dataset: 192 | """Preprocesses `tf.data.Dataset` by shuffling and batching.""" 193 | if is_training: 194 | data = data.shuffle(FLAGS.max_elements_per_client) 195 | # Repeat data by client epochs and batch 196 | return data.take(FLAGS.max_elements_per_client).repeat( 197 | FLAGS.client_epochs_per_round).batch(FLAGS.client_batch_size) 198 | else: 199 | return data.batch(FLAGS.eval_batch_size) 200 | 201 | train_fed_data = train_fed_data.preprocess( 202 | functools.partial(preprocess_fn, is_training=True)) 203 | input_spec = train_fed_data.element_type_structure 204 | val_data = preprocess_fn( 205 | val_fed_data.create_tf_dataset_from_all_clients(), is_training=False) 206 | test_data = preprocess_fn( 207 | test_fed_data.create_tf_dataset_from_all_clients(), is_training=False) 208 | 209 | def tff_model_fn() -> tff.learning.Model: 210 | """Wraps a tensorflow model to TFF model.""" 211 | return tff.learning.from_keras_model(keras_model=model_builder(), 212 | input_spec=input_spec, 213 | loss=loss_builder(), 214 | metrics=metrics_builder()) 215 | 216 | iterative_process = iterative_process_builder( 217 | tff_model_fn, client_weight_fn=None) 218 | 219 | # training_process accepts client ids as input 220 | training_process = tff.simulation.compose_dataset_computation_with_iterative_process( 221 | train_fed_data.dataset_computation, iterative_process) 222 | 223 | client_ids_fn = functools.partial( 224 | tff.simulation.build_uniform_sampling_fn( 225 | train_fed_data.client_ids, 226 | replace=False, 227 | random_seed=FLAGS.client_datasets_random_seed), 228 | size=FLAGS.clients_per_round) 229 | # We convert the output to a list (instead of an np.ndarray) so that it can 230 | # be used as input to the iterative process. 231 | client_sampling_fn = lambda x: list(client_ids_fn(x)) 232 | 233 | # Build central Keras model for evaluation 234 | strategy = tf.distribute.MirroredStrategy() 235 | # To prevent OSError: [Errno 9] Bad file descriptor 236 | # https://github.com/tensorflow/tensorflow/issues/50487 237 | atexit.register(strategy._extended._collective_ops._pool.close) 238 | 239 | # Open a strategy scope. 240 | with strategy.scope(): 241 | eval_model = model_builder() 242 | eval_model.compile(loss=loss_builder(), metrics=metrics_builder()) 243 | 244 | def evaluation_fn(state, eval_data: tf.data.Dataset) -> Dict: 245 | """Evaluate TFF model state on `eval_data`""" 246 | state.model.assign_weights_to(eval_model) 247 | eval_metrics = eval_model.evaluate( 248 | eval_data, 249 | verbose=0, 250 | batch_size=FLAGS.eval_batch_size, 251 | return_dict=True) 252 | return flair_metrics.flatten_metrics(eval_metrics) 253 | 254 | logging.info('Training model:') 255 | logging.info(model_builder().summary()) 256 | 257 | state = training_process.initialize() 258 | if FLAGS.restore_model_path is not None: 259 | logging.info("Loading pretrained weights from {}".format( 260 | FLAGS.restore_model_path)) 261 | pretrained_model = model_builder() 262 | pretrained_model.load_weights( 263 | FLAGS.restore_model_path, skip_mismatch=True, by_name=True) 264 | # Load our pre-trained model weights into the global model state. 265 | state = tff.learning.state_with_new_model_weights( 266 | state, 267 | trainable_weights=[ 268 | v.numpy() for v in pretrained_model.trainable_weights], 269 | non_trainable_weights=[ 270 | v.numpy() for v in pretrained_model.non_trainable_weights 271 | ]) 272 | 273 | round_num = 0 274 | loop_start_time = time.time() 275 | best_val_loss = float('inf') 276 | 277 | os.makedirs(FLAGS.save_model_dir, exist_ok=True) 278 | save_model_path = os.path.join( 279 | FLAGS.save_model_dir, f"resnet18_federated_{num_labels}labels.h5") 280 | 281 | # Main training loop 282 | while round_num < FLAGS.total_rounds: 283 | data_prep_start_time = time.time() 284 | sampled_clients = client_sampling_fn(round_num) 285 | metrics = {'prepare datasets secs': time.time() - data_prep_start_time} 286 | 287 | state, round_metrics = training_process.next(state, sampled_clients) 288 | metrics.update(flair_metrics.flatten_metrics(round_metrics)) 289 | logging.info('Round {:2d}, {:.2f}s per round in average.'.format( 290 | round_num, (time.time() - loop_start_time) / (round_num + 1))) 291 | 292 | if (round_num + 1) % FLAGS.rounds_per_eval == 0: 293 | # Compute evaluation metrics 294 | val_metrics = evaluation_fn(state, val_data) 295 | metrics.update({'val ' + k: v for k, v in val_metrics.items()}) 296 | # Save model if current iteration has better val metrics 297 | current_val_loss = float(val_metrics["loss"]) 298 | if current_val_loss < best_val_loss: 299 | logging.info(f"Saving current best model to {save_model_path}") 300 | eval_model.save(save_model_path) 301 | best_val_loss = current_val_loss 302 | 303 | metrics['duration of iteration'] = time.time() - data_prep_start_time 304 | flair_metrics.print_metrics(metrics, iteration=round_num) 305 | round_num += 1 306 | 307 | eval_model.load_weights(save_model_path, by_name=True) 308 | # final dev evaluation 309 | logging.info("Evaluating best model on val set.") 310 | val_metrics = eval_model.evaluate( 311 | val_data, batch_size=FLAGS.eval_batch_size, return_dict=True) 312 | val_metrics = {'final val ' + k: v for k, v in 313 | flair_metrics.flatten_metrics(val_metrics).items()} 314 | flair_metrics.print_metrics(val_metrics) 315 | 316 | # final test evaluation 317 | logging.info("Evaluating best model on test set.") 318 | test_metrics = eval_model.evaluate( 319 | test_data, batch_size=FLAGS.eval_batch_size, return_dict=True) 320 | test_metrics = {'final test ' + k: v for k, v in 321 | flair_metrics.flatten_metrics(test_metrics).items()} 322 | flair_metrics.print_metrics(test_metrics) 323 | 324 | 325 | if __name__ == '__main__': 326 | app.run(main) 327 | -------------------------------------------------------------------------------- /benchmark/flair_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import functools 7 | import json 8 | import os 9 | from typing import Dict, Tuple, Optional 10 | 11 | import tensorflow as tf 12 | import tensorflow_federated as tff 13 | 14 | 15 | KEY_IMAGE_BYTES = 'image/encoded_jpeg' 16 | KEY_IMAGE_DECODED = 'image/decoded' 17 | KEY_LABELS = 'labels' 18 | KEY_FINE_GRAINED_LABELS = 'fine_grained_labels' 19 | 20 | 21 | def load_tfrecords( 22 | filename: str, 23 | image_shape: Tuple[int, int, int], 24 | num_labels: int, 25 | use_fine_grained_labels: bool, 26 | binary_label_index: Optional[int], 27 | ) -> tf.data.Dataset: 28 | """Load tfrecords from `filename` and return a `tf.data.Dataset`""" 29 | dataset = tf.data.TFRecordDataset([filename]) 30 | key_labels = KEY_FINE_GRAINED_LABELS if use_fine_grained_labels else KEY_LABELS 31 | 32 | def parse(example_proto): 33 | """Parse an example to image and label in tensorflow tensor format.""" 34 | feature_description = { 35 | KEY_IMAGE_BYTES: tf.io.FixedLenFeature([], tf.string), 36 | key_labels: tf.io.VarLenFeature(tf.int64), 37 | } 38 | example = tf.io.parse_single_example(example_proto, feature_description) 39 | image = tf.io.decode_jpeg(example[KEY_IMAGE_BYTES]) 40 | labels = tf.reduce_sum( 41 | tf.one_hot( 42 | example[key_labels].values, depth=num_labels, dtype=tf.int32), 43 | axis=0) 44 | if binary_label_index is not None: 45 | labels = labels[binary_label_index] 46 | return tf.reshape(image, image_shape), labels 47 | 48 | return dataset.map(parse, tf.data.AUTOTUNE) 49 | 50 | 51 | def load_label_to_index(label_to_index_file: str, 52 | use_fine_grained_labels: bool) -> Dict[str, int]: 53 | """ 54 | Load label to index mapping. 55 | 56 | :param label_to_index_file: 57 | Path to json file that has the label to index mapping. 58 | :param use_fine_grained_labels: 59 | Whether to load mapping for fine-grained labels. 60 | 61 | :return: 62 | A dictionary that maps label to index. 63 | """ 64 | with open(label_to_index_file) as f: 65 | return json.load(f)[ 66 | "fine_grained_labels" if use_fine_grained_labels else "labels"] 67 | 68 | 69 | def load_tfrecords_data( 70 | tfrecords_dir: str, 71 | image_shape: Tuple[int, int, int], 72 | num_labels: int, 73 | use_fine_grained_labels: bool, 74 | binary_label_index: Optional[int] = None 75 | ) -> Tuple[tff.simulation.datasets.FilePerUserClientData, ...]: 76 | """ 77 | Load tfrecords data into TFF format. 78 | 79 | :param tfrecords_dir: 80 | Directory with all tfrecords saved, processed by `prepapre_tfrecords.py`. 81 | :param image_shape: 82 | 3D tuple indicating shape of image [height, weight, channels]. 83 | :param num_labels: 84 | Number of labels. 85 | :param use_fine_grained_labels: 86 | Whether to use fine-grained labels. 87 | :param binary_label_index: 88 | Optional integer. If set, label will be a binary value for the given 89 | `binary_label_index`, and other label indices will be ignored. 90 | 91 | :return: 92 | A tuple of three `tff.simulation.datasets.FilePerUserClientData` object 93 | for train, val and test set respectively. 94 | """ 95 | def get_client_ids_to_files(partition: str): 96 | """Get the tfrecords filenames for a train/val/test partition""" 97 | partition_dir = os.path.join(tfrecords_dir, partition) 98 | partition_client_files = os.listdir(partition_dir) 99 | return { 100 | client_file.split(".tfrecords")[0]: os.path.join( 101 | partition_dir, client_file) 102 | for client_file in partition_client_files 103 | } 104 | 105 | return tuple([ 106 | tff.simulation.datasets.FilePerUserClientData( 107 | client_ids_to_files=get_client_ids_to_files(partition), 108 | dataset_fn=functools.partial( 109 | load_tfrecords, 110 | image_shape=image_shape, 111 | num_labels=num_labels, 112 | use_fine_grained_labels=use_fine_grained_labels, 113 | binary_label_index=binary_label_index) 114 | ) 115 | for partition in ['train', 'val', 'test'] 116 | ]) 117 | -------------------------------------------------------------------------------- /benchmark/flair_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import sys 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from keras.utils import metrics_utils 11 | from typing import List, Optional, Dict 12 | 13 | 14 | class ConfusionMatrixMetrics(tf.keras.metrics.AUC): 15 | """ 16 | Base class for metrics based on confusion matrix, including precision, 17 | recall, F1 score and averaged precision. 18 | Please refer https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC 19 | for arguments description. 20 | """ 21 | def __init__(self, 22 | num_labels: int, 23 | multi_label: bool, 24 | num_thresholds: int = 200, 25 | name: Optional[str] = None, 26 | dtype: Optional[tf.DType] = None, 27 | thresholds: Optional[List[float]] = None, 28 | label_weights: Optional[List[float]] = None, 29 | from_logits: bool = False): 30 | 31 | if not multi_label: 32 | num_labels = None 33 | 34 | self._num_labels = num_labels 35 | if isinstance(self, (Precision, Recall, F1)): 36 | thresholds = metrics_utils.parse_init_thresholds( 37 | thresholds, default_threshold=0.5) 38 | 39 | super(ConfusionMatrixMetrics, self).__init__( 40 | num_thresholds=num_thresholds, 41 | curve='ROC', 42 | summation_method='interpolation', 43 | name=name, 44 | dtype=dtype, 45 | thresholds=thresholds, 46 | multi_label=multi_label, 47 | num_labels=num_labels, 48 | label_weights=label_weights, 49 | from_logits=from_logits) 50 | 51 | def result(self): 52 | raise NotImplementedError( 53 | "ConfusionMatrixMetrics does not return any result") 54 | 55 | def get_config(self): 56 | config = super(ConfusionMatrixMetrics, self).get_config() 57 | # Pop unrelated arguments 58 | config.pop("curve") 59 | config.pop("summation_method") 60 | # Add arguments in __init__ to pass TFF metrics builder checks 61 | config['thresholds'] = self.thresholds[1:-1] 62 | config['num_labels'] = self._num_labels 63 | config['from_logits'] = self._from_logits 64 | return config 65 | 66 | def update_state(self, *args, **kwargs): 67 | # Return None to pass TFF metrics builder checks 68 | super(ConfusionMatrixMetrics, self).update_state(*args, **kwargs) 69 | return 70 | 71 | def get_precision(self): 72 | tp, fp = self.true_positives, self.false_positives 73 | return tf.squeeze(tf.math.divide_no_nan(tp, tf.math.add(tp, fp))) 74 | 75 | def get_recall(self): 76 | tp, fn = self.true_positives, self.false_negatives 77 | return tf.squeeze(tf.math.divide_no_nan(tp, tf.math.add(tp, fn))) 78 | 79 | def get_macro_average(self, by_label_metrics): 80 | assert self.multi_label 81 | if self.label_weights is None: 82 | macro_average = tf.reduce_mean(by_label_metrics) 83 | else: 84 | macro_average = tf.reduce_sum(by_label_metrics * self.label_weights 85 | ) / tf.reduce_sum(self.label_weights) 86 | return macro_average 87 | 88 | 89 | class AveragedPrecision(ConfusionMatrixMetrics): 90 | """ 91 | Averaged precision metrics. Implementation follows sklearn in 92 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html 93 | """ 94 | def get_averaged_precision(self): 95 | precision = self.get_precision() 96 | recall = self.get_recall() 97 | recall_diff = recall[1:] - recall[:-1] 98 | return -tf.reduce_sum(recall_diff * precision[:-1], axis=0) 99 | 100 | def result(self): 101 | averaged_precision = self.get_averaged_precision() 102 | if self.multi_label: 103 | averaged_precision = self.get_macro_average(averaged_precision) 104 | return averaged_precision 105 | 106 | 107 | class Precision(ConfusionMatrixMetrics): 108 | """Precision metrics""" 109 | def result(self): 110 | precision = self.get_precision()[1:-1] 111 | if self.multi_label: 112 | precision = self.get_macro_average(precision) 113 | return precision 114 | 115 | 116 | class Recall(ConfusionMatrixMetrics): 117 | """Recall metrics""" 118 | def result(self): 119 | recall = self.get_recall()[1:-1] 120 | if self.multi_label: 121 | recall = self.get_macro_average(recall) 122 | return recall 123 | 124 | 125 | class F1(ConfusionMatrixMetrics): 126 | """F1 score metrics""" 127 | def result(self): 128 | precision = self.get_precision()[1:-1] 129 | recall = self.get_recall()[1:-1] 130 | if self.multi_label: 131 | precision = self.get_macro_average(precision) 132 | recall = self.get_macro_average(recall) 133 | return tf.math.divide_no_nan( 134 | 2 * precision * recall, tf.math.add(precision, recall)) 135 | 136 | 137 | def metrics_builder(num_labels: int) -> List[tf.keras.metrics.Metric]: 138 | """ 139 | Build a list of metrics to track during training. 140 | 141 | :param num_labels: 142 | Number of class labels. 143 | :return: 144 | A list of `tf.keras.metrics.Metric` object. 145 | """ 146 | metrics = [ 147 | tf.keras.metrics.BinaryCrossentropy(name="loss", from_logits=True), 148 | ] 149 | metric_classes = [Precision, Recall, F1, AveragedPrecision] 150 | metric_names = ["precision", "recall", "f1", "averaged_precision"] 151 | for metric_class, metric_name in zip(metric_classes, metric_names): 152 | if num_labels == 1: 153 | metrics.append( 154 | metric_class(num_labels=num_labels, multi_label=False, 155 | from_logits=True, name=metric_name)) 156 | else: 157 | metrics.extend( 158 | [ 159 | metric_class(num_labels=num_labels, multi_label=True, 160 | from_logits=True, name=f"macro_{metric_name}"), 161 | metric_class(num_labels=num_labels, multi_label=False, 162 | from_logits=True, name=f"micro_{metric_name}") 163 | ]) 164 | return metrics 165 | 166 | 167 | def flatten_metrics(nested_metrics, prefix: Optional[str] = None 168 | ) -> Dict[str, float]: 169 | """ 170 | Flatten a nested metrics structure to a dictionary where key is the metric 171 | name and value is the metric value in float. 172 | 173 | :param nested_metrics: 174 | A nested metrics structure. 175 | :param prefix: 176 | Optional prefix description for a metrics. 177 | 178 | :return: 179 | Dictionary of metrics names and metrics values. 180 | """ 181 | flattened_metrics = {} 182 | 183 | if isinstance(nested_metrics, dict): 184 | for key, value in nested_metrics.items(): 185 | flattened_metrics.update(flatten_metrics( 186 | value, str(key) if prefix is None else f'{prefix} {key}')) 187 | elif isinstance(nested_metrics, (np.ndarray, list, tuple)): 188 | if len(nested_metrics) == 1: 189 | flattened_metrics[prefix] = float(nested_metrics[0]) 190 | else: 191 | for index, value in enumerate(nested_metrics): 192 | flattened_metrics.update(flatten_metrics( 193 | value, str(index) if prefix is None else f'{prefix} {index}' 194 | )) 195 | else: 196 | flattened_metrics[prefix] = float(nested_metrics) 197 | return flattened_metrics 198 | 199 | 200 | def print_metrics(metrics: Dict[str, float], iteration: Optional[int] = None): 201 | """ 202 | Print a dictionary of metrics names and metrics values. 203 | 204 | :param metrics: 205 | A dictionary of metrics names and metrics values. 206 | :param iteration: 207 | Optional value indicating training iteration. 208 | """ 209 | if iteration is not None: 210 | sys.stdout.write('Metrics at iteration {}:\n'.format(iteration)) 211 | for key, value in metrics.items(): 212 | sys.stdout.write(' {:<50}: {}\n'.format(key, value)) 213 | -------------------------------------------------------------------------------- /benchmark/flair_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import tensorflow as tf 7 | from typing import Optional, Callable, List, Tuple 8 | 9 | from keras.layers import ( 10 | RandomCrop, RandomFlip, Normalization, Rescaling, 11 | Conv2D, ZeroPadding2D, ReLU, MaxPooling2D, BatchNormalization) 12 | import tensorflow_addons.layers.normalizations as tfa_norms 13 | 14 | 15 | # ImageNet statistics from https://pytorch.org/vision/stable/models.html 16 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 17 | IMAGENET_VARIANCE = [0.229 ** 2, 0.224 ** 2, 0.225 ** 2] 18 | 19 | 20 | class FrozenBatchNormalization(BatchNormalization): 21 | """ 22 | BatchNormalization layer that freezes the moving mean and average updates. 23 | It is intended to be used in fine-tuning a pretrained model in federated 24 | learning setting, where the moving mean and average will be assigned to 25 | the ones in the pretrained model. Only beta and gamma are updated. 26 | """ 27 | def call(self, inputs, training=None): 28 | inputs_dtype = inputs.dtype.base_dtype 29 | if inputs_dtype in (tf.float16, tf.bfloat16): 30 | inputs = tf.cast(inputs, tf.float32) 31 | 32 | # Compute the axes along which to reduce the mean / variance 33 | input_shape = inputs.shape 34 | ndims = len(input_shape) 35 | reduction_axes = [i for i in range(ndims) if i not in self.axis] 36 | 37 | # Broadcasting only necessary for single-axis batch norm where the axis 38 | # is not the last dimension 39 | broadcast_shape = [1] * ndims 40 | broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value 41 | 42 | def _broadcast(v): 43 | if (v is not None and len(v.shape) != ndims and 44 | reduction_axes != list(range(ndims - 1))): 45 | return tf.reshape(v, broadcast_shape) 46 | return v 47 | 48 | scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 49 | # use pretrained moving_mean and moving_variance for normalization 50 | mean, variance = self.moving_mean, self.moving_variance 51 | mean = tf.cast(mean, inputs.dtype) 52 | variance = tf.cast(variance, inputs.dtype) 53 | if offset is not None: 54 | offset = tf.cast(offset, inputs.dtype) 55 | if scale is not None: 56 | scale = tf.cast(scale, inputs.dtype) 57 | outputs = tf.nn.batch_normalization(inputs, _broadcast(mean), 58 | _broadcast(variance), offset, scale, 59 | self.epsilon) 60 | if inputs_dtype in (tf.float16, tf.bfloat16): 61 | outputs = tf.cast(outputs, inputs_dtype) 62 | 63 | # If some components of the shape got lost due to adjustments, fix that. 64 | outputs.set_shape(input_shape) 65 | return outputs 66 | 67 | 68 | def conv3x3(x: tf.Tensor, scope: str, out_planes: int, stride: int = 1, 69 | groups: int = 1, dilation: int = 1, seed: int = 0): 70 | """3x3 convolution with padding""" 71 | x = ZeroPadding2D(padding=(dilation, dilation), name=f"{scope}_padding")(x) 72 | return Conv2D( 73 | out_planes, 74 | kernel_size=3, 75 | strides=stride, 76 | groups=groups, 77 | use_bias=False, 78 | dilation_rate=dilation, 79 | name=f"{scope}_3x3", 80 | kernel_initializer=tf.keras.initializers.HeNormal(seed=seed), 81 | )(x) 82 | 83 | 84 | def conv1x1(x: tf.Tensor, scope: str, out_planes: int, stride: int = 1, 85 | seed: int = 0): 86 | """1x1 convolution""" 87 | return Conv2D( 88 | out_planes, 89 | kernel_size=1, 90 | strides=stride, 91 | use_bias=False, 92 | name=f"{scope}_1x1", 93 | kernel_initializer=tf.keras.initializers.HeNormal(seed=seed), 94 | )(x) 95 | 96 | 97 | def norm(x: tf.Tensor, scope: str, use_batch_norm: bool): 98 | """Normalization layer""" 99 | if use_batch_norm: 100 | return FrozenBatchNormalization(axis=3, epsilon=1e-5, name=scope)(x) 101 | else: 102 | return tfa_norms.GroupNormalization(epsilon=1e-5, name=scope)(x) 103 | 104 | 105 | def relu(x: tf.Tensor, scope: str): 106 | """ReLU activation layer""" 107 | return ReLU(name=scope)(x) 108 | 109 | 110 | def basic_block(x: tf.Tensor, scope: str, out_planes: int, use_batch_norm: bool, 111 | stride: int = 1, downsample: Optional[Callable] = None, 112 | seed: int = 0): 113 | """Basic ResNet block""" 114 | out = conv3x3(x, f"{scope}_conv1", out_planes, stride, seed=seed) 115 | out = norm(out, scope=f"{scope}_norm1", use_batch_norm=use_batch_norm) 116 | out = relu(out, f"{scope}_relu1") 117 | out = conv3x3(out, f"{scope}_conv2", out_planes, seed=seed) 118 | out = norm(out, scope=f"{scope}_norm2", use_batch_norm=use_batch_norm) 119 | if downsample is not None: 120 | x = downsample(x) 121 | out += x 122 | out = relu(out, f"{scope}_relu2") 123 | return out 124 | 125 | 126 | def block_layers( 127 | x: tf.Tensor, 128 | scope: str, 129 | in_planes: int, 130 | out_planes: int, 131 | blocks: int, 132 | use_batch_norm: bool, 133 | stride: int = 1, 134 | seed: int = 0, 135 | ): 136 | """Layers of ResNet block""" 137 | downsample = None 138 | if stride != 1 or in_planes != out_planes: 139 | # Downsample is performed when stride > 1 according to Section 3.3 in 140 | # https://arxiv.org/pdf/1512.03385.pdf 141 | def downsample(h: tf.Tensor): 142 | h = conv1x1(h, f"{scope}_downsample_conv", out_planes, stride) 143 | return norm(h, f"{scope}_downsample_norm", use_batch_norm) 144 | 145 | x = basic_block(x, f"{scope}_block1", out_planes, use_batch_norm, stride, 146 | downsample, seed=seed) 147 | for i in range(1, blocks): 148 | x = basic_block(x, f"{scope}_block{i + 1}", out_planes, use_batch_norm, 149 | seed=seed) 150 | return x 151 | 152 | 153 | def create_resnet(input_shape: Tuple[int, int, int], 154 | num_classes: int, 155 | use_batch_norm: bool, 156 | repetitions: List[int] = None, 157 | initial_filters: int = 64, 158 | seed: int = 0): 159 | """ 160 | Creates a ResNet Keras model. Implementation follows torchvision in 161 | https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 162 | """ 163 | img_input = tf.keras.layers.Input(shape=input_shape) 164 | 165 | # preprocessing layer 166 | x = RandomCrop(height=224, width=224)(img_input) 167 | x = RandomFlip()(x) 168 | x = Rescaling(scale=1 / 255.0)(x) 169 | x = Normalization( 170 | axis=-1, mean=IMAGENET_MEAN, variance=IMAGENET_VARIANCE)(x) 171 | 172 | # initial conv layer 173 | x = ZeroPadding2D((3, 3), name="initial_padding")(x) 174 | x = Conv2D( 175 | initial_filters, 176 | kernel_size=7, strides=2, use_bias=False, name="initial_conv", 177 | kernel_initializer=tf.keras.initializers.HeNormal(seed=seed))(x) 178 | x = norm(x, scope="initial_norm", use_batch_norm=use_batch_norm) 179 | x = relu(x, scope="initial_relu") 180 | x = ZeroPadding2D((1, 1), name="pooling_padding")(x) 181 | x = MaxPooling2D(pool_size=3, strides=2, name="initial_pooling")(x) 182 | 183 | # residual blocks 184 | x = block_layers(x, "layer1", initial_filters, 64, repetitions[0], 185 | use_batch_norm, seed=seed) 186 | x = block_layers(x, "layer2", initial_filters, 128, repetitions[1], 187 | use_batch_norm, 2, seed=seed) 188 | x = block_layers(x, "layer3", initial_filters, 256, repetitions[2], 189 | use_batch_norm, 2, seed=seed) 190 | x = block_layers(x, "layer4", initial_filters, 512, repetitions[3], 191 | use_batch_norm, 2, seed=seed) 192 | 193 | # classification layers 194 | x = tf.keras.layers.GlobalAveragePooling2D(name="global_pooling")(x) 195 | x = tf.keras.layers.Dense( 196 | num_classes, name="classifier", 197 | kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x) 198 | model = tf.keras.models.Model(img_input, x) 199 | return model 200 | 201 | 202 | def resnet18(input_shape: Tuple[int, int, int], 203 | num_classes: int, 204 | pretrained: bool, 205 | seed: int = 0): 206 | """ 207 | Creates a ResNet18 keras model. 208 | 209 | :param input_shape: 210 | Input image shape in [height, weight, channels.] 211 | :param num_classes: 212 | Number of output classes. 213 | :param pretrained: 214 | Whether the model is pretrained on ImageNet. If true, model will use 215 | BatchNormalization. If false, model will use GroupNormalization in order 216 | to train with differential privacy. 217 | :param seed: 218 | Random seed for initialize the weights. 219 | 220 | :return: 221 | A ResNet18 keras model 222 | """ 223 | model = create_resnet( 224 | input_shape, 225 | num_classes, 226 | use_batch_norm=pretrained, 227 | repetitions=[2, 2, 2, 2], 228 | seed=seed) 229 | return model 230 | -------------------------------------------------------------------------------- /download_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import os 7 | import sys 8 | import argparse 9 | import subprocess 10 | import multiprocessing 11 | import functools 12 | import logging 13 | 14 | from urllib.parse import urljoin 15 | from urllib.request import urlretrieve 16 | 17 | 18 | logger = logging.getLogger(name=__name__) 19 | 20 | DATA_URL = "https://docs-assets.developer.apple.com/ml-research/datasets/flair/" 21 | NUM_IMAGE_BATCHES = 43 22 | SMALL_IMAGE_URLS = [ 23 | urljoin(DATA_URL, f"images/small/small_images-{str(i).zfill(2)}.tar.gz") 24 | for i in range(NUM_IMAGE_BATCHES)] 25 | RAW_IMAGE_URLS = [ 26 | urljoin(DATA_URL, f"images/raw/images-{str(i).zfill(2)}.tar.gz") 27 | for i in range(NUM_IMAGE_BATCHES)] 28 | LABELS_AND_METADATA_URL = urljoin(DATA_URL, "labels/labels_and_metadata.json") 29 | LABEL_RELATIONSHIP_URL = urljoin(DATA_URL, "labels/label_relationship.txt") 30 | 31 | 32 | def extract_tar(compressed_path: str, dataset_dir: str, 33 | keep_archive_after_decompress: bool): 34 | subprocess.run(f"tar -zxf {compressed_path} -C {dataset_dir}".split(), 35 | check=True) 36 | if not keep_archive_after_decompress: 37 | os.remove(compressed_path) 38 | 39 | 40 | def decompress_images(dataset_dir: str, keep_archive_after_decompress: bool): 41 | compressed_paths = [os.path.join(dataset_dir, path) 42 | for path in os.listdir(dataset_dir) 43 | if path.endswith(".tar.gz")] 44 | decompress = functools.partial( 45 | extract_tar, 46 | dataset_dir=dataset_dir, 47 | keep_archive_after_decompress=keep_archive_after_decompress) 48 | with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: 49 | pool.map(decompress, compressed_paths) 50 | 51 | 52 | if __name__ == '__main__': 53 | logging.basicConfig( 54 | stream=sys.stdout, 55 | level=logging.INFO, 56 | format='%(asctime)s %(levelname)s: %(message)s') 57 | 58 | parser = argparse.ArgumentParser( 59 | description='Download the images and labels of FLAIR dataset.') 60 | parser.add_argument("--dataset_dir", required=True, 61 | help="Path to directory of dataset to be downloaded") 62 | parser.add_argument("--download_raw", action="store_true", 63 | help="Whether to download the raw images, " 64 | "which need storage space ~1.2TB") 65 | parser.add_argument("--keep_archive_after_decompress", action="store_true", 66 | help="Whether to keep the image tarball archives") 67 | arguments = parser.parse_args() 68 | os.makedirs(arguments.dataset_dir, exist_ok=True) 69 | 70 | # download labels and metadata 71 | logger.info("Downloading labels...") 72 | urlretrieve( 73 | LABELS_AND_METADATA_URL, os.path.join( 74 | arguments.dataset_dir, os.path.basename(LABELS_AND_METADATA_URL))) 75 | urlretrieve( 76 | LABEL_RELATIONSHIP_URL, os.path.join( 77 | arguments.dataset_dir, os.path.basename(LABEL_RELATIONSHIP_URL))) 78 | # download and decompress all images 79 | for image_url in SMALL_IMAGE_URLS: 80 | logger.info("Downloading small image: {}".format(image_url)) 81 | urlretrieve(image_url, os.path.join( 82 | arguments.dataset_dir, os.path.basename(image_url))) 83 | if arguments.download_raw: 84 | for image_url in RAW_IMAGE_URLS: 85 | logger.info("Downloading raw image: {}".format(image_url)) 86 | urlretrieve(image_url, os.path.join( 87 | arguments.dataset_dir, os.path.basename(image_url))) 88 | logger.info("Decompressing images...") 89 | decompress_images(arguments.dataset_dir, 90 | arguments.keep_archive_after_decompress) 91 | -------------------------------------------------------------------------------- /explore_labels.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "80725a8d", 6 | "metadata": {}, 7 | "source": [ 8 | "## Explore FLAIR labels\n", 9 | "This notebook will explore the distribution of users and labels in FLAIR dataset.\n", 10 | "The `labels_and_metadata.json` file contains a list of `user_id`, `labels` and `fine_grained_labels` for each image. \n", 11 | "\n", 12 | "`fine_grained_labels` are human annotated labels from a taxonomy of 1,628 categories.\n", 13 | "`labels` reference higher-order categories than `fine_grained_labels`: the 1,628 `fine_grained_labels` map to 17 `labels`.\n", 14 | "\n", 15 | "For example, `fine_grained_labels: [\"waffle\", \"bread\"]` map to `labels: [\"food\"]`.\n", 16 | "It is expected that `fine_grained_labels` also contain the 17 coarse-grained categories as the subjects in an image might not be categorized into finer granularity.\n", 17 | "\n", 18 | "### Load the labels and metadata" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "9a7b5346", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Loaded metadata and labels for 429078 images\n", 32 | "Example metadata and labels for one image:\n", 33 | "{\n", 34 | " \"user_id\": \"59769174@N00\",\n", 35 | " \"image_id\": \"14913474848\",\n", 36 | " \"labels\": [\n", 37 | " \"equipment\",\n", 38 | " \"material\",\n", 39 | " \"structure\"\n", 40 | " ],\n", 41 | " \"partition\": \"train\",\n", 42 | " \"fine_grained_labels\": [\n", 43 | " \"bag\",\n", 44 | " \"document\",\n", 45 | " \"furniture\",\n", 46 | " \"material\",\n", 47 | " \"printed_page\"\n", 48 | " ]\n", 49 | "}\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "import json\n", 55 | "import os\n", 56 | "\n", 57 | "dataset_dir = \"../flair-data/\" # replace with the path to directory that you downloaded the dataset\n", 58 | "\n", 59 | "with open(os.path.join(dataset_dir, \"labels_and_metadata.json\")) as f:\n", 60 | " metadata_list = json.load(f)\n", 61 | "\n", 62 | "print(f\"Loaded metadata and labels for {len(metadata_list)} images\")\n", 63 | "print(f\"Example metadata and labels for one image:\\n\" + json.dumps(metadata_list[0], indent=4))" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "15b85ab7", 69 | "metadata": {}, 70 | "source": [ 71 | "### Count labels over users and images" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 2, 77 | "id": "0daf9341", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "17 unique labels\n", 85 | "1628 unique fine-grained labels\n", 86 | "51414 users\n", 87 | "Number of train/val/test images: 345879/39239/43960\n", 88 | "Number of train/val/test users: 41131/5141/5142\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "from collections import Counter, defaultdict\n", 94 | "\n", 95 | "# Counter for image label statistics\n", 96 | "image_label_counter = []\n", 97 | "image_fine_grained_counter = []\n", 98 | "\n", 99 | "# Counters for label statistics\n", 100 | "label_counter = Counter()\n", 101 | "fine_grained_label_counter = Counter()\n", 102 | "\n", 103 | "# Counters for user label statistics\n", 104 | "user_image_counter = Counter()\n", 105 | "user_label_counter = defaultdict(Counter)\n", 106 | "user_fine_grained_label_counter = defaultdict(Counter)\n", 107 | "\n", 108 | "n_train_images, n_val_images, n_test_images = 0, 0, 0\n", 109 | "train_users, val_users, test_users = set(), set(), set()\n", 110 | "\n", 111 | "for metadata in metadata_list:\n", 112 | " image_label_counter.append(len(metadata[\"labels\"]))\n", 113 | " image_fine_grained_counter.append(len(metadata[\"fine_grained_labels\"]))\n", 114 | " # Increment count for overall label distribution\n", 115 | " label_counter.update(metadata[\"labels\"])\n", 116 | " fine_grained_label_counter.update(metadata[\"fine_grained_labels\"])\n", 117 | " # Increment count for user label distribution\n", 118 | " user_image_counter[metadata[\"user_id\"]] += 1\n", 119 | " user_label_counter[metadata[\"user_id\"]].update(metadata[\"labels\"])\n", 120 | " user_fine_grained_label_counter[metadata[\"user_id\"]].update(metadata[\"fine_grained_labels\"])\n", 121 | " # train/val/test counts\n", 122 | " if metadata[\"partition\"] == \"train\":\n", 123 | " train_users.add(metadata[\"user_id\"])\n", 124 | " n_train_images += 1\n", 125 | " if metadata[\"partition\"] == \"val\":\n", 126 | " val_users.add(metadata[\"user_id\"])\n", 127 | " n_val_images += 1\n", 128 | " if metadata[\"partition\"] == \"test\":\n", 129 | " test_users.add(metadata[\"user_id\"])\n", 130 | " n_test_images += 1\n", 131 | "\n", 132 | "print(f\"{len(label_counter)} unique labels\\n\" \n", 133 | " f\"{len(fine_grained_label_counter)} unique fine-grained labels\\n\" \n", 134 | " f\"{len(user_image_counter)} users\")\n", 135 | "print(f\"Number of train/val/test images: {n_train_images}/{n_val_images}/{n_test_images}\")\n", 136 | "print(f\"Number of train/val/test users: {len(train_users)}/{len(val_users)}/{len(test_users)}\")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "5561b2df", 142 | "metadata": {}, 143 | "source": [ 144 | "### Per-image label statistics\n", 145 | "\n", 146 | "The table below displays the per-image label statistics. \n", 147 | "On average, each image has 2.79 labels and 4.61 fine-grained labels. " 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 3, 153 | "id": "9a775ac1", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/html": [ 159 | "
\n", 160 | "\n", 173 | "\n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | "
label countfine-grained label count
count429078.000000429078.000000
mean2.7852494.614781
std1.1492182.729608
min1.0000001.000000
25%2.0000003.000000
50%3.0000004.000000
75%4.0000006.000000
max9.00000036.000000
\n", 224 | "
" 225 | ], 226 | "text/plain": [ 227 | " label count fine-grained label count\n", 228 | "count 429078.000000 429078.000000\n", 229 | "mean 2.785249 4.614781\n", 230 | "std 1.149218 2.729608\n", 231 | "min 1.000000 1.000000\n", 232 | "25% 2.000000 3.000000\n", 233 | "50% 3.000000 4.000000\n", 234 | "75% 4.000000 6.000000\n", 235 | "max 9.000000 36.000000" 236 | ] 237 | }, 238 | "execution_count": 3, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "import pandas as pd\n", 245 | "\n", 246 | "pd.DataFrame(zip(image_label_counter, image_fine_grained_counter), columns=['label count', 'fine-grained label count']).describe()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "id": "eca0e5c3", 252 | "metadata": {}, 253 | "source": [ 254 | "### Histogram of coarse-grained labels\n", 255 | "The table below shows the counts of the 17 higher-order labels overall images." 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 4, 261 | "id": "27ea353e", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/html": [ 267 | "
\n", 268 | "\n", 281 | "\n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | "
labelcount
0structure228923
1equipment225862
2material139733
3outdoor131322
4plant123363
5food110792
6animal68858
7liquid68677
8art37230
9interior_room32042
10light11326
11recreation5651
12celebration3769
13fire3011
14music2571
15games1093
16religion866
\n", 377 | "
" 378 | ], 379 | "text/plain": [ 380 | " label count\n", 381 | "0 structure 228923\n", 382 | "1 equipment 225862\n", 383 | "2 material 139733\n", 384 | "3 outdoor 131322\n", 385 | "4 plant 123363\n", 386 | "5 food 110792\n", 387 | "6 animal 68858\n", 388 | "7 liquid 68677\n", 389 | "8 art 37230\n", 390 | "9 interior_room 32042\n", 391 | "10 light 11326\n", 392 | "11 recreation 5651\n", 393 | "12 celebration 3769\n", 394 | "13 fire 3011\n", 395 | "14 music 2571\n", 396 | "15 games 1093\n", 397 | "16 religion 866" 398 | ] 399 | }, 400 | "execution_count": 4, 401 | "metadata": {}, 402 | "output_type": "execute_result" 403 | } 404 | ], 405 | "source": [ 406 | "pd.DataFrame(label_counter.most_common(), columns=['label', 'count'])" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "id": "a91b6034", 412 | "metadata": {}, 413 | "source": [ 414 | "### Histogram of the top 20 fine-grained labels\n", 415 | "The table below shows the counts of the 20 most common fine-grained labels overall images." 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 5, 421 | "id": "f78ed4a7", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/html": [ 427 | "
\n", 428 | "\n", 441 | "\n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | "
Head fine-grained labelcount
0wood_processed82336
1material74807
2structure61517
3grass44741
4plate42392
5plant42009
6blue_sky41043
7foliage40793
8cloudy36225
9tree33984
10textile33709
11building32349
12rocks28850
13container25189
14land24715
15sky24669
16fence24249
17food23752
18interior_room23254
19vegetation23106
\n", 552 | "
" 553 | ], 554 | "text/plain": [ 555 | " Head fine-grained label count\n", 556 | "0 wood_processed 82336\n", 557 | "1 material 74807\n", 558 | "2 structure 61517\n", 559 | "3 grass 44741\n", 560 | "4 plate 42392\n", 561 | "5 plant 42009\n", 562 | "6 blue_sky 41043\n", 563 | "7 foliage 40793\n", 564 | "8 cloudy 36225\n", 565 | "9 tree 33984\n", 566 | "10 textile 33709\n", 567 | "11 building 32349\n", 568 | "12 rocks 28850\n", 569 | "13 container 25189\n", 570 | "14 land 24715\n", 571 | "15 sky 24669\n", 572 | "16 fence 24249\n", 573 | "17 food 23752\n", 574 | "18 interior_room 23254\n", 575 | "19 vegetation 23106" 576 | ] 577 | }, 578 | "execution_count": 5, 579 | "metadata": {}, 580 | "output_type": "execute_result" 581 | } 582 | ], 583 | "source": [ 584 | "sorted_fine_grained_labels = fine_grained_label_counter.most_common()\n", 585 | "pd.DataFrame(sorted_fine_grained_labels[:20], columns=['Head fine-grained label', 'count'])" 586 | ] 587 | }, 588 | { 589 | "cell_type": "markdown", 590 | "id": "2795d6e3", 591 | "metadata": {}, 592 | "source": [ 593 | "### Per-user image statistics\n", 594 | "The table below displays the per-user image statistics. \n", 595 | "On average, each user has 8.34 images. The distribution is head-heavy and long-tailed, where the median of images per user is 2." 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": 6, 601 | "id": "d6f267b6", 602 | "metadata": {}, 603 | "outputs": [ 604 | { 605 | "data": { 606 | "text/html": [ 607 | "
\n", 608 | "\n", 621 | "\n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | "
user image counts
count51414.000000
mean8.345548
std51.275980
min1.000000
25%1.000000
50%2.000000
75%5.000000
max4151.000000
\n", 663 | "
" 664 | ], 665 | "text/plain": [ 666 | " user image counts\n", 667 | "count 51414.000000\n", 668 | "mean 8.345548\n", 669 | "std 51.275980\n", 670 | "min 1.000000\n", 671 | "25% 1.000000\n", 672 | "50% 2.000000\n", 673 | "75% 5.000000\n", 674 | "max 4151.000000" 675 | ] 676 | }, 677 | "execution_count": 6, 678 | "metadata": {}, 679 | "output_type": "execute_result" 680 | } 681 | ], 682 | "source": [ 683 | "pd.DataFrame(user_image_counter.values(), columns=[\"user image counts\"]).describe()" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "id": "914b0863", 689 | "metadata": {}, 690 | "source": [ 691 | "### Per-user label statistics\n", 692 | "The table below displays the per-user image statistics. \n", 693 | "On average, each user has 4.6 distinct labels." 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 7, 699 | "id": "f6a3a00a", 700 | "metadata": {}, 701 | "outputs": [ 702 | { 703 | "data": { 704 | "text/html": [ 705 | "
\n", 706 | "\n", 719 | "\n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | "
user label counts
count51414.000000
mean4.608647
std2.811109
min1.000000
25%3.000000
50%4.000000
75%6.000000
max17.000000
\n", 761 | "
" 762 | ], 763 | "text/plain": [ 764 | " user label counts\n", 765 | "count 51414.000000\n", 766 | "mean 4.608647\n", 767 | "std 2.811109\n", 768 | "min 1.000000\n", 769 | "25% 3.000000\n", 770 | "50% 4.000000\n", 771 | "75% 6.000000\n", 772 | "max 17.000000" 773 | ] 774 | }, 775 | "execution_count": 7, 776 | "metadata": {}, 777 | "output_type": "execute_result" 778 | } 779 | ], 780 | "source": [ 781 | "user_num_labels = [len(counter) for counter in user_label_counter.values()]\n", 782 | "pd.DataFrame(user_num_labels, columns=[\"user label counts\"]).describe()" 783 | ] 784 | }, 785 | { 786 | "cell_type": "markdown", 787 | "id": "f17863ac", 788 | "metadata": {}, 789 | "source": [ 790 | "### Per-user fine-grained label statistics\n", 791 | "The table below displays the per-user image statistics. \n", 792 | "On average, each user has 16.8 distinct fine-grained labels." 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 8, 798 | "id": "0fd1047c", 799 | "metadata": {}, 800 | "outputs": [ 801 | { 802 | "data": { 803 | "text/html": [ 804 | "
\n", 805 | "\n", 818 | "\n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | "
user fine-grained label counts
count51414.000000
mean16.813242
std32.666895
min1.000000
25%4.000000
50%7.000000
75%16.000000
max838.000000
\n", 860 | "
" 861 | ], 862 | "text/plain": [ 863 | " user fine-grained label counts\n", 864 | "count 51414.000000\n", 865 | "mean 16.813242\n", 866 | "std 32.666895\n", 867 | "min 1.000000\n", 868 | "25% 4.000000\n", 869 | "50% 7.000000\n", 870 | "75% 16.000000\n", 871 | "max 838.000000" 872 | ] 873 | }, 874 | "execution_count": 8, 875 | "metadata": {}, 876 | "output_type": "execute_result" 877 | } 878 | ], 879 | "source": [ 880 | "user_num_fine_grained_labels = [len(counter) for counter in user_fine_grained_label_counter.values()]\n", 881 | "pd.DataFrame(user_num_fine_grained_labels, columns=[\"user fine-grained label counts\"]).describe()" 882 | ] 883 | } 884 | ], 885 | "metadata": { 886 | "kernelspec": { 887 | "display_name": "Python 3 (ipykernel)", 888 | "language": "python", 889 | "name": "python3" 890 | }, 891 | "language_info": { 892 | "codemirror_mode": { 893 | "name": "ipython", 894 | "version": 3 895 | }, 896 | "file_extension": ".py", 897 | "mimetype": "text/x-python", 898 | "name": "python", 899 | "nbconvert_exporter": "python", 900 | "pygments_lexer": "ipython3", 901 | "version": "3.8.9" 902 | } 903 | }, 904 | "nbformat": 4, 905 | "nbformat_minor": 5 906 | } 907 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | import argparse 7 | import os 8 | import sys 9 | import logging 10 | 11 | import h5py 12 | import json 13 | import numpy as np 14 | 15 | from typing import Dict, List 16 | from collections import defaultdict, Counter 17 | from PIL import Image 18 | 19 | 20 | logger = logging.getLogger(name=__name__) 21 | 22 | LABEL_DELIMITER = '|' # Labels will be joined by delimiter and saved to hdf5 23 | LOG_INTERVAL = 500 # Log the preprocessing progress every interval steps 24 | 25 | 26 | def load_user_metadata(labels_file: str) -> Dict[str, List]: 27 | """ 28 | Load labels and metadata keyed by `user_id`. 29 | 30 | :param labels_file: 31 | A .json file with a list of labels and metadata dictionaries. Each 32 | dictionary has keys: `[image_id,user_id,labels,fine_grained_labels]`. 33 | * `image_id` is the ID of an image. 34 | * `user_id` is the ID of the user `image_id` belongs to. 35 | * `labels` is a list of 17 higher-order class labels. 36 | * `fine_grained_labels` is a list of 1,628 fine-grained class labels. 37 | :return: 38 | A dictionary where key is `user_id` and value is a list of labels and 39 | metadata for each image `user_id` owns. 40 | """ 41 | user_metadata = defaultdict(list) 42 | with open(labels_file) as f: 43 | metadata_list = json.load(f) 44 | 45 | for metadata in metadata_list: 46 | user_metadata[metadata["user_id"]].append(metadata) 47 | return user_metadata 48 | 49 | 50 | def preprocess_federated_dataset(image_dir: str, 51 | labels_file: str, 52 | output_file: str): 53 | """ 54 | Process images and labels into a HDF5 federated dataset where data is 55 | first split by train/test partitions and then split again by user ID. 56 | 57 | :param image_dir: 58 | Path to directory of images output from the script 59 | `download_dataset.sh`. 60 | :param labels_file: 61 | A .json file with a list of labels and metadata dictionaries. Each 62 | dictionary has keys: `[image_id,user_id,labels,fine_grained_labels]`. 63 | * `image_id` is the ID of an image. 64 | * `user_id` is the ID of the user `image_id` belongs to. 65 | * `labels` is a list of 17 higher-order class labels. 66 | * `fine_grained_labels` is a list of ~1,600 fine-grained class labels. 67 | :param output_file: 68 | Output path for HDF5 file. Use the postfix `.hdf5`. 69 | """ 70 | logger.info('Preprocessing federated dataset.') 71 | user_metadata = load_user_metadata(labels_file) 72 | label_counter = Counter() 73 | fine_grained_label_counter = Counter() 74 | with h5py.File(output_file, 'w') as h5file: 75 | # Iterate through users of each partition. 76 | for i, user_id in enumerate(user_metadata): 77 | # Load and concatenate all images of a user. 78 | image_array, image_id_array = [], [] 79 | labels_array, fine_grained_labels_array = [], [] 80 | # Load and concatenate all images and labels of a user. 81 | for metadata in user_metadata[user_id]: 82 | image_id = metadata["image_id"] 83 | image = Image.open( 84 | os.path.join(image_dir, f"{image_id}.jpg")) 85 | image_array.append(np.asarray(image)) 86 | image_id_array.append(image_id) 87 | # Encode labels as a single string, separated by delimiter | 88 | labels_array.append(LABEL_DELIMITER.join(metadata["labels"])) 89 | fine_grained_labels_array.append( 90 | LABEL_DELIMITER.join(metadata["fine_grained_labels"])) 91 | # Update label counter 92 | label_counter.update(metadata["labels"]) 93 | fine_grained_label_counter.update( 94 | metadata["fine_grained_labels"]) 95 | 96 | partition = user_metadata[user_id][0]["partition"] 97 | # Multiple variable-length labels. Needs to be stored as a string. 98 | h5file[f'/{partition}/{user_id}/labels'] = np.asarray( 99 | labels_array, dtype='S') 100 | h5file[f'/{partition}/{user_id}/fine_grained_labels'] = np.asarray( 101 | fine_grained_labels_array, dtype='S') 102 | h5file[f'/{partition}/{user_id}/image_ids'] = np.asarray( 103 | image_id_array, dtype='S') 104 | # Tensor with dimensions [num_images,width,height,channels] 105 | h5file.create_dataset( 106 | f'/{partition}/{user_id}/images', data=np.stack(image_array)) 107 | 108 | if (i + 1) % LOG_INTERVAL == 0: 109 | logger.info("Processed {0}/{1} users".format( 110 | i + 1, len(user_metadata))) 111 | 112 | # Write metadata 113 | h5file['/metadata/label_counter'] = json.dumps(label_counter) 114 | h5file['/metadata/fine_grained_label_counter'] = json.dumps( 115 | fine_grained_label_counter) 116 | 117 | logger.info('Finished preprocess federated dataset successfully!') 118 | 119 | 120 | def preprocess_central_dataset(image_dir: str, 121 | labels_file: str, 122 | output_file: str): 123 | """ 124 | Process images and labels into a HDF5 (not federated) dataset where 125 | data is split by train/val/test partitions. 126 | 127 | Same parameters as `preprocess_federated_dataset`. 128 | """ 129 | logger.info('Preprocessing central dataset.') 130 | user_metadata = load_user_metadata(labels_file) 131 | label_counter = Counter() 132 | fine_grained_label_counter = Counter() 133 | with h5py.File(output_file, 'w') as h5file: 134 | # Iterate through users of each partition. 135 | for i, user_id in enumerate(user_metadata): 136 | # Load and concatenate all images of a user. 137 | for metadata in user_metadata[user_id]: 138 | image_id = metadata["image_id"] 139 | image = Image.open( 140 | os.path.join(image_dir, f"{image_id}.jpg")) 141 | partition = metadata["partition"] 142 | h5file.create_dataset( 143 | f'/{partition}/{image_id}/image', data=np.asarray(image)) 144 | # Encode labels as a single string, separated by delimiter | 145 | h5file[f'/{partition}/{image_id}/labels'] = LABEL_DELIMITER.join( 146 | metadata["labels"]) 147 | h5file[f'/{partition}/{image_id}/fine_grained_labels'] = ( 148 | LABEL_DELIMITER.join(metadata["fine_grained_labels"])) 149 | h5file[f'/{partition}/{image_id}/user_id'] = user_id 150 | # Update label counter 151 | label_counter.update(metadata["labels"]) 152 | fine_grained_label_counter.update( 153 | metadata["fine_grained_labels"]) 154 | 155 | if (i + 1) % LOG_INTERVAL == 0: 156 | logger.info("Processed {0}/{1} users".format( 157 | i + 1, len(user_metadata))) 158 | 159 | # Write metadata 160 | h5file['/metadata/label_counter'] = json.dumps(label_counter) 161 | h5file['/metadata/fine_grained_label_counter'] = json.dumps( 162 | fine_grained_label_counter) 163 | 164 | logger.info('Finished preprocessing central dataset successfully!') 165 | 166 | 167 | if __name__ == '__main__': 168 | logging.basicConfig( 169 | stream=sys.stdout, 170 | level=logging.INFO, 171 | format='%(asctime)s %(levelname)s: %(message)s') 172 | 173 | argument_parser = argparse.ArgumentParser( 174 | description= 175 | 'Preprocess the images and labels of FLAIR dataset into HDF5 files.') 176 | argument_parser.add_argument( 177 | '--dataset_dir', 178 | required=True, 179 | help='Path to directory of images and label file. ' 180 | 'Can be downloaded using download_dataset.py') 181 | argument_parser.add_argument( 182 | '--output_file', 183 | required=True, 184 | help='Path to output HDF5 file that will be constructed by this script' 185 | ) 186 | argument_parser.add_argument( 187 | '--not_group_data_by_user', 188 | action='store_true', 189 | default=False, 190 | help='If true, do not group data by user IDs.' 191 | 'If false, group data by user IDs to ' 192 | 'make suitable for federated learning.') 193 | arguments = argument_parser.parse_args() 194 | 195 | image_dir = os.path.join(arguments.dataset_dir, "small_images") 196 | labels_file = os.path.join(arguments.dataset_dir, "labels_and_metadata.json") 197 | if arguments.not_group_data_by_user: 198 | preprocess_central_dataset(image_dir, labels_file, arguments.output_file) 199 | else: 200 | preprocess_federated_dataset(image_dir, labels_file, arguments.output_file) 201 | -------------------------------------------------------------------------------- /prepare_tfrecords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 5 | 6 | from typing import Dict, List, Tuple 7 | from collections import defaultdict, Counter 8 | 9 | import os 10 | import sys 11 | import argparse 12 | import json 13 | import tensorflow as tf 14 | import logging 15 | 16 | logger = logging.getLogger(name=__name__) 17 | 18 | KEY_IMAGE_BYTES = 'image/encoded_jpeg' 19 | KEY_IMAGE_DECODED = 'image/decoded' 20 | KEY_LABELS = 'labels' 21 | KEY_FINE_GRAINED_LABELS = 'fine_grained_labels' 22 | LOG_INTERVAL = 500 # Log the preprocessing progress every interval steps 23 | 24 | 25 | def load_user_metadata_and_label_counters( 26 | labels_file: str) -> Tuple[Dict, Counter, Counter]: 27 | """ 28 | Load labels and metadata keyed by `user_id`, and label counts. 29 | 30 | :param labels_file: 31 | A .json file with a list of labels and metadata dictionaries. Each 32 | dictionary has keys: `[image_id,user_id,labels,fine_grained_labels]`. 33 | * `image_id` is the ID of an image. 34 | * `user_id` is the ID of the user `image_id` belongs to. 35 | * `labels` is a list of 17 higher-order class labels. 36 | * `fine_grained_labels` is a list of 1,628 fine-grained class labels. 37 | :return: 38 | Three dictionaries. First dictionary has key being `user_id` and value 39 | being a list of labels and metadata for each image `user_id` owns. 40 | Second and third dictionaries are counts for the labels for coarse-grained 41 | and fine-grained taxonomies. 42 | """ 43 | user_metadata = defaultdict(list) 44 | with open(labels_file) as f: 45 | metadata_list = json.load(f) 46 | 47 | label_counter = Counter() 48 | fine_grained_label_counter = Counter() 49 | for metadata in metadata_list: 50 | user_metadata[metadata["user_id"]].append(metadata) 51 | label_counter.update(metadata["labels"]) 52 | fine_grained_label_counter.update(metadata["fine_grained_labels"]) 53 | return user_metadata, label_counter, fine_grained_label_counter 54 | 55 | 56 | def create_example( 57 | image_bytes: bytes, 58 | labels: List[int], 59 | fine_grained_labels: List[int] 60 | ) -> tf.train.Example: 61 | """Create a `tf.train.Example` for a given image and labels""" 62 | features = { 63 | KEY_IMAGE_BYTES: tf.train.Feature( 64 | bytes_list=tf.train.BytesList(value=[image_bytes])), 65 | KEY_LABELS: tf.train.Feature( 66 | int64_list=tf.train.Int64List(value=labels)), 67 | KEY_FINE_GRAINED_LABELS: tf.train.Feature( 68 | int64_list=tf.train.Int64List(value=fine_grained_labels)) 69 | } 70 | return tf.train.Example(features=tf.train.Features(feature=features)) 71 | 72 | 73 | def preprocess_federated_dataset(image_dir: str, 74 | labels_file: str, 75 | tfrecords_dir: str): 76 | """ 77 | Process images and labels into tfrecords where data is first split by 78 | train/test partitions and then split again by user ID. Label to index mapping 79 | will be saved to `label_to_index.json` in `tfrecords_dir`. 80 | 81 | :param image_dir: 82 | Path to directory of images output from the script 83 | `download_dataset.sh`. 84 | :param labels_file: 85 | A .json file with a list of labels and metadata dictionaries. Each 86 | dictionary has keys: `[image_id,user_id,labels,fine_grained_labels]`. 87 | * `image_id` is the ID of an image. 88 | * `user_id` is the ID of the user `image_id` belongs to. 89 | * `labels` is a list of 17 higher-order class labels. 90 | * `fine_grained_labels` is a list of ~1,600 fine-grained class labels. 91 | :param tfrecords_dir: 92 | Save directory path for tfrecords. 93 | """ 94 | logger.info('Preprocessing federated tfrecords.') 95 | os.makedirs(tfrecords_dir, exist_ok=True) 96 | (user_metadata, label_counter, 97 | fine_grained_label_counter) = load_user_metadata_and_label_counters(labels_file) 98 | label_to_index = { 99 | label: index for index, label 100 | in enumerate(sorted(label_counter.keys()))} 101 | fine_grained_label_to_index = { 102 | fine_grained_label: index for index, fine_grained_label 103 | in enumerate(sorted(fine_grained_label_counter.keys()))} 104 | 105 | with open(os.path.join(tfrecords_dir, "label_to_index.json"), "w") as f: 106 | json.dump({ 107 | "labels": label_to_index, 108 | "fine_grained_labels": fine_grained_label_to_index 109 | }, f, indent=4) 110 | 111 | for i, user_id in enumerate(user_metadata): 112 | partition = user_metadata[user_id][0]["partition"] 113 | 114 | # Load and concatenate all images and labels of a user. 115 | user_examples = [] 116 | for metadata in user_metadata[user_id]: 117 | image_id = metadata["image_id"] 118 | with open(os.path.join(image_dir, f"{image_id}.jpg"), 119 | 'rb') as f: 120 | image_bytes = f.read() 121 | example = create_example( 122 | image_bytes=image_bytes, 123 | labels=[label_to_index[label] for label in metadata["labels"]], 124 | fine_grained_labels=[ 125 | fine_grained_label_to_index[label] 126 | for label in metadata["fine_grained_labels"] 127 | ]) 128 | user_examples.append(example) 129 | 130 | partition_dir = os.path.join(tfrecords_dir, partition) 131 | os.makedirs(partition_dir, exist_ok=True) 132 | with tf.io.TFRecordWriter(os.path.join( 133 | partition_dir, f'{user_id}.tfrecords')) as writer: 134 | for example in user_examples: 135 | writer.write(example.SerializeToString()) 136 | 137 | if (i + 1) % LOG_INTERVAL == 0: 138 | logger.info("Processed {0}/{1} users".format( 139 | i + 1, len(user_metadata))) 140 | logger.info('Finished preprocess federated tfrecords successfully!') 141 | 142 | 143 | if __name__ == '__main__': 144 | logging.basicConfig( 145 | stream=sys.stdout, 146 | level=logging.INFO, 147 | format='%(asctime)s %(levelname)s: %(message)s') 148 | 149 | argument_parser = argparse.ArgumentParser( 150 | description= 151 | 'Preprocess the images and labels of FLAIR dataset into HDF5 files.') 152 | argument_parser.add_argument( 153 | '--dataset_dir', 154 | required=True, 155 | help='Path to directory of images and label file. ' 156 | 'Can be downloaded using download_dataset.py') 157 | argument_parser.add_argument( 158 | '--tfrecords_dir', 159 | required=True, 160 | help='Path to directory to save output tfrecords.' 161 | ) 162 | arguments = argument_parser.parse_args() 163 | 164 | image_dir = os.path.join(arguments.dataset_dir, "small_images") 165 | labels_file = os.path.join(arguments.dataset_dir, 166 | "labels_and_metadata.json") 167 | preprocess_federated_dataset(image_dir, labels_file, 168 | arguments.tfrecords_dir) 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py~=3.1.0 2 | notebook~=6.4.10 3 | numpy~=1.21.4 4 | Pillow~=8.2.0 5 | tensorflow==2.8.0 6 | tensorflow-federated==0.20.0 7 | tensorflow-probability==0.16.0 8 | tensorflow_addons~=0.17.1 9 | pandas~=1.4.2 10 | --------------------------------------------------------------------------------