├── .gitignore ├── CONTRIBUTING.md ├── README.md ├── LICENSE └── camelyon17k_demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # augmentations_medical_images 2 | 3 | This code corresponds to our Nature Medicine paper on 4 | "Generative models improve fairness of medical classifiers under 5 | distribution shifts". 6 | In this codebase we release models (both classifiers and generative models) for 7 | the [Camelyon17K dataset](https://wilds.stanford.edu/datasets/). 8 | 9 | ## Usage 10 | 11 | All code is in CoLABS. 12 | - Camelyon17K: `./camelyon17k_demo.ipynb` 13 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/augmentations_medical_images/blob/master/camelyon17k_demo.ipynb). 14 | 15 | ## Citing this work 16 | 17 | ```tex 18 | @article{ktena2023generative, 19 | title={Generative models improve fairness of medical classifiers under distribution shifts}, 20 | author={Ktena, Ira* and Wiles, Olivia* and Albuquerque, Isabela and Rebuffi, Sylvestre-Alvise and Tanno, Ryutaro and Roy, Abhijit Guha and Azizi, Shekoofeh and Belgrave, Danielle and Kohli, Pushmeet and Karthikesalingam, Alan and Cemgil, Taylan and Gowal, Sven}, 21 | journal={arXiv preprint arXiv:2304.09218}, 22 | year={2023} 23 | } 24 | ``` 25 | 26 | ## License and disclaimer 27 | 28 | Copyright 2024 DeepMind Technologies Limited 29 | 30 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 31 | you may not use this file except in compliance with the Apache 2.0 license. 32 | You may obtain a copy of the Apache 2.0 license at: 33 | https://www.apache.org/licenses/LICENSE-2.0 34 | 35 | All other materials are licensed under the Creative Commons Attribution 4.0 36 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 37 | https://creativecommons.org/licenses/by/4.0/legalcode 38 | 39 | Unless required by applicable law or agreed to in writing, all software and 40 | materials distributed here under the Apache 2.0 or CC-BY licenses are 41 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 42 | either express or implied. See the licenses for the specific language governing 43 | permissions and limitations under those licenses. 44 | 45 | This is not an official Google product. 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /camelyon17k_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "Y7FTPXzCYeaZ" 7 | }, 8 | "source": [ 9 | "Copyright 2024 Google LLC\n", 10 | "\n", 11 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 12 | "you may not use this file except in compliance with the License.\n", 13 | "You may obtain a copy of the License at\n", 14 | "\n", 15 | " https://www.apache.org/licenses/LICENSE-2.0\n", 16 | "\n", 17 | "Unless required by applicable law or agreed to in writing, software\n", 18 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 19 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 20 | "See the License for the specific language governing permissions and\n", 21 | "limitations under the License." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": { 27 | "id": "2HkRnXLvDEhG" 28 | }, 29 | "source": [ 30 | "# Models for Camelyon17K\n", 31 | "\n", 32 | "This CoLAB shows how to load and run models on Camelyon17K. In particular, we have two sets of models:\n", 33 | "\n", 34 | "1. *Generative*: here the models generate synthetic images with and without tumours across different hospitals.\n", 35 | "2. *Classification*: here the models take an image of a potentially tumourous slide image and classify whether there are or are not tumours.\n", 36 | "\n", 37 | "The CoLAB is divided into two sections for these two use cases.\n", 38 | "\n", 39 | "We save our models using jax2tf for ease of use. Models can be downloaded with this [link](https://storage.googleapis.com/augmentations_medical_images/open_source/open_source.tgz). Note that this CoLAB was *NOT* used for any results in the paper but is provided to (1) show sample results with our saved out models and (2) demonstrate how our pipeline operates.\n", 40 | "\n", 41 | "This code was run on a TPU so it is unclear how feasible the different parts will be to run on a CPU.\n" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "id": "Lq8wR6UhMaKz" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "# See instructions at https://github.com/google/jax#installation for how to install.\n", 53 | "# !pip install -U \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", 54 | "\n", 55 | "!pip install matplotlib\n", 56 | "!pip install numpy\n", 57 | "!pip install tensorflow\n", 58 | "!pip install tensorflow_datasets" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "id": "T04fYV8XDp4w" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "# @title Imports\n", 70 | "import jax\n", 71 | "import matplotlib.pyplot as plt\n", 72 | "import numpy as np\n", 73 | "import os\n", 74 | "import pandas as pd\n", 75 | "import tensorflow as tf\n", 76 | "import tensorflow_datasets as tfds" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "id": "RaAUdCNzp2zr" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "# @markdown Download the open source models.\n", 88 | "# @markdown Save them to `./open_source/`.\n", 89 | "\n", 90 | "# Path to open sourcing directory.\n", 91 | "base_path = './open_source/' # @param {type: 'string'}" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "id": "4DgcVeOZjB4e" 98 | }, 99 | "source": [ 100 | "## Image Generation\n", 101 | "\n", 102 | "Here we show how to sample from two different generative models (trained on the full Camelyon dataset or the *most skewed* version) and also load samples generated by those models. Both models were trained with unlabelled data as well as labelled data." 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "id": "biiogreIkfOv" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "# @title Sampling code\n", 114 | "# @markdown Note that you will need a GPU or TPU to run this in any amount of reasonable time.\n", 115 | "\n", 116 | "file_name = 'skewed100_gendata' # @param {type: 'string'} ['gendata', 'skewed100_gendata']\n", 117 | "\n", 118 | "if file_name == 'skewed100_gendata':\n", 119 | " model_name = '44644773_3_skewed100_gendatamodel'\n", 120 | "else:\n", 121 | " model_name = '51586976_1_genmodel'\n", 122 | "with tf.device('TPU'):\n", 123 | " restored_model = tf.saved_model.load(f'{base_path}/histopathology/models/{model_name}/')\n", 124 | "\n", 125 | "all_images = []\n", 126 | "# We use these hospital ids from the WILDS dataset.\n", 127 | "# Hospital ids [1, 2] are OOD Val and Test.\n", 128 | "for hospital in [0, 3, 4]:\n", 129 | " for label_id in [0, 1]:\n", 130 | " one_hot_hospital = jax.nn.one_hot(hospital, 5)[None, :]\n", 131 | " one_hot_label = jax.nn.one_hot(label_id, 2)[None, :]\n", 132 | " res = restored_model(np.zeros(1,), one_hot_label, one_hot_hospital)\n", 133 | " all_images.append(res[0])\n", 134 | "\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": { 141 | "id": "eTMRad7HKPT-" 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "fig, ax = plt.subplots(3, 2, figsize=(10, 10))\n", 146 | "\n", 147 | "for i in range(3):\n", 148 | " for j in range(2):\n", 149 | " ax[i][j].imshow(all_images[i * 2 + j])\n", 150 | " ax[i][j].axis('off')\n", 151 | " ax[i][j].set_title(f'Label: {j}, Center: {i}')" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "PatnjEJ_NEso" 158 | }, 159 | "source": [ 160 | "# Classification" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "id": "uco5VWoWDrCy" 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "# @title Load in a saved model and evaluate\n", 172 | "# @markdown This cell and the one below show how to load in the saved out models and run inference on them on the evaluation datasets.\n", 173 | "# @markdown We exported 4 models on histopathology: the baseline and our model conditioned on the hospital and tumor label (with color augmentation).\n", 174 | "# @markdown We export these two setups for the *most skewed* and *all data* setting.\n", 175 | "\n", 176 | "\n", 177 | "# @markdown Note that the following code gives only the results for one model: in the paper we report results across five runs.\n", 178 | "\n", 179 | "model_name = 'baseline' # @param {type: 'string'} ['ours_multiclass', 'baseline', 'skewed100_baseline', 'skewed100_ours_multiclass']\n", 180 | "device = 'CPU' # @param {type: 'string'} ['CPU', 'GPU', 'TPU']\n", 181 | "with tf.device(device):\n", 182 | " restored_model = tf.saved_model.load(f'{base_path}/histopathology/models/{model_name}')" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "id": "zPOJUx4wNKy_" 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "# @markdown Create a tfds version of the Camelyon17 dataset:\n", 194 | "# @markdown Follow the instructions in the [WILDS code](https://github.com/p-lambda/wilds/blob/main/wilds/datasets/camelyon17_dataset.py)\n", 195 | "# @markdown to download a the blob file which includes images and metadata.\n", 196 | "_CAMELYON_LOCATION = './camelyon17/' # @param\n", 197 | "TEST_CENTER = 2\n", 198 | "VAL_CENTER = 1\n", 199 | "\n", 200 | "# @markdown Here we load in the full dataset, but note we also created skewed versions in the paper\n", 201 | "# @markdown which are not shown here to demonstrate the robustness of our approach in these settings.\n", 202 | "\n", 203 | "def parse_function(filename, label, center):\n", 204 | " image_string = tf.io.read_file(filename)\n", 205 | " image_decoded = tf.image.decode_image(image_string)\n", 206 | " image = tf.cast(image_decoded, tf.float32)\n", 207 | " return {'image': image, 'label': label, 'center': center}\n", 208 | "\n", 209 | "def load_camelyon():\n", 210 | " camelyon_path = os.path.join(_CAMELYON_LOCATION, 'metadata.csv')\n", 211 | "\n", 212 | " metadata_df = pd.read_csv(camelyon_path, index_col=0,dtype={'patient': 'str'})\n", 213 | " patches_location = f'{_CAMELYON_LOCATION}/patches/'\n", 214 | " input_array = [\n", 215 | " f'{patches_location}/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'\n", 216 | " for patient, node, x, y in\n", 217 | " metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)]\n", 218 | " metadata_df['images'] = input_array\n", 219 | "\n", 220 | " # Extract splits\n", 221 | " split_dict = {\n", 222 | " 'train': 0,\n", 223 | " 'id_val': 1,\n", 224 | " 'test': 2,\n", 225 | " 'val': 3\n", 226 | " }\n", 227 | " val_center_mask = (metadata_df['center'] == VAL_CENTER)\n", 228 | " test_center_mask = (metadata_df['center'] == TEST_CENTER)\n", 229 | " metadata_df.loc[val_center_mask, 'split'] = split_dict['val']\n", 230 | " metadata_df.loc[test_center_mask, 'split'] = split_dict['test']\n", 231 | " return metadata_df\n", 232 | "\n", 233 | "camelyon_metadata = load_camelyon()\n", 234 | "\n", 235 | "def load_eval_dataset(batch_size, split='id_val'):\n", 236 | " \"\"\"Load in the Camelyon eval dataset into a tfds structure.\"\"\"\n", 237 | " if split == 'id_val':\n", 238 | " split_id = 1\n", 239 | " elif split == 'ood_test':\n", 240 | " split_id = 2\n", 241 | " elif split == 'ood_val':\n", 242 | " split_id = 3\n", 243 | " else:\n", 244 | " raise ValueError(f'Unknown split: {split}')\n", 245 | " eval_data = camelyon_metadata[camelyon_metadata['split'] == split_id]\n", 246 | " files = eval_data['images'].values\n", 247 | " labels = eval_data['tumor'].values\n", 248 | " center = eval_data['center'].values\n", 249 | " images = tf.constant(files)\n", 250 | " labels = tf.constant(labels)\n", 251 | " center = tf.constant(center)\n", 252 | " dataset = tf.data.Dataset.from_tensor_slices((images, labels, center))\n", 253 | " dataset = dataset.map(parse_function).batch(batch_size)\n", 254 | " return dataset" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": { 261 | "id": "0GnS7vf0mW4o" 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "predictions = []\n", 266 | "center = []\n", 267 | "true_labels = []\n", 268 | "\n", 269 | "\n", 270 | "for eval_dataset in ['id_val', 'ood_val', 'ood_test']:\n", 271 | " print(f'Results for {eval_dataset}')\n", 272 | " ds = load_eval_dataset(512, eval_dataset)\n", 273 | " ds = tfds.as_numpy(ds)\n", 274 | " for i, ds_item in enumerate(ds):\n", 275 | " images = ds_item['image'].astype(np.float32) / 255.0\n", 276 | " labels = ds_item['label']\n", 277 | " centers = ds_item['center']\n", 278 | "\n", 279 | " logits = restored_model(images)\n", 280 | " predicted_label = np.argmax(logits, axis=-1)\n", 281 | " predictions.append(predicted_label)\n", 282 | " center.append(centers)\n", 283 | " true_labels.append(labels)\n", 284 | " print(\n", 285 | " f'# samples: {np.concatenate(true_labels).shape[0]} in dataset'\n", 286 | " f' {eval_dataset}'\n", 287 | " )\n", 288 | " print(\n", 289 | " 'Accuracy:'\n", 290 | " f' {(np.concatenate(predictions) == np.concatenate(true_labels)).mean()}'\n", 291 | " )\n", 292 | "\n", 293 | " if eval_dataset == 'id_val':\n", 294 | " centers = np.concatenate(center)\n", 295 | " predictions = np.concatenate(predictions)\n", 296 | " true_labels = np.concatenate(true_labels)\n", 297 | " err_center = [\n", 298 | " (predictions[centers == c] == true_labels[centers == c]).mean()\n", 299 | " for c in np.unique(centers)\n", 300 | " ]\n", 301 | " print(f'Fairness GAP: {(max(err_center)) - min(err_center)}')\n", 302 | " print('\\n\\n')" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "id": "tuax08JJaRIC" 309 | }, 310 | "source": [ 311 | "With the code above, you should get the following results. Note that these are results with a *single* model (in the paper we reported mean and standard deviation across five seeds):\n", 312 | "\n", 313 | "| model | checkpoint name | Training setup | ID_VAL | OOD_VAL | OOD_TEST | FAIRNESS_GAP |\n", 314 | "|--------|------|------|---------|---------|-----------|--------------|\n", 315 | "| Ours (Multi class) | `ours_multiclass` | All train | 98.0 | 94.2 | 94.8 | 0.006 |\n", 316 | "| Baseline | `baseline` | All train | 92.4 | 85.1 | 62.4 | 0.041 | \n", 317 | "| Ours (Multi class) | `skewed100_ours_multiclass` | Most skewed | 96.0 | 92.9 | 94.2 | 0.023 | \n", 318 | "| Baseline | `skewed100_baseline` | Most skewed | 75.7 | 88.6 | 64.3 | 0.464 |" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": { 325 | "id": "5z7TuWS5xGTp" 326 | }, 327 | "outputs": [], 328 | "source": [] 329 | } 330 | ], 331 | "metadata": { 332 | "colab": { 333 | "last_runtime": { 334 | "kind": "private" 335 | }, 336 | "private_outputs": true, 337 | "provenance": [ 338 | { 339 | "file_id": "1erPHavpU5ntHNnV3j6rDTgxY5uexSPEn", 340 | "timestamp": 1705326055329 341 | }, 342 | { 343 | "file_id": "1eUViTtvvoV2b1bXgUwzJ_-Ip8__n7umg", 344 | "timestamp": 1705323298445 345 | }, 346 | { 347 | "file_id": "1VdphEI83Ioor-7HM7RD3cleZGSgsr8WN", 348 | "timestamp": 1691487304569 349 | } 350 | ], 351 | "toc_visible": true 352 | }, 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "name": "python3" 356 | }, 357 | "language_info": { 358 | "name": "python" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 0 363 | } 364 | --------------------------------------------------------------------------------