├── .gitignore ├── LICENSE ├── README.md ├── notebooks ├── training.ipynb └── visualization.ipynb ├── preprocessing_requirements.txt ├── requirements.txt ├── rxrx ├── __init__.py ├── input.py ├── io.py ├── main.py ├── official_resnet.py ├── preprocess │ ├── __init__.py │ ├── images2tfrecords.py │ └── images2zarr.py └── utils.py ├── setup.py ├── test_requirements.txt └── tests └── test_images2tfrecords.py /.gitignore: -------------------------------------------------------------------------------- 1 | scratches 2 | 3 | # editor 4 | *.swp 5 | 6 | # backup 7 | *.bak 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | .pytest_cache/ 14 | .mypy_cache/ 15 | 16 | # Possible secrets 17 | *.json 18 | *.yaml 19 | *.yml 20 | secrets.tar 21 | !.travis.yml 22 | !environment.yml 23 | deploy/kube/ 24 | deploy/terraform/ 25 | !conda.recipe/meta.yaml 26 | !configome.sample.yaml 27 | 28 | # Terraform 29 | .terraform 30 | *.tf 31 | *.tfstate 32 | *.backup 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | env/ 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | .hypothesis/ 77 | .pytest_cache/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | 88 | # Flask stuff: 89 | instance/ 90 | .webassets-cache 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | 98 | # PyBuilder 99 | target/ 100 | 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # celery beat schedule file 108 | celerybeat-schedule 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019, Recursion Pharmaceuticals. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2019, Recursion Pharmaceuticals. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![scorecard-score](https://github.com/recursionpharma/octo-guard-badges/blob/trunk/badges/repo/rxrx1-utils/maturity_score.svg?raw=true)](https://infosec-docs.prod.rxrx.io/octoguard/scorecards/rxrx1-utils) 2 | [![scorecard-status](https://github.com/recursionpharma/octo-guard-badges/blob/trunk/badges/repo/rxrx1-utils/scorecard_status.svg?raw=true)](https://infosec-docs.prod.rxrx.io/octoguard/scorecards/rxrx1-utils) 3 | # rxrx1-utils 4 | 5 | Starter code for the CellSignal NeurIPS 2019 competition [hosted on Kaggle](https://www.kaggle.com/c/recursion-cellular-image-classification). 6 | 7 | To learn more about the dataset please visit [RxRx.ai](http://rxrx.ai). 8 | 9 | ## Notebooks 10 | 11 | Here are some notebooks to illustrate how this code can be used. 12 | 13 | * [Image visualization][vis-notebook] 14 | * [Model training on TPUs][training-notebook] 15 | 16 | [vis-notebook]: https://colab.research.google.com/github/recursionpharma/rxrx1-utils/blob/trunk/notebooks/visualization.ipynb 17 | [training-notebook]: https://colab.research.google.com/github/recursionpharma/rxrx1-utils/blob/trunk/notebooks/training.ipynb 18 | 19 | ## Setup 20 | 21 | This starter code works with python 2.7 and above. To install the deps needed for training and visualization run: 22 | 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | If you plan on using the preprocessing functionality you also need to install other deps: 28 | 29 | ``` 30 | pip install -r preprocessing_requirements.txt 31 | ``` 32 | 33 | ## Preprocessing 34 | 35 | Reading individual image files can become an IO bottleneck during training. This is will be a common problem faced by people who use this dataset so we are also releasing an example script to pack the images into TFRecords and `zarr` files. We are also making available some pre-created TFRecords available in Google Cloud Storage. Read more about the [provided TFRecords below](#provided-tfrecords). 36 | 37 | 38 | ### images2tfrecords 39 | 40 | Script that packs raw images from the `rxrx1` dataset into `TFRecord`s. This scripts runs locally or using Google DataFlow. 41 | 42 | Run `python -m rxrx.preprocess.images2tfrecords --help` for usage instructions. 43 | 44 | 45 | ### images2zarr 46 | 47 | Script that packs raw images from the `rxrx1` dataset into `zarr`s. This script only runs locally but could easily be extended to run using Google DataFlow. 48 | 49 | This script packs each site image into a single `zarr`. So, instead of having to load 6 separate channel `png`s for a singe image all of those channels will be saved together in a single `zarr` file. 50 | You could extend the script to pack more images into a single `zarr` file similar to what is done for `TFRecord`s. This is left as an exercise to the IO bound reader. :) Read more about the Zarr format and library [here](https://zarr.readthedocs.io/en/stable/). 51 | 52 | Run `python -m rxrx.preprocess.images2zarr --help` for usage instructions. 53 | 54 | ## Training on TPUs 55 | 56 | This repo has barebones starter code on how to train a model on the RxRx1 dataset using Google Cloud TPUs. 57 | 58 | The easiest way to see this in action is to look at this [notebook][training-notebook]. 59 | 60 | You can also spin up a VM to launch jobs from. To understand TPUs the best place to start is the [TPU quickstart guide][tpu-quickstart]. The `ctpu` command is helpful and you can find [its documentation][ctpu-docs] here. Note, that you can easily [download and install ctpu][download-ctpu] to you local machine. 61 | 62 | [tpu-quickstart]: https://cloud.google.com/tpu/docs/quickstart 63 | [ctpu-docs]: https://cloud.google.com/tpu/docs/ctpu-reference 64 | [download-ctpu]: https://github.com/tensorflow/tpu/tree/master/tools/ctpu#download 65 | 66 | ### Example TPU workflow 67 | 68 | First spin up a VM: 69 | ``` 70 | ctpu up -vm-only -forward-agent -forward-ports -name my-tpu-vm 71 | ``` 72 | 73 | This command will create the VM and `ssh` you into it. Note how the `-vm-only` flag is used. This allows you to spin up the VM separate from the TPU which helps prevent spending money on idle TPUs. 74 | 75 | Next, setup the repo and install the dependencies: 76 | ``` 77 | git clone git@github.com:recursionpharma/rxrx1-utils.git 78 | cd rxrx1-utils 79 | pip install -r requirements.txt # optional if just training! 80 | ``` 81 | 82 | Note that for just training you can skip the `pip install` since the VM will have all the needed deps already. 83 | 84 | Next you need to spin up a TPU for training: 85 | ``` 86 | export TPU_NAME=my-tpu-v3-8 87 | ctpu up -name "$TPU_NAME" -preemptible -tpu-only -tpu-size v3-8 88 | ``` 89 | 90 | Once that is complete you can start a training job: 91 | ``` 92 | python -m rxrx.main --model-dir "gs://path-to-bucket/trial-id/" 93 | ``` 94 | You'll also want to launch a `tensorboard` to watch to check the results: 95 | 96 | ``` 97 | tensorboard --logdir=gs://path-to-bucket/ 98 | ``` 99 | Since we used the `-forward-ports` in the `ctpu` command when starting the VM you will be able to view `tensorboard` on your localhost. 100 | 101 | Once you are done with the TPU be sure to delete it! 102 | ``` 103 | ctpu delete -name "$TPU_NAME" -tpu-only` 104 | ``` 105 | 106 | You can then iterate on the code and spin up a TPU again when ready to try again. 107 | 108 | When you are done with your VM you can either stop it or delete it with the `ctpu` command, for example: 109 | ``` 110 | ctpu delete -name my-tpu-vm 111 | ``` 112 | 113 | ## Provided TFRecords 114 | 115 | As noted above we are providing TFRecords. They live in the following buckets: 116 | 117 | ``` 118 | gs://rxrx1-us-central1/tfrecords 119 | gs://rxrx1-europe-west4/tfrecords 120 | ``` 121 | 122 | The data lives in these two regional buckets because when you train with TPUs you want to train from buckets in the same region as your TPU. Remember to use the appropriate bucket that is in the same region as your TPU! 123 | 124 | The directory structure of the TFRecords is as follows: 125 | 126 | ``` 127 | └── tfrecords 128 | ├── by_exp_plate_site-42 129 | │ ├── HEPG2-10_p1_s1.tfrecord 130 | │ ├── HEPG2-10_p1_s2.tfrecord 131 | │ ├── …. 132 | │ ├── U2OS-03_p3_s2.tfrecord 133 | │ ├── U2OS-03_p4_s2.tfrecord 134 | │ └── U2OS-03_p4_s2.tfrecord 135 | └── random-42 136 | ├── train 137 | │ ├── 001.tfrecord 138 | │ ├── 002.tfrecord 139 | …. 140 | ``` 141 | The `random-42` denotes that the data has been split up randomly across different tfrecords, each record holding ~1000 examples. The `42` is the random seed used to generate this partition. The example code in this repository uses this version of the data. 142 | 143 | The `by_exp_plate_site-42` is where each TFRecord contains an all of the images for a particular experiment, plate, and site grouping. Internally the well addresses are random in the TFRecord. The advantage of this grouping is that you can be selective on the experiments you train on. Due to the grouping each TFRecord here has only about ~277 examples per file. 144 | 145 | For good training batch diversity it is recommended that you use the TF Dataset API to interleave examples from these various files. The provided `input_fn` in this repository already does this. 146 | -------------------------------------------------------------------------------- /notebooks/training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "training.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "TPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "rasOe6jNJbFj", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "# How to train a ResNet50 on RxRx1 using TPUs \n", 26 | "\n", 27 | "Colaboratory makes it easy to train models using [Cloud TPUs](https://cloud.google.com/tpu/), and this notebook demonstrates how to use the code in [rxrx1-utils](https://github.com/recursionpharma/rxrx1-utils) to train ResNet50 on the RxRx1 image set using Colab TPU.\n", 28 | "\n", 29 | "Be sure to select the TPU runtime before beginning!" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "metadata": { 35 | "id": "cKtZctcXJTAZ", 36 | "colab_type": "code", 37 | "colab": {} 38 | }, 39 | "source": [ 40 | "import json\n", 41 | "import os\n", 42 | "import sys\n", 43 | "import tensorflow as tf" 44 | ], 45 | "execution_count": 0, 46 | "outputs": [] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "metadata": { 51 | "id": "LNgr17uD0K--", 52 | "colab_type": "code", 53 | "outputId": "dc2b6d1b-fa7d-481d-c5ef-887876f6c27a", 54 | "colab": { 55 | "base_uri": "https://localhost:8080/", 56 | "height": 119 57 | } 58 | }, 59 | "source": [ 60 | "if 'google.colab' in sys.modules:\n", 61 | " !git clone https://github.com/recursionpharma/rxrx1-utils\n", 62 | " sys.path.append('/content/rxrx1-utils')\n", 63 | "\n", 64 | " from google.colab import auth\n", 65 | " auth.authenticate_user()\n", 66 | " \n", 67 | "from rxrx.main import main" 68 | ], 69 | "execution_count": 2, 70 | "outputs": [ 71 | { 72 | "output_type": "stream", 73 | "text": [ 74 | "Cloning into 'rxrx1-utils'...\n", 75 | "remote: Enumerating objects: 99, done.\u001b[K\n", 76 | "remote: Counting objects: 100% (99/99), done.\u001b[K\n", 77 | "remote: Compressing objects: 100% (53/53), done.\u001b[K\n", 78 | "remote: Total 99 (delta 48), reused 92 (delta 42), pack-reused 0\u001b[K\n", 79 | "Unpacking objects: 100% (99/99), done.\n" 80 | ], 81 | "name": "stdout" 82 | } 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "colab_type": "text", 89 | "id": "HrPeVFofzIdy" 90 | }, 91 | "source": [ 92 | "## Train\n", 93 | "\n", 94 | "Set `MODEL_DIR` to be a Google Cloud Storage bucket that you can write to. The code will write your checkpoins to this directory." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "id": "Z9MjRJpwJTAw", 101 | "colab_type": "code", 102 | "outputId": "44655cb7-37c9-4a20-a689-6cda882224ba", 103 | "colab": { 104 | "base_uri": "https://localhost:8080/", 105 | "height": 1000 106 | } 107 | }, 108 | "source": [ 109 | "MODEL_DIR = 'gs://path/to/your/bucket'\n", 110 | "URL_BASE_PATH = 'gs://rxrx1-us-central1/tfrecords/random-42'\n", 111 | "\n", 112 | "# make sure we're in a TPU runtime\n", 113 | "assert 'COLAB_TPU_ADDR' in os.environ\n", 114 | "\n", 115 | "# set TPU-relevant args\n", 116 | "tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])\n", 117 | "num_shards = 8 # colab uses Cloud TPU v2-8\n", 118 | "\n", 119 | "# upload credentials to the TPU\n", 120 | "with tf.Session(tpu_grpc) as sess:\n", 121 | " data = json.load(open('/content/adc.json'))\n", 122 | " tf.contrib.cloud.configure_gcs(sess, credentials=data)\n", 123 | "\n", 124 | "tf.logging.set_verbosity(tf.logging.INFO)\n", 125 | "\n", 126 | "main(use_tpu=True,\n", 127 | " tpu=tpu_grpc,\n", 128 | " gcp_project=None,\n", 129 | " tpu_zone=None,\n", 130 | " url_base_path=URL_BASE_PATH,\n", 131 | " use_cache=False,\n", 132 | " model_dir=MODEL_DIR,\n", 133 | " train_epochs=1,\n", 134 | " train_batch_size=512,\n", 135 | " num_train_images=73030,\n", 136 | " epochs_per_loop=1,\n", 137 | " log_step_count_epochs=1,\n", 138 | " num_cores=num_shards,\n", 139 | " data_format='channels_last',\n", 140 | " transpose_input=True,\n", 141 | " tf_precision='bfloat16',\n", 142 | " n_classes=1108,\n", 143 | " momentum=0.9,\n", 144 | " weight_decay=1e-4,\n", 145 | " base_learning_rate=0.2,\n", 146 | " warmup_epochs=5)" 147 | ], 148 | "execution_count": 3, 149 | "outputs": [ 150 | { 151 | "output_type": "stream", 152 | "text": [ 153 | "WARNING: Logging before flag parsing goes to stderr.\n", 154 | "W0627 19:53:08.003671 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:280: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", 155 | "\n", 156 | "I0627 19:53:08.005592 139758653511552 main.py:280] tpu: grpc://10.106.194.154:8470\n", 157 | "I0627 19:53:08.010348 139758653511552 main.py:283] gcp_project: None\n", 158 | "W0627 19:53:10.223041 139758653511552 estimator.py:1984] Estimator's model_fn (functools.partial(, n_classes=1108, num_train_images=73030, data_format='channels_last', transpose_input=True, train_batch_size=512, iterations_per_loop=142, tf_precision='bfloat16', momentum=0.9, weight_decay=0.0001, base_learning_rate=0.2, warmup_epochs=5, model_dir='gs://recursion-tpu-training/berton/rxrx1_test/my_test', use_tpu=True, resnet_depth=50)) includes params argument, but params are not passed to Estimator.\n", 159 | "I0627 19:53:10.225781 139758653511552 estimator.py:209] Using config: {'_model_dir': 'gs://recursion-tpu-training/berton/rxrx1_test/my_test', '_tf_random_seed': None, '_save_summary_steps': 142, '_save_checkpoints_steps': 142, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n", 160 | "cluster_def {\n", 161 | " job {\n", 162 | " name: \"worker\"\n", 163 | " tasks {\n", 164 | " key: 0\n", 165 | " value: \"10.106.194.154:8470\"\n", 166 | " }\n", 167 | " }\n", 168 | "}\n", 169 | "isolate_session_state: true\n", 170 | ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.106.194.154:8470', '_evaluation_master': 'grpc://10.106.194.154:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=142, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2), '_cluster': }\n", 171 | "I0627 19:53:10.227086 139758653511552 tpu_context.py:209] _TPUContext: eval_on_tpu True\n", 172 | "I0627 19:53:10.233348 139758653511552 main.py:338] Train glob: gs://rxrx1-us-central1/tfrecords/random-42/train/*.tfrecord\n", 173 | "I0627 19:53:10.236295 139758653511552 main.py:351] Training for 142 steps (1.00 epochs in total). Current step 0.\n", 174 | "I0627 19:53:10.345630 139758653511552 tpu_system_metadata.py:78] Querying Tensorflow master (grpc://10.106.194.154:8470) for TPU system metadata.\n", 175 | "I0627 19:53:10.362607 139758653511552 tpu_system_metadata.py:148] Found TPU system:\n", 176 | "I0627 19:53:10.364023 139758653511552 tpu_system_metadata.py:149] *** Num TPU Cores: 8\n", 177 | "I0627 19:53:10.365448 139758653511552 tpu_system_metadata.py:150] *** Num TPU Workers: 1\n", 178 | "I0627 19:53:10.370308 139758653511552 tpu_system_metadata.py:152] *** Num TPU Cores Per Worker: 8\n", 179 | "I0627 19:53:10.373871 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 9785620760386089044)\n", 180 | "I0627 19:53:10.379615 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 17957700996873846002)\n", 181 | "I0627 19:53:10.387336 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 7534058317997506011)\n", 182 | "I0627 19:53:10.388670 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 732442551779127628)\n", 183 | "I0627 19:53:10.390994 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 1258150734284970345)\n", 184 | "I0627 19:53:10.399887 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 12528303070827221666)\n", 185 | "I0627 19:53:10.402188 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 2385972351757131582)\n", 186 | "I0627 19:53:10.403303 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 17540351673832642764)\n", 187 | "I0627 19:53:10.406220 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 16943942441228138344)\n", 188 | "I0627 19:53:10.408085 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 14676506593360444113)\n", 189 | "I0627 19:53:10.410236 139758653511552 tpu_system_metadata.py:154] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 10007324885980858573)\n", 190 | "W0627 19:53:10.435066 139758653511552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", 191 | "Instructions for updating:\n", 192 | "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n", 193 | "I0627 19:53:10.457272 139758653511552 estimator.py:1145] Calling model_fn.\n", 194 | "W0627 19:53:10.489276 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/input.py:94: shuffle_and_repeat (from tensorflow.contrib.data.python.ops.shuffle_ops) is deprecated and will be removed in a future version.\n", 195 | "Instructions for updating:\n", 196 | "Use `tf.data.experimental.shuffle_and_repeat(...)`.\n", 197 | "W0627 19:53:10.490580 139758653511552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/data/python/ops/shuffle_ops.py:54: shuffle_and_repeat (from tensorflow.python.data.experimental.ops.shuffle_ops) is deprecated and will be removed in a future version.\n", 198 | "Instructions for updating:\n", 199 | "Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.\n", 200 | "W0627 19:53:10.501322 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/input.py:115: parallel_interleave (from tensorflow.contrib.data.python.ops.interleave_ops) is deprecated and will be removed in a future version.\n", 201 | "Instructions for updating:\n", 202 | "Use `tf.data.experimental.parallel_interleave(...)`.\n", 203 | "W0627 19:53:10.502724 139758653511552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/data/python/ops/interleave_ops.py:77: parallel_interleave (from tensorflow.python.data.experimental.ops.interleave_ops) is deprecated and will be removed in a future version.\n", 204 | "Instructions for updating:\n", 205 | "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy execution is desired, use `tf.data.Options.experimental_determinstic`.\n", 206 | "W0627 19:53:10.540040 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/input.py:125: map_and_batch (from tensorflow.contrib.data.python.ops.batching) is deprecated and will be removed in a future version.\n", 207 | "Instructions for updating:\n", 208 | "Use `tf.data.experimental.map_and_batch(...)`.\n", 209 | "W0627 19:53:10.541434 139758653511552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/data/python/ops/batching.py:273: map_and_batch (from tensorflow.python.data.experimental.ops.batching) is deprecated and will be removed in a future version.\n", 210 | "Instructions for updating:\n", 211 | "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by `tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data optimizations will take care of using the fused implementation.\n", 212 | "W0627 19:53:10.547796 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/input.py:48: The name tf.FixedLenFeature is deprecated. Please use tf.io.FixedLenFeature instead.\n", 213 | "\n", 214 | "W0627 19:53:10.550729 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/input.py:59: The name tf.parse_single_example is deprecated. Please use tf.io.parse_single_example instead.\n", 215 | "\n", 216 | "W0627 19:53:10.683037 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/official_resnet.py:211: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", 217 | "Instructions for updating:\n", 218 | "Use `tf.keras.layers.Conv2D` instead.\n", 219 | "W0627 19:53:10.946627 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/official_resnet.py:70: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.\n", 220 | "Instructions for updating:\n", 221 | "Use keras.layers.BatchNormalization instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.batch_normalization` documentation).\n", 222 | "W0627 19:53:11.076068 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/official_resnet.py:413: max_pooling2d (from tensorflow.python.layers.pooling) is deprecated and will be removed in a future version.\n", 223 | "Instructions for updating:\n", 224 | "Use keras.layers.MaxPooling2D instead.\n", 225 | "W0627 19:53:15.471554 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/official_resnet.py:442: average_pooling2d (from tensorflow.python.layers.pooling) is deprecated and will be removed in a future version.\n", 226 | "Instructions for updating:\n", 227 | "Use keras.layers.AveragePooling2D instead.\n", 228 | "W0627 19:53:15.481713 139758653511552 deprecation.py:323] From /content/rxrx1-utils/rxrx/official_resnet.py:449: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", 229 | "Instructions for updating:\n", 230 | "Use keras.layers.dense instead.\n", 231 | "W0627 19:53:16.032349 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:125: The name tf.losses.softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.softmax_cross_entropy instead.\n", 232 | "\n", 233 | "W0627 19:53:16.092159 139758653511552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 234 | "Instructions for updating:\n", 235 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", 236 | "W0627 19:53:16.112515 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:131: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n", 237 | "\n", 238 | "W0627 19:53:16.223887 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:138: The name tf.train.get_global_step is deprecated. Please use tf.compat.v1.train.get_global_step instead.\n", 239 | "\n", 240 | "W0627 19:53:16.239456 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:145: The name tf.train.cosine_decay_restarts is deprecated. Please use tf.compat.v1.train.cosine_decay_restarts instead.\n", 241 | "\n", 242 | "W0627 19:53:16.297146 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:155: The name tf.train.MomentumOptimizer is deprecated. Please use tf.compat.v1.train.MomentumOptimizer instead.\n", 243 | "\n", 244 | "W0627 19:53:16.298741 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:167: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.\n", 245 | "\n", 246 | "W0627 19:53:16.300213 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:167: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.\n", 247 | "\n", 248 | "I0627 19:53:22.770973 139758653511552 basic_session_run_hooks.py:541] Create CheckpointSaverHook.\n", 249 | "I0627 19:53:23.122148 139758653511552 estimator.py:1147] Done calling model_fn.\n", 250 | "I0627 19:53:25.609339 139758653511552 tpu_estimator.py:499] TPU job name worker\n", 251 | "I0627 19:53:27.020166 139758653511552 monitored_session.py:240] Graph was finalized.\n", 252 | "I0627 19:53:31.370273 139758653511552 session_manager.py:500] Running local_init_op.\n", 253 | "I0627 19:53:31.856439 139758653511552 session_manager.py:502] Done running local_init_op.\n", 254 | "I0627 19:53:39.210234 139758653511552 basic_session_run_hooks.py:606] Saving checkpoints for 0 into gs://recursion-tpu-training/berton/rxrx1_test/my_test/model.ckpt.\n", 255 | "W0627 19:53:47.069951 139758653511552 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:741: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", 256 | "Instructions for updating:\n", 257 | "Prefer Variable.assign which has equivalent behavior in 2.X.\n", 258 | "I0627 19:53:48.202275 139758653511552 util.py:98] Initialized dataset iterators in 0 seconds\n", 259 | "I0627 19:53:48.204611 139758653511552 session_support.py:332] Installing graceful shutdown hook.\n", 260 | "I0627 19:53:48.214668 139758653511552 session_support.py:82] Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0']\n", 261 | "I0627 19:53:48.222669 139758653511552 session_support.py:105] Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR\n", 262 | "\n", 263 | "I0627 19:53:48.228841 139758653511552 tpu_estimator.py:557] Init TPU system\n", 264 | "I0627 19:53:52.087585 139758653511552 tpu_estimator.py:566] Initialized TPU in 3 seconds\n", 265 | "I0627 19:53:52.887957 139757323314944 tpu_estimator.py:514] Starting infeed thread controller.\n", 266 | "I0627 19:53:52.888870 139757314922240 tpu_estimator.py:533] Starting outfeed thread controller.\n", 267 | "I0627 19:53:53.404005 139758653511552 tpu_estimator.py:590] Enqueue next (142) batch(es) of data to infeed.\n", 268 | "I0627 19:53:53.406334 139758653511552 tpu_estimator.py:594] Dequeue next (142) batch(es) of data from outfeed.\n", 269 | "I0627 19:54:33.920991 139757314922240 tpu_estimator.py:275] Outfeed finished for iteration (0, 0)\n", 270 | "I0627 19:55:34.432938 139757314922240 tpu_estimator.py:275] Outfeed finished for iteration (0, 80)\n", 271 | "I0627 19:56:20.688752 139758653511552 basic_session_run_hooks.py:606] Saving checkpoints for 142 into gs://recursion-tpu-training/berton/rxrx1_test/my_test/model.ckpt.\n", 272 | "I0627 19:56:27.875086 139758653511552 basic_session_run_hooks.py:262] loss = 7.9618587, step = 142\n", 273 | "I0627 19:56:28.358581 139758653511552 tpu_estimator.py:598] Stop infeed thread controller\n", 274 | "I0627 19:56:28.359947 139758653511552 tpu_estimator.py:430] Shutting down InfeedController thread.\n", 275 | "I0627 19:56:28.364674 139757323314944 tpu_estimator.py:425] InfeedController received shutdown signal, stopping.\n", 276 | "I0627 19:56:28.366270 139757323314944 tpu_estimator.py:530] Infeed thread finished, shutting down.\n", 277 | "I0627 19:56:28.369237 139758653511552 error_handling.py:96] infeed marked as finished\n", 278 | "I0627 19:56:28.371149 139758653511552 tpu_estimator.py:602] Stop output thread controller\n", 279 | "I0627 19:56:28.373593 139758653511552 tpu_estimator.py:430] Shutting down OutfeedController thread.\n", 280 | "I0627 19:56:28.377429 139757314922240 tpu_estimator.py:425] OutfeedController received shutdown signal, stopping.\n", 281 | "I0627 19:56:28.379346 139757314922240 tpu_estimator.py:541] Outfeed thread finished, shutting down.\n", 282 | "I0627 19:56:28.381241 139758653511552 error_handling.py:96] outfeed marked as finished\n", 283 | "I0627 19:56:28.382400 139758653511552 tpu_estimator.py:606] Shutdown TPU system.\n", 284 | "I0627 19:56:32.223384 139758653511552 estimator.py:368] Loss for final step: 7.9618587.\n", 285 | "I0627 19:56:32.225899 139758653511552 error_handling.py:96] training_loop marked as finished\n", 286 | "I0627 19:56:32.231190 139758653511552 main.py:358] Finished training up to step 142. Elapsed seconds 201.\n", 287 | "I0627 19:56:32.232772 139758653511552 main.py:363] Finished training up to step 142. Elapsed seconds 201.\n" 288 | ], 289 | "name": "stderr" 290 | } 291 | ] 292 | } 293 | ] 294 | } 295 | -------------------------------------------------------------------------------- /preprocessing_requirements.txt: -------------------------------------------------------------------------------- 1 | apache-beam[gcp] 2 | dask 3 | zarr 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | toolz 2 | scikit-image 3 | pandas 4 | google-cloud-storage 5 | google-api-python-client 6 | oauth2client 7 | tensorflow==1.13.1 8 | apache-beam[gcp] 9 | -------------------------------------------------------------------------------- /rxrx/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | 8 | 9 | def _add_tpu_models_to_path(): 10 | dir = os.path.dirname(os.path.realpath(__file__)) 11 | tpu_models_dir = os.path.abspath(os.path.join(dir, '..', 'tpu', 'models')) 12 | if tpu_models_dir not in sys.path: 13 | sys.path.insert(0, tpu_models_dir) 14 | 15 | 16 | _add_tpu_models_to_path() 17 | -------------------------------------------------------------------------------- /rxrx/input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Efficient input pipeline using tf.data.Dataset. 16 | 17 | Original file: 18 | https://github.com/tensorflow/tpu/blob/master/models/official/resnet/imagenet_input.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from functools import partial, reduce 26 | 27 | import tensorflow as tf 28 | 29 | 30 | def set_shapes(transpose_input, batch_size, images, labels): 31 | """Statically set the batch_size dimension.""" 32 | if transpose_input: 33 | images.set_shape(images.get_shape().merge_with( 34 | tf.TensorShape([None, None, None, batch_size]))) 35 | labels.set_shape( 36 | labels.get_shape().merge_with(tf.TensorShape([batch_size]))) 37 | else: 38 | images.set_shape(images.get_shape().merge_with( 39 | tf.TensorShape([batch_size, None, None, None]))) 40 | labels.set_shape( 41 | labels.get_shape().merge_with(tf.TensorShape([batch_size]))) 42 | 43 | return images, labels 44 | 45 | def parse_example(value, use_bfloat16=True, pixel_stats=None): 46 | 47 | keys_to_features = { 48 | 'image': tf.FixedLenFeature((), tf.string), 49 | 'well': tf.FixedLenFeature((), tf.string), 50 | 'well_type': tf.FixedLenFeature((), tf.string), 51 | 'plate': tf.FixedLenFeature((), tf.int64), 52 | 'site': tf.FixedLenFeature((), tf.int64), 53 | 'cell_type': tf.FixedLenFeature((), tf.string), 54 | 'sirna': tf.FixedLenFeature((), tf.int64), 55 | 'experiment': tf.FixedLenFeature((), tf.string) 56 | } 57 | 58 | image_shape = [512, 512, 6] 59 | parsed = tf.parse_single_example(value, keys_to_features) 60 | image_raw = tf.decode_raw(parsed['image'], tf.uint8) 61 | image = tf.reshape(image_raw, image_shape) 62 | image.set_shape(image_shape) 63 | 64 | if pixel_stats is not None: 65 | mean, std = pixel_stats 66 | image = (tf.cast(image, tf.float32) - mean) / std 67 | 68 | if use_bfloat16: 69 | image = tf.image.convert_image_dtype(image, dtype=tf.bfloat16) 70 | 71 | label = parsed["sirna"] 72 | 73 | return image, label 74 | 75 | 76 | DEFAULT_PARAMS = dict(batch_size=512) 77 | 78 | 79 | def input_fn(tf_records_glob, 80 | input_fn_params, 81 | params=None, 82 | use_bfloat16=False, 83 | pixel_stats = None, 84 | transpose_input=True, 85 | shuffle_buffer=64): 86 | 87 | batch_size = params['batch_size'] 88 | 89 | filenames_dataset = tf.data.Dataset.list_files(tf_records_glob) 90 | 91 | def fetch_images(filenames): 92 | dataset = tf.data.TFRecordDataset( 93 | filenames, 94 | compression_type="GZIP", 95 | buffer_size=(1000 * 1000 * 96 | input_fn_params['tfrecord_dataset_buffer_size']), 97 | num_parallel_reads=input_fn_params[ 98 | 'tfrecord_dataset_num_parallel_reads']) 99 | return dataset 100 | 101 | images_dataset = filenames_dataset.apply( 102 | tf.contrib.data.parallel_interleave( 103 | fetch_images, 104 | cycle_length=input_fn_params['parallel_interleave_cycle_length'], 105 | block_length=input_fn_params['parallel_interleave_block_length'], 106 | sloppy=True, 107 | buffer_output_elements=input_fn_params[ 108 | 'parallel_interleave_buffer_output_elements'], 109 | prefetch_input_elements=input_fn_params[ 110 | 'parallel_interleave_prefetch_input_elements'])) 111 | 112 | images_dataset = images_dataset.shuffle(2048).repeat() 113 | 114 | # examples dataset 115 | dataset = images_dataset.apply( 116 | tf.contrib.data.map_and_batch( 117 | lambda value: parse_example(value, 118 | use_bfloat16=use_bfloat16, 119 | pixel_stats=pixel_stats), 120 | batch_size=batch_size, 121 | num_parallel_calls=input_fn_params['map_and_batch_num_parallel_calls'], 122 | drop_remainder=True)) 123 | 124 | # Transpose for performance on TPU 125 | if transpose_input: 126 | dataset = dataset.map( 127 | lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), 128 | num_parallel_calls=input_fn_params['transpose_num_parallel_calls']) 129 | 130 | # Assign static batch size dimension 131 | dataset = dataset.map(partial(set_shapes, transpose_input, batch_size)) 132 | 133 | # Prefetch overlaps in-feed with training 134 | dataset = dataset.prefetch( 135 | buffer_size=input_fn_params['prefetch_buffer_size']) 136 | 137 | return dataset 138 | -------------------------------------------------------------------------------- /rxrx/io.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import numpy as np 8 | from skimage.io import imread 9 | import pandas as pd 10 | 11 | import tensorflow as tf 12 | 13 | DEFAULT_BASE_PATH = 'gs://rxrx1-us-central1' 14 | DEFAULT_METADATA_BASE_PATH = os.path.join(DEFAULT_BASE_PATH, 'metadata') 15 | DEFAULT_IMAGES_BASE_PATH = os.path.join(DEFAULT_BASE_PATH, 'images') 16 | DEFAULT_CHANNELS = (1, 2, 3, 4, 5, 6) 17 | RGB_MAP = { 18 | 1: { 19 | 'rgb': np.array([19, 0, 249]), 20 | 'range': [0, 51] 21 | }, 22 | 2: { 23 | 'rgb': np.array([42, 255, 31]), 24 | 'range': [0, 107] 25 | }, 26 | 3: { 27 | 'rgb': np.array([255, 0, 25]), 28 | 'range': [0, 64] 29 | }, 30 | 4: { 31 | 'rgb': np.array([45, 255, 252]), 32 | 'range': [0, 191] 33 | }, 34 | 5: { 35 | 'rgb': np.array([250, 0, 253]), 36 | 'range': [0, 89] 37 | }, 38 | 6: { 39 | 'rgb': np.array([254, 255, 40]), 40 | 'range': [0, 191] 41 | } 42 | } 43 | 44 | 45 | def load_image(image_path): 46 | with tf.io.gfile.GFile(image_path, 'rb') as f: 47 | return imread(f, format='png') 48 | 49 | 50 | def load_images_as_tensor(image_paths, dtype=np.uint8): 51 | n_channels = len(image_paths) 52 | 53 | data = np.ndarray(shape=(512, 512, n_channels), dtype=dtype) 54 | 55 | for ix, img_path in enumerate(image_paths): 56 | data[:, :, ix] = load_image(img_path) 57 | 58 | return data 59 | 60 | 61 | def convert_tensor_to_rgb(t, channels=DEFAULT_CHANNELS, vmax=255, rgb_map=RGB_MAP): 62 | """ 63 | Converts and returns the image data as RGB image 64 | 65 | Parameters 66 | ---------- 67 | t : np.ndarray 68 | original image data 69 | channels : list of int 70 | channels to include 71 | vmax : int 72 | the max value used for scaling 73 | rgb_map : dict 74 | the color mapping for each channel 75 | See rxrx.io.RGB_MAP to see what the defaults are. 76 | 77 | Returns 78 | ------- 79 | np.ndarray the image data of the site as RGB channels 80 | """ 81 | colored_channels = [] 82 | for i, channel in enumerate(channels): 83 | x = (t[:, :, i] / vmax) / \ 84 | ((rgb_map[channel]['range'][1] - rgb_map[channel]['range'][0]) / 255) + \ 85 | rgb_map[channel]['range'][0] / 255 86 | x = np.where(x > 1., 1., x) 87 | x_rgb = np.array( 88 | np.outer(x, rgb_map[channel]['rgb']).reshape(512, 512, 3), 89 | dtype=int) 90 | colored_channels.append(x_rgb) 91 | im = np.array(np.array(colored_channels).sum(axis=0), dtype=int) 92 | im = np.where(im > 255, 255, im) 93 | return im 94 | 95 | 96 | def image_path(dataset, 97 | experiment, 98 | plate, 99 | address, 100 | site, 101 | channel, 102 | base_path=DEFAULT_IMAGES_BASE_PATH): 103 | """ 104 | Returns the path of a channel image. 105 | 106 | Parameters 107 | ---------- 108 | dataset : str 109 | what subset of the data: train, test 110 | experiment : str 111 | experiment name 112 | plate : int 113 | plate number 114 | address : str 115 | plate address 116 | site : int 117 | site number 118 | channel : int 119 | channel number 120 | base_path : str 121 | the base path of the raw images 122 | 123 | Returns 124 | ------- 125 | str the path of image 126 | """ 127 | return os.path.join(base_path, dataset, experiment, "Plate{}".format(plate), 128 | "{}_s{}_w{}.png".format(address, site, channel)) 129 | 130 | 131 | def load_site(dataset, 132 | experiment, 133 | plate, 134 | well, 135 | site, 136 | channels=DEFAULT_CHANNELS, 137 | base_path=DEFAULT_IMAGES_BASE_PATH): 138 | """ 139 | Returns the image data of a site 140 | 141 | Parameters 142 | ---------- 143 | dataset : str 144 | what subset of the data: train, test 145 | experiment : str 146 | experiment name 147 | plate : int 148 | plate number 149 | address : str 150 | plate address 151 | site : int 152 | site number 153 | channels : list of int 154 | channels to include 155 | base_path : str 156 | the base path of the raw images 157 | 158 | Returns 159 | ------- 160 | np.ndarray the image data of the site 161 | """ 162 | channel_paths = [ 163 | image_path( 164 | dataset, experiment, plate, well, site, c, base_path=base_path) 165 | for c in channels 166 | ] 167 | return load_images_as_tensor(channel_paths) 168 | 169 | 170 | def load_site_as_rgb(dataset, 171 | experiment, 172 | plate, 173 | well, 174 | site, 175 | channels=DEFAULT_CHANNELS, 176 | base_path=DEFAULT_IMAGES_BASE_PATH, 177 | rgb_map=RGB_MAP): 178 | """ 179 | Loads and returns the image data as RGB image 180 | 181 | Parameters 182 | ---------- 183 | dataset : str 184 | what subset of the data: train, test 185 | experiment : str 186 | experiment name 187 | plate : int 188 | plate number 189 | address : str 190 | plate address 191 | site : int 192 | site number 193 | channels : list of int 194 | channels to include 195 | base_path : str 196 | the base path of the raw images 197 | rgb_map : dict 198 | the color mapping for each channel 199 | See rxrx.io.RGB_MAP to see what the defaults are. 200 | 201 | Returns 202 | ------- 203 | np.ndarray the image data of the site as RGB channels 204 | """ 205 | x = load_site(dataset, experiment, plate, well, site, channels, base_path) 206 | return convert_tensor_to_rgb(x, channels, rgb_map=rgb_map) 207 | 208 | 209 | def _tf_read_csv(path): 210 | with tf.io.gfile.GFile(path, 'rb') as f: 211 | return pd.read_csv(f) 212 | 213 | 214 | def _load_dataset(base_path, dataset, include_controls=True): 215 | df = _tf_read_csv(os.path.join(base_path, dataset + '.csv')) 216 | if include_controls: 217 | controls = _tf_read_csv( 218 | os.path.join(base_path, dataset + '_controls.csv')) 219 | df['well_type'] = 'treatment' 220 | df = pd.concat([controls, df], sort=True) 221 | df['cell_type'] = df.experiment.str.split("-").apply(lambda a: a[0]) 222 | df['dataset'] = dataset 223 | dfs = [] 224 | for site in (1, 2): 225 | df = df.copy() 226 | df['site'] = site 227 | dfs.append(df) 228 | res = pd.concat(dfs).sort_values( 229 | by=['id_code', 'site']).set_index('id_code') 230 | return res 231 | 232 | 233 | def combine_metadata(base_path=DEFAULT_METADATA_BASE_PATH, 234 | include_controls=True): 235 | """ 236 | Combines all metadata files into a single dataframe and 237 | expands it to include sites, not just wells. 238 | 239 | Note, that the dtype of sirna is a float due to the missing 240 | test values but it should be treated as an int. 241 | 242 | Parameters 243 | ---------- 244 | base_path : str 245 | where the metadata files from Kaggle live 246 | include_controls : bool 247 | indicate if you want the controls included in the dataframe 248 | 249 | Returns 250 | ------- 251 | pandas.DataFrame the combined metadata 252 | """ 253 | df = pd.concat( 254 | [ 255 | _load_dataset( 256 | base_path, dataset, include_controls=include_controls) 257 | for dataset in ['test', 'train'] 258 | ], 259 | sort=True) 260 | return df 261 | -------------------------------------------------------------------------------- /rxrx/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Train a ResNet-50 model on RxRx1 on TPU. 16 | 17 | Original file: 18 | https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_main.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import os 27 | import time 28 | import argparse 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | from tensorflow.contrib import summary 33 | from tensorflow.python.estimator import estimator 34 | 35 | from rxrx import input as rxinput 36 | from rxrx.official_resnet import resnet_v1 37 | 38 | DEFAULT_INPUT_FN_PARAMS = { 39 | 'tfrecord_dataset_buffer_size': 256, 40 | 'tfrecord_dataset_num_parallel_reads': None, 41 | 'parallel_interleave_cycle_length': 32, 42 | 'parallel_interleave_block_length': 1, 43 | 'parallel_interleave_buffer_output_elements': None, 44 | 'parallel_interleave_prefetch_input_elements': None, 45 | 'map_and_batch_num_parallel_calls': 128, 46 | 'transpose_num_parallel_calls': 128, 47 | 'prefetch_buffer_size': tf.contrib.data.AUTOTUNE, 48 | } 49 | 50 | # The mean and stds for each of the channels 51 | GLOBAL_PIXEL_STATS = (np.array([6.74696984, 14.74640167, 10.51260864, 52 | 10.45369445, 5.49959796, 9.81545561]), 53 | np.array([7.95876312, 12.17305868, 5.86172946, 54 | 7.83451711, 4.701167, 5.43130431])) 55 | 56 | 57 | def resnet_model_fn(features, labels, mode, params, n_classes, num_train_images, 58 | data_format, transpose_input, train_batch_size, 59 | momentum, weight_decay, base_learning_rate, warmup_epochs, 60 | use_tpu, iterations_per_loop, model_dir, tf_precision, 61 | resnet_depth): 62 | """The model_fn for ResNet to be used with TPUEstimator. 63 | 64 | Args: 65 | features: `Tensor` of batched images 66 | labels: `Tensor` of labels for the data samples 67 | mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` 68 | params: `dict` of parameters passed to the model from the TPUEstimator, 69 | `params['batch_size']` is always provided and should be used as the 70 | effective batch size. 71 | 72 | 73 | Returns: 74 | A `TPUEstimatorSpec` for the model 75 | """ 76 | if isinstance(features, dict): 77 | features = features['feature'] 78 | 79 | # In most cases, the default data format NCHW instead of NHWC should be 80 | # used for a significant performance boost on GPU/TPU. NHWC should be used 81 | # only if the network needs to be run on CPU since the pooling operations 82 | # are only supported on NHWC. 83 | if data_format == 'channels_first': 84 | assert not transpose_input # channels_first only for GPU 85 | features = tf.transpose(features, [0, 3, 1, 2]) 86 | 87 | if transpose_input and mode != tf.estimator.ModeKeys.PREDICT: 88 | features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC 89 | 90 | # This nested function allows us to avoid duplicating the logic which 91 | # builds the network, for different values of --precision. 92 | def build_network(): 93 | network = resnet_v1( 94 | resnet_depth=resnet_depth, 95 | num_classes=n_classes, 96 | data_format=data_format) 97 | return network( 98 | inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) 99 | 100 | if tf_precision == 'bfloat16': 101 | with tf.contrib.tpu.bfloat16_scope(): 102 | logits = build_network() 103 | logits = tf.cast(logits, tf.float32) 104 | elif tf_precision == 'float32': 105 | logits = build_network() 106 | 107 | if mode == tf.estimator.ModeKeys.PREDICT: 108 | predictions = { 109 | 'classes': tf.argmax(logits, axis=1), 110 | 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 111 | } 112 | return tf.estimator.EstimatorSpec( 113 | mode=mode, 114 | predictions=predictions, 115 | export_outputs={ 116 | 'classify': tf.estimator.export.PredictOutput(predictions) 117 | }) 118 | 119 | # If necessary, in the model_fn, use params['batch_size'] instead the batch 120 | # size flags (--train_batch_size or --eval_batch_size). 121 | batch_size = params['batch_size'] # pylint: disable=unused-variable 122 | 123 | # Calculate loss, which includes softmax cross entropy and L2 regularization. 124 | one_hot_labels = tf.one_hot(labels, n_classes) 125 | cross_entropy = tf.losses.softmax_cross_entropy( 126 | logits=logits, 127 | onehot_labels=one_hot_labels) 128 | 129 | # Add weight decay to the loss for non-batch-normalization variables. 130 | loss = cross_entropy + weight_decay * tf.add_n([ 131 | tf.nn.l2_loss(v) for v in tf.trainable_variables() 132 | if 'batch_normalization' not in v.name 133 | ]) 134 | 135 | host_call = None 136 | if mode == tf.estimator.ModeKeys.TRAIN: 137 | # Compute the current epoch and associated learning rate from global_step. 138 | global_step = tf.train.get_global_step() 139 | steps_per_epoch = tf.cast(num_train_images / train_batch_size, tf.float32) 140 | current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) 141 | warmup_steps = warmup_epochs * steps_per_epoch 142 | 143 | 144 | period = 10 * steps_per_epoch 145 | learning_rate = tf.train.cosine_decay_restarts(base_learning_rate, 146 | global_step, 147 | period, 148 | t_mul=1.0, 149 | m_mul=1.0, 150 | alpha=0.0, 151 | name=None) 152 | 153 | 154 | 155 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, 156 | momentum=momentum, 157 | use_nesterov=True) 158 | 159 | if use_tpu: 160 | # When using TPU, wrap the optimizer with CrossShardOptimizer which 161 | # handles synchronization details between different TPU cores. To the 162 | # user, this should look like regular synchronous training. 163 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 164 | 165 | # Batch normalization requires UPDATE_OPS to be added as a dependency to 166 | # the train operation. 167 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 168 | with tf.control_dependencies(update_ops): 169 | train_op = optimizer.minimize(loss, global_step) 170 | 171 | 172 | def host_call_fn(gs, loss, lr, ce): 173 | """Training host call. Creates scalar summaries for training metrics. 174 | This function is executed on the CPU and should not directly reference 175 | any Tensors in the rest of the `model_fn`. To pass Tensors from the 176 | model to the `metric_fn`, provide as part of the `host_call`. See 177 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec 178 | for more information. 179 | Arguments should match the list of `Tensor` objects passed as the second 180 | element in the tuple passed to `host_call`. 181 | Args: 182 | gs: `Tensor with shape `[batch]` for the global_step 183 | loss: `Tensor` with shape `[batch]` for the training loss. 184 | lr: `Tensor` with shape `[batch]` for the learning_rate. 185 | ce: `Tensor` with shape `[batch]` for the current_epoch. 186 | Returns: 187 | List of summary ops to run on the CPU host. 188 | """ 189 | gs = gs[0] 190 | # Host call fns are executed FLAGS.iterations_per_loop times after one 191 | # TPU loop is finished, setting max_queue value to the same as number of 192 | # iterations will make the summary writer only flush the data to storage 193 | # once per loop. 194 | with summary.create_file_writer(model_dir, 195 | max_queue=iterations_per_loop).as_default(): 196 | with summary.always_record_summaries(): 197 | summary.scalar('loss', loss[0], step=gs) 198 | summary.scalar('learning_rate', lr[0], step=gs) 199 | summary.scalar('current_epoch', ce[0], step=gs) 200 | return summary.all_summary_ops() 201 | 202 | # To log the loss, current learning rate, and epoch for Tensorboard, the 203 | # summary op needs to be run on the host CPU via host_call. host_call 204 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch 205 | # dimension. These Tensors are implicitly concatenated to 206 | # [params['batch_size']]. 207 | gs_t = tf.reshape(global_step, [1]) 208 | loss_t = tf.reshape(loss, [1]) 209 | lr_t = tf.reshape(learning_rate, [1]) 210 | ce_t = tf.reshape(current_epoch, [1]) 211 | 212 | host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) 213 | 214 | else: 215 | train_op = None 216 | 217 | eval_metrics = None 218 | if mode == tf.estimator.ModeKeys.EVAL: 219 | 220 | def metric_fn(labels, logits): 221 | """Evaluation metric function. Evaluates accuracy. 222 | This function is executed on the CPU and should not directly reference 223 | any Tensors in the rest of the `model_fn`. To pass Tensors from the model 224 | to the `metric_fn`, provide as part of the `eval_metrics`. See 225 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec 226 | for more information. 227 | Arguments should match the list of `Tensor` objects passed as the second 228 | element in the tuple passed to `eval_metrics`. 229 | Args: 230 | labels: `Tensor` with shape `[batch]`. 231 | logits: `Tensor` with shape `[batch, num_classes]`. 232 | Returns: 233 | A dict of the metrics to return from evaluation. 234 | """ 235 | predictions = tf.argmax(logits, axis=1) 236 | top_1_accuracy = tf.metrics.accuracy(labels, predictions) 237 | in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) 238 | top_5_accuracy = tf.metrics.mean(in_top_5) 239 | 240 | return { 241 | 'top_1_accuracy': top_1_accuracy, 242 | 'top_5_accuracy': top_5_accuracy, 243 | } 244 | 245 | eval_metrics = (metric_fn, [labels, logits]) 246 | 247 | return tf.contrib.tpu.TPUEstimatorSpec( 248 | mode=mode, 249 | loss=loss, 250 | train_op=train_op, 251 | host_call=host_call, 252 | eval_metrics=eval_metrics) 253 | 254 | def main(use_tpu, 255 | tpu, 256 | gcp_project, 257 | tpu_zone, 258 | url_base_path, 259 | use_cache, 260 | model_dir, 261 | train_epochs, 262 | train_batch_size, 263 | num_train_images, 264 | epochs_per_loop, 265 | log_step_count_epochs, 266 | num_cores, 267 | data_format, 268 | transpose_input, 269 | tf_precision, 270 | n_classes, 271 | momentum, 272 | weight_decay, 273 | base_learning_rate, 274 | warmup_epochs, 275 | input_fn_params=DEFAULT_INPUT_FN_PARAMS, 276 | resnet_depth=50): 277 | 278 | if use_tpu & (tpu is None): 279 | tpu = os.getenv('TPU_NAME') 280 | tf.logging.info('tpu: {}'.format(tpu)) 281 | if gcp_project is None: 282 | gcp_project = os.getenv('TPU_PROJECT') 283 | tf.logging.info('gcp_project: {}'.format(gcp_project)) 284 | 285 | steps_per_epoch = (num_train_images // train_batch_size) 286 | train_steps = steps_per_epoch * train_epochs 287 | current_step = estimator._load_global_step_from_checkpoint_dir(model_dir) # pylint: disable=protected-access,line-too-long 288 | iterations_per_loop = steps_per_epoch * epochs_per_loop 289 | log_step_count_steps = steps_per_epoch * log_step_count_epochs 290 | 291 | 292 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 293 | tpu if (tpu or use_tpu) else '', zone=tpu_zone, project=gcp_project) 294 | 295 | 296 | config = tf.contrib.tpu.RunConfig( 297 | cluster=tpu_cluster_resolver, 298 | model_dir=model_dir, 299 | save_summary_steps=iterations_per_loop, 300 | save_checkpoints_steps=iterations_per_loop, 301 | log_step_count_steps=log_step_count_steps, 302 | tpu_config=tf.contrib.tpu.TPUConfig( 303 | iterations_per_loop=iterations_per_loop, 304 | num_shards=num_cores, 305 | per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. 306 | PER_HOST_V2)) # pylint: disable=line-too-long 307 | 308 | model_fn = functools.partial( 309 | resnet_model_fn, 310 | n_classes=n_classes, 311 | num_train_images=num_train_images, 312 | data_format=data_format, 313 | transpose_input=transpose_input, 314 | train_batch_size=train_batch_size, 315 | iterations_per_loop=iterations_per_loop, 316 | tf_precision=tf_precision, 317 | momentum=momentum, 318 | weight_decay=weight_decay, 319 | base_learning_rate=base_learning_rate, 320 | warmup_epochs=warmup_epochs, 321 | model_dir=model_dir, 322 | use_tpu=use_tpu, 323 | resnet_depth=resnet_depth) 324 | 325 | 326 | resnet_classifier = tf.contrib.tpu.TPUEstimator( 327 | use_tpu=use_tpu, 328 | model_fn=model_fn, 329 | config=config, 330 | train_batch_size=train_batch_size, 331 | export_to_tpu=False) 332 | 333 | 334 | use_bfloat16 = (tf_precision == 'bfloat16') 335 | 336 | train_glob = os.path.join(url_base_path, 'train', '*.tfrecord') 337 | 338 | tf.logging.info("Train glob: {}".format(train_glob)) 339 | 340 | train_input_fn = functools.partial(rxinput.input_fn, 341 | input_fn_params=input_fn_params, 342 | tf_records_glob=train_glob, 343 | pixel_stats=GLOBAL_PIXEL_STATS, 344 | transpose_input=transpose_input, 345 | use_bfloat16=use_bfloat16) 346 | 347 | 348 | 349 | tf.logging.info('Training for %d steps (%.2f epochs in total). Current' 350 | ' step %d.', train_steps, train_steps / steps_per_epoch, 351 | current_step) 352 | 353 | start_timestamp = time.time() # This time will include compilation time 354 | 355 | resnet_classifier.train(input_fn=train_input_fn, max_steps=train_steps) 356 | 357 | tf.logging.info('Finished training up to step %d. Elapsed seconds %d.', 358 | train_steps, int(time.time() - start_timestamp)) 359 | 360 | 361 | elapsed_time = int(time.time() - start_timestamp) 362 | tf.logging.info('Finished training up to step %d. Elapsed seconds %d.', 363 | train_steps, elapsed_time) 364 | 365 | tf.logging.info('Exporting SavedModel.') 366 | 367 | def serving_input_receiver_fn(): 368 | features = { 369 | 'feature': tf.placeholder(dtype=tf.float32, shape=[None, 512, 512, 6]), 370 | } 371 | receiver_tensors = features 372 | return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) 373 | 374 | resnet_classifier.export_saved_model(os.path.join(model_dir, 'saved_model'), serving_input_receiver_fn) 375 | 376 | 377 | if __name__ == '__main__': 378 | 379 | p = argparse.ArgumentParser(description='Train ResNet on rxrx1') 380 | # TPU Parameters 381 | p.add_argument( 382 | '--use-tpu', 383 | type=bool, 384 | default=True, 385 | help=('Use TPU to execute the model for training and evaluation. If' 386 | ' --use_tpu=false, will use whatever devices are available to' 387 | ' TensorFlow by default (e.g. CPU and GPU)')) 388 | p.add_argument( 389 | '--tpu', 390 | type=str, 391 | default=None, 392 | help=( 393 | 'The Cloud TPU to use for training.' 394 | ' This should be either the name used when creating the Cloud TPU, ' 395 | 'or a grpc://ip.address.of.tpu:8470 url.')) 396 | p.add_argument( 397 | '--gcp-project', 398 | type=str, 399 | default=None, 400 | help=('Project name for the Cloud TPU-enabled project. ' 401 | 'If not specified, we will attempt to automatically ' 402 | 'detect the GCE project from metadata.')) 403 | p.add_argument( 404 | '--tpu-zone', 405 | type=str, 406 | default=None, 407 | help=('GCE zone where the Cloud TPU is located in. ' 408 | 'If not specified, we will attempt to automatically ' 409 | 'detect the GCE project from metadata.')) 410 | p.add_argument('--use-cache', type=bool, default=None) 411 | # Dataset Parameters 412 | p.add_argument( 413 | '--url-base-path', 414 | type=str, 415 | default='gs://rxrx1-us-central1/tfrecords/random-42', 416 | help=('Base path for tfrecord storage bucket url.')) 417 | # Training parameters 418 | p.add_argument( 419 | '--model-dir', 420 | type=str, 421 | default=None, 422 | help=( 423 | 'The Google Cloud Storage bucket where the model and training summaries are' 424 | ' stored.')) 425 | p.add_argument( 426 | '--train-epochs', 427 | type=int, 428 | default=1, 429 | help=( 430 | 'Defining an epoch as one pass through every training example, ' 431 | 'the number of total passes through all examples during training. ' 432 | 'Implicitly sets the total train steps.')) 433 | p.add_argument( 434 | '--num-train-images', 435 | type=int, 436 | default=73000 437 | ) 438 | p.add_argument( 439 | '--train-batch-size', 440 | type=int, 441 | default=512, 442 | help=('Batch size to use during training.')) 443 | p.add_argument( 444 | '--n-classes', 445 | type=int, 446 | default=1108, 447 | help=('The number of label classes - typically will be 1108 ' 448 | 'since there are 1108 experimental siRNA classes.')) 449 | p.add_argument( 450 | '--epochs-per-loop', 451 | type=int, 452 | default=1, 453 | help=('The number of steps to run on TPU before outfeeding metrics ' 454 | 'to the CPU. Larger values will speed up training.')) 455 | p.add_argument( 456 | '--log-step-count-epochs', 457 | type=int, 458 | default=64, 459 | help=('The number of epochs at ' 460 | 'which global step information is logged .')) 461 | p.add_argument( 462 | '--num-cores', 463 | type=int, 464 | default=8, 465 | help=('Number of TPU cores. For a single TPU device, this is 8 because ' 466 | 'each TPU has 4 chips each with 2 cores.')) 467 | p.add_argument( 468 | '--data-format', 469 | type=str, 470 | default='channels_last', 471 | choices=[ 472 | 'channels_first', 473 | 'channels_last', 474 | ], 475 | help=('A flag to override the data format used in the model. ' 476 | 'To run on CPU or TPU, channels_last should be used. ' 477 | 'For GPU, channels_first will improve performance.')) 478 | p.add_argument( 479 | '--transpose-input', 480 | type=bool, 481 | default=True, 482 | help=('Use TPU double transpose optimization.')) 483 | p.add_argument( 484 | '--tf-precision', 485 | type=str, 486 | default='bfloat16', 487 | choices=['bfloat16', 'float32'], 488 | help=('Tensorflow precision type used when defining the network.')) 489 | 490 | # Optimizer Parameters 491 | 492 | p.add_argument('--momentum', type=float, default=0.9) 493 | p.add_argument('--weight-decay', type=float, default=1e-4) 494 | p.add_argument( 495 | '--base-learning-rate', 496 | type=float, 497 | default=0.2, 498 | help=('Base learning rate when train batch size is 512. ' 499 | 'Chosen to match the resnet paper.')) 500 | p.add_argument( 501 | '--warmup-epochs', 502 | type=int, 503 | default=5, 504 | ) 505 | args = p.parse_args() 506 | args = vars(args) 507 | tf.logging.set_verbosity(tf.logging.INFO) 508 | tf.logging.info('Parsed args: ') 509 | for k, v in args.items(): 510 | tf.logging.info('{} : {}'.format(k, v)) 511 | main(**args) 512 | -------------------------------------------------------------------------------- /rxrx/official_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the post-activation form of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | 21 | Original file: 22 | https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import tensorflow as tf 30 | 31 | BATCH_NORM_DECAY = 0.9 32 | BATCH_NORM_EPSILON = 1e-5 33 | 34 | 35 | def batch_norm_relu(inputs, is_training, relu=True, init_zero=False, 36 | data_format='channels_first'): 37 | """Performs a batch normalization followed by a ReLU. 38 | 39 | Args: 40 | inputs: `Tensor` of shape `[batch, channels, ...]`. 41 | is_training: `bool` for whether the model is training. 42 | relu: `bool` if False, omits the ReLU operation. 43 | init_zero: `bool` if True, initializes scale parameter of batch 44 | normalization with 0 instead of 1 (default). 45 | data_format: `str` either "channels_first" for `[batch, channels, height, 46 | width]` or "channels_last for `[batch, height, width, channels]`. 47 | 48 | Returns: 49 | A normalized `Tensor` with the same `data_format`. 50 | """ 51 | if init_zero: 52 | gamma_initializer = tf.zeros_initializer() 53 | else: 54 | gamma_initializer = tf.ones_initializer() 55 | 56 | if data_format == 'channels_first': 57 | axis = 1 58 | else: 59 | axis = 3 60 | 61 | inputs = tf.layers.batch_normalization( 62 | inputs=inputs, 63 | axis=axis, 64 | momentum=BATCH_NORM_DECAY, 65 | epsilon=BATCH_NORM_EPSILON, 66 | center=True, 67 | scale=True, 68 | training=is_training, 69 | fused=True, 70 | gamma_initializer=gamma_initializer) 71 | 72 | if relu: 73 | inputs = tf.nn.relu(inputs) 74 | return inputs 75 | 76 | 77 | def dropblock(net, is_training, keep_prob, dropblock_size, 78 | data_format='channels_first'): 79 | """DropBlock: a regularization method for convolutional neural networks. 80 | 81 | DropBlock is a form of structured dropout, where units in a contiguous 82 | region of a feature map are dropped together. DropBlock works better than 83 | dropout on convolutional layers due to the fact that activation units in 84 | convolutional layers are spatially correlated. 85 | See https://arxiv.org/pdf/1810.12890.pdf for details. 86 | 87 | Args: 88 | net: `Tensor` input tensor. 89 | is_training: `bool` for whether the model is training. 90 | keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. "None" 91 | means no DropBlock. 92 | dropblock_size: `int` size of blocks to be dropped by DropBlock. 93 | data_format: `str` either "channels_first" for `[batch, channels, height, 94 | width]` or "channels_last for `[batch, height, width, channels]`. 95 | Returns: 96 | A version of input tensor with DropBlock applied. 97 | Raises: 98 | if width and height of the input tensor are not equal. 99 | """ 100 | 101 | if not is_training or keep_prob is None: 102 | return net 103 | 104 | tf.logging.info('Applying DropBlock: dropblock_size {}, net.shape {}'.format( 105 | dropblock_size, net.shape)) 106 | 107 | if data_format == 'channels_last': 108 | _, width, height, _ = net.get_shape().as_list() 109 | else: 110 | _, _, width, height = net.get_shape().as_list() 111 | if width != height: 112 | raise ValueError('Input tensor with width!=height is not supported.') 113 | 114 | dropblock_size = min(dropblock_size, width) 115 | # seed_drop_rate is the gamma parameter of DropBlcok. 116 | seed_drop_rate = (1.0 - keep_prob) * width**2 / dropblock_size**2 / ( 117 | width - dropblock_size + 1)**2 118 | 119 | # Forces the block to be inside the feature map. 120 | w_i, h_i = tf.meshgrid(tf.range(width), tf.range(width)) 121 | valid_block_center = tf.logical_and( 122 | tf.logical_and(w_i >= int(dropblock_size // 2), 123 | w_i < width - (dropblock_size - 1) // 2), 124 | tf.logical_and(h_i >= int(dropblock_size // 2), 125 | h_i < width - (dropblock_size - 1) // 2)) 126 | 127 | valid_block_center = tf.expand_dims(valid_block_center, 0) 128 | valid_block_center = tf.expand_dims( 129 | valid_block_center, -1 if data_format == 'channels_last' else 0) 130 | 131 | randnoise = tf.random_uniform(net.shape, dtype=tf.float32) 132 | block_pattern = (1 - tf.cast(valid_block_center, dtype=tf.float32) + tf.cast( 133 | (1 - seed_drop_rate), dtype=tf.float32) + randnoise) >= 1 134 | block_pattern = tf.cast(block_pattern, dtype=tf.float32) 135 | 136 | if dropblock_size == width: 137 | block_pattern = tf.reduce_min( 138 | block_pattern, 139 | axis=[1, 2] if data_format == 'channels_last' else [2, 3], 140 | keepdims=True) 141 | else: 142 | if data_format == 'channels_last': 143 | ksize = [1, dropblock_size, dropblock_size, 1] 144 | else: 145 | ksize = [1, 1, dropblock_size, dropblock_size] 146 | block_pattern = -tf.nn.max_pool( 147 | -block_pattern, ksize=ksize, strides=[1, 1, 1, 1], padding='SAME', 148 | data_format='NHWC' if data_format == 'channels_last' else 'NCHW') 149 | 150 | percent_ones = tf.cast(tf.reduce_sum((block_pattern)), tf.float32) / tf.cast( 151 | tf.size(block_pattern), tf.float32) 152 | 153 | net = net / tf.cast(percent_ones, net.dtype) * tf.cast( 154 | block_pattern, net.dtype) 155 | return net 156 | 157 | 158 | def fixed_padding(inputs, kernel_size, data_format='channels_first'): 159 | """Pads the input along the spatial dimensions independently of input size. 160 | 161 | Args: 162 | inputs: `Tensor` of size `[batch, channels, height, width]` or 163 | `[batch, height, width, channels]` depending on `data_format`. 164 | kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d` 165 | operations. Should be a positive integer. 166 | data_format: `str` either "channels_first" for `[batch, channels, height, 167 | width]` or "channels_last for `[batch, height, width, channels]`. 168 | 169 | Returns: 170 | A padded `Tensor` of the same `data_format` with size either intact 171 | (if `kernel_size == 1`) or padded (if `kernel_size > 1`). 172 | """ 173 | pad_total = kernel_size - 1 174 | pad_beg = pad_total // 2 175 | pad_end = pad_total - pad_beg 176 | if data_format == 'channels_first': 177 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], 178 | [pad_beg, pad_end], [pad_beg, pad_end]]) 179 | else: 180 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 181 | [pad_beg, pad_end], [0, 0]]) 182 | 183 | return padded_inputs 184 | 185 | 186 | def conv2d_fixed_padding(inputs, filters, kernel_size, strides, 187 | data_format='channels_first'): 188 | """Strided 2-D convolution with explicit padding. 189 | 190 | The padding is consistent and is based only on `kernel_size`, not on the 191 | dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). 192 | 193 | Args: 194 | inputs: `Tensor` of size `[batch, channels, height_in, width_in]`. 195 | filters: `int` number of filters in the convolution. 196 | kernel_size: `int` size of the kernel to be used in the convolution. 197 | strides: `int` strides of the convolution. 198 | data_format: `str` either "channels_first" for `[batch, channels, height, 199 | width]` or "channels_last for `[batch, height, width, channels]`. 200 | 201 | Returns: 202 | A `Tensor` of shape `[batch, filters, height_out, width_out]`. 203 | """ 204 | if strides > 1: 205 | inputs = fixed_padding(inputs, kernel_size, data_format=data_format) 206 | 207 | return tf.layers.conv2d( 208 | inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, 209 | padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, 210 | kernel_initializer=tf.variance_scaling_initializer(), 211 | data_format=data_format) 212 | 213 | 214 | def residual_block(inputs, filters, is_training, strides, 215 | use_projection=False, data_format='channels_first', 216 | dropblock_keep_prob=None, dropblock_size=None): 217 | """Standard building block for residual networks with BN after convolutions. 218 | 219 | Args: 220 | inputs: `Tensor` of size `[batch, channels, height, width]`. 221 | filters: `int` number of filters for the first two convolutions. Note that 222 | the third and final convolution will use 4 times as many filters. 223 | is_training: `bool` for whether the model is in training. 224 | strides: `int` block stride. If greater than 1, this block will ultimately 225 | downsample the input. 226 | use_projection: `bool` for whether this block should use a projection 227 | shortcut (versus the default identity shortcut). This is usually `True` 228 | for the first block of a block group, which may change the number of 229 | filters and the resolution. 230 | data_format: `str` either "channels_first" for `[batch, channels, height, 231 | width]` or "channels_last for `[batch, height, width, channels]`. 232 | dropblock_keep_prob: unused; needed to give method same signature as other 233 | blocks 234 | dropblock_size: unused; needed to give method same signature as other 235 | blocks 236 | Returns: 237 | The output `Tensor` of the block. 238 | """ 239 | del dropblock_keep_prob 240 | del dropblock_size 241 | shortcut = inputs 242 | if use_projection: 243 | # Projection shortcut in first layer to match filters and strides 244 | shortcut = conv2d_fixed_padding( 245 | inputs=inputs, filters=filters, kernel_size=1, strides=strides, 246 | data_format=data_format) 247 | shortcut = batch_norm_relu(shortcut, is_training, relu=False, 248 | data_format=data_format) 249 | 250 | inputs = conv2d_fixed_padding( 251 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 252 | data_format=data_format) 253 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 254 | 255 | inputs = conv2d_fixed_padding( 256 | inputs=inputs, filters=filters, kernel_size=3, strides=1, 257 | data_format=data_format) 258 | inputs = batch_norm_relu(inputs, is_training, relu=False, init_zero=True, 259 | data_format=data_format) 260 | 261 | return tf.nn.relu(inputs + shortcut) 262 | 263 | 264 | def bottleneck_block(inputs, filters, is_training, strides, 265 | use_projection=False, data_format='channels_first', 266 | dropblock_keep_prob=None, dropblock_size=None): 267 | """Bottleneck block variant for residual networks with BN after convolutions. 268 | 269 | Args: 270 | inputs: `Tensor` of size `[batch, channels, height, width]`. 271 | filters: `int` number of filters for the first two convolutions. Note that 272 | the third and final convolution will use 4 times as many filters. 273 | is_training: `bool` for whether the model is in training. 274 | strides: `int` block stride. If greater than 1, this block will ultimately 275 | downsample the input. 276 | use_projection: `bool` for whether this block should use a projection 277 | shortcut (versus the default identity shortcut). This is usually `True` 278 | for the first block of a block group, which may change the number of 279 | filters and the resolution. 280 | data_format: `str` either "channels_first" for `[batch, channels, height, 281 | width]` or "channels_last for `[batch, height, width, channels]`. 282 | dropblock_keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. 283 | "None" means no DropBlock. 284 | dropblock_size: `int` size parameter of DropBlock. Will not be used if 285 | dropblock_keep_prob is "None". 286 | 287 | Returns: 288 | The output `Tensor` of the block. 289 | """ 290 | shortcut = inputs 291 | if use_projection: 292 | # Projection shortcut only in first block within a group. Bottleneck blocks 293 | # end with 4 times the number of filters. 294 | filters_out = 4 * filters 295 | shortcut = conv2d_fixed_padding( 296 | inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, 297 | data_format=data_format) 298 | shortcut = batch_norm_relu(shortcut, is_training, relu=False, 299 | data_format=data_format) 300 | shortcut = dropblock( 301 | shortcut, is_training=is_training, data_format=data_format, 302 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 303 | 304 | inputs = conv2d_fixed_padding( 305 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 306 | data_format=data_format) 307 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 308 | inputs = dropblock( 309 | inputs, is_training=is_training, data_format=data_format, 310 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 311 | 312 | inputs = conv2d_fixed_padding( 313 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 314 | data_format=data_format) 315 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 316 | inputs = dropblock( 317 | inputs, is_training=is_training, data_format=data_format, 318 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 319 | 320 | inputs = conv2d_fixed_padding( 321 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, 322 | data_format=data_format) 323 | inputs = batch_norm_relu(inputs, is_training, relu=False, init_zero=True, 324 | data_format=data_format) 325 | inputs = dropblock( 326 | inputs, is_training=is_training, data_format=data_format, 327 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 328 | 329 | return tf.nn.relu(inputs + shortcut) 330 | 331 | 332 | def block_group(inputs, filters, block_fn, blocks, strides, is_training, name, 333 | data_format='channels_first', dropblock_keep_prob=None, 334 | dropblock_size=None): 335 | """Creates one group of blocks for the ResNet model. 336 | 337 | Args: 338 | inputs: `Tensor` of size `[batch, channels, height, width]`. 339 | filters: `int` number of filters for the first convolution of the layer. 340 | block_fn: `function` for the block to use within the model 341 | blocks: `int` number of blocks contained in the layer. 342 | strides: `int` stride to use for the first convolution of the layer. If 343 | greater than 1, this layer will downsample the input. 344 | is_training: `bool` for whether the model is training. 345 | name: `str`name for the Tensor output of the block layer. 346 | data_format: `str` either "channels_first" for `[batch, channels, height, 347 | width]` or "channels_last for `[batch, height, width, channels]`. 348 | dropblock_keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. 349 | "None" means no DropBlock. 350 | dropblock_size: `int` size parameter of DropBlock. Will not be used if 351 | dropblock_keep_prob is "None". 352 | 353 | Returns: 354 | The output `Tensor` of the block layer. 355 | """ 356 | # Only the first block per block_group uses projection shortcut and strides. 357 | inputs = block_fn(inputs, filters, is_training, strides, 358 | use_projection=True, data_format=data_format, 359 | dropblock_keep_prob=dropblock_keep_prob, 360 | dropblock_size=dropblock_size) 361 | 362 | for _ in range(1, blocks): 363 | inputs = block_fn(inputs, filters, is_training, 1, 364 | data_format=data_format, 365 | dropblock_keep_prob=dropblock_keep_prob, 366 | dropblock_size=dropblock_size) 367 | 368 | return tf.identity(inputs, name) 369 | 370 | 371 | def resnet_v1_generator(block_fn, layers, num_classes, 372 | data_format='channels_first', dropblock_keep_probs=None, 373 | dropblock_size=None): 374 | """Generator for ResNet v1 models. 375 | 376 | Args: 377 | block_fn: `function` for the block to use within the model. Either 378 | `residual_block` or `bottleneck_block`. 379 | layers: list of 4 `int`s denoting the number of blocks to include in each 380 | of the 4 block groups. Each group consists of blocks that take inputs of 381 | the same resolution. 382 | num_classes: `int` number of possible classes for image classification. 383 | data_format: `str` either "channels_first" for `[batch, channels, height, 384 | width]` or "channels_last for `[batch, height, width, channels]`. 385 | dropblock_keep_probs: `list` of 4 elements denoting keep_prob of DropBlock 386 | for each block group. None indicates no DropBlock for the corresponding 387 | block group. 388 | dropblock_size: `int`: size parameter of DropBlock. 389 | 390 | Returns: 391 | Model `function` that takes in `inputs` and `is_training` and returns the 392 | output `Tensor` of the ResNet model. 393 | 394 | Raises: 395 | if dropblock_keep_probs is not 'None' or a list with len 4. 396 | """ 397 | if dropblock_keep_probs is None: 398 | dropblock_keep_probs = [None] * 4 399 | if not isinstance(dropblock_keep_probs, 400 | list) or len(dropblock_keep_probs) != 4: 401 | raise ValueError('dropblock_keep_probs is not valid:', dropblock_keep_probs) 402 | 403 | def model(inputs, is_training): 404 | """Creation of the model graph.""" 405 | inputs = conv2d_fixed_padding( 406 | inputs=inputs, filters=64, kernel_size=7, strides=2, 407 | data_format=data_format) 408 | inputs = tf.identity(inputs, 'initial_conv') 409 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 410 | 411 | inputs = tf.layers.max_pooling2d( 412 | inputs=inputs, pool_size=3, strides=2, padding='SAME', 413 | data_format=data_format) 414 | inputs = tf.identity(inputs, 'initial_max_pool') 415 | 416 | inputs = block_group( 417 | inputs=inputs, filters=64, block_fn=block_fn, blocks=layers[0], 418 | strides=1, is_training=is_training, name='block_group1', 419 | data_format=data_format, dropblock_keep_prob=dropblock_keep_probs[0], 420 | dropblock_size=dropblock_size) 421 | inputs = block_group( 422 | inputs=inputs, filters=128, block_fn=block_fn, blocks=layers[1], 423 | strides=2, is_training=is_training, name='block_group2', 424 | data_format=data_format, dropblock_keep_prob=dropblock_keep_probs[1], 425 | dropblock_size=dropblock_size) 426 | inputs = block_group( 427 | inputs=inputs, filters=256, block_fn=block_fn, blocks=layers[2], 428 | strides=2, is_training=is_training, name='block_group3', 429 | data_format=data_format, dropblock_keep_prob=dropblock_keep_probs[2], 430 | dropblock_size=dropblock_size) 431 | inputs = block_group( 432 | inputs=inputs, filters=512, block_fn=block_fn, blocks=layers[3], 433 | strides=2, is_training=is_training, name='block_group4', 434 | data_format=data_format, dropblock_keep_prob=dropblock_keep_probs[3], 435 | dropblock_size=dropblock_size) 436 | 437 | # The activation is 7x7 so this is a global average pool. 438 | # TODO(huangyp): reduce_mean will be faster. 439 | pool_size = (inputs.shape[1], inputs.shape[2]) 440 | inputs = tf.layers.average_pooling2d( 441 | inputs=inputs, pool_size=pool_size, strides=1, padding='VALID', 442 | data_format=data_format) 443 | inputs = tf.identity(inputs, 'final_avg_pool') 444 | inputs = tf.reshape( 445 | inputs, [-1, 2048 if block_fn is bottleneck_block else 512]) 446 | inputs = tf.layers.dense( 447 | inputs=inputs, 448 | units=num_classes, 449 | kernel_initializer=tf.random_normal_initializer(stddev=.01)) 450 | inputs = tf.identity(inputs, 'final_dense') 451 | return inputs 452 | 453 | model.default_image_size = 224 454 | return model 455 | 456 | 457 | def resnet_v1(resnet_depth, num_classes, data_format='channels_first', 458 | dropblock_keep_probs=None, dropblock_size=None): 459 | """Returns the ResNet model for a given size and number of output classes.""" 460 | model_params = { 461 | 18: {'block': residual_block, 'layers': [2, 2, 2, 2]}, 462 | 34: {'block': residual_block, 'layers': [3, 4, 6, 3]}, 463 | 50: {'block': bottleneck_block, 'layers': [3, 4, 6, 3]}, 464 | 101: {'block': bottleneck_block, 'layers': [3, 4, 23, 3]}, 465 | 152: {'block': bottleneck_block, 'layers': [3, 8, 36, 3]}, 466 | 200: {'block': bottleneck_block, 'layers': [3, 24, 36, 3]} 467 | } 468 | 469 | if resnet_depth not in model_params: 470 | raise ValueError('Not a valid resnet_depth:', resnet_depth) 471 | 472 | params = model_params[resnet_depth] 473 | return resnet_v1_generator( 474 | params['block'], params['layers'], num_classes, 475 | dropblock_keep_probs=dropblock_keep_probs, dropblock_size=dropblock_size, 476 | data_format=data_format) 477 | -------------------------------------------------------------------------------- /rxrx/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recursionpharma/rxrx1-utils/31e6f9b826c6de84f531cd2b4f67bfc158036d7d/rxrx/preprocess/__init__.py -------------------------------------------------------------------------------- /rxrx/preprocess/images2tfrecords.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import shutil 5 | import tempfile 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import tensorflow as tf 10 | import toolz as t 11 | 12 | from .. import io as rio 13 | from .. import utils as rutils 14 | from ..io import DEFAULT_CHANNELS 15 | 16 | TFRECORD_COMPRESSION = tf.python_io.TFRecordCompressionType.GZIP 17 | TFRECORD_OPTIONS = tf.python_io.TFRecordOptions(TFRECORD_COMPRESSION) 18 | VALID_DATASETS = {'train', 'test'} 19 | VALID_STRATEGIES = {'random', 'by_exp_plate_site'} 20 | 21 | ### TensorFlow Helpers 22 | 23 | 24 | def bytes_feature(value): 25 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 26 | 27 | 28 | def string_feature(value): 29 | return bytes_feature(value.encode('utf-8')) 30 | 31 | 32 | def int64_feature(value): 33 | return tf.train.Feature(int64_list=tf.train.Int64List( 34 | value=rutils.wrap(value))) 35 | 36 | 37 | def float_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList( 39 | value=ruitls.wrap(value))) 40 | 41 | 42 | ### Conversion to TFExample and TFRecord logic 43 | 44 | 45 | def dict_to_tfexample(site): 46 | """ 47 | Takes a dictionary of a site with all the metadata and the `image` data. 48 | 49 | Returns a TFExample 50 | """ 51 | 52 | features = { 53 | 'image': bytes_feature(site['image'].tostring()), 54 | 'well': string_feature(site['well']), 55 | 'well_type': string_feature(site['well_type']), 56 | 'experiment': string_feature(site['experiment']), 57 | 'plate': int64_feature(site['plate']), 58 | 'site': int64_feature(site['site']), 59 | 'cell_type': string_feature(site['cell_type']) 60 | } 61 | 62 | # Handle case where sirna is not known (test) 63 | if site["sirna"] is not None: 64 | features["sirna"] = int64_feature(site["sirna"]) 65 | 66 | return tf.train.Example(features=tf.train.Features(feature=features)) 67 | 68 | 69 | def _pack_tfrecord(base_path, 70 | sites, 71 | dest_path, 72 | channels=DEFAULT_CHANNELS): 73 | if not dest_path.startswith('gs://'): 74 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 75 | with tf.python_io.TFRecordWriter( 76 | dest_path, options=TFRECORD_OPTIONS) as writer: 77 | for site in sites: 78 | data = rio.load_site( 79 | base_path=base_path, 80 | channels=channels, 81 | **rutils.select_keys(site, ('dataset', 'experiment', 'plate', 82 | 'well', 'site'))) 83 | example = dict_to_tfexample(t.assoc(site, 'image', data)) 84 | writer.write(example.SerializeToString()) 85 | 86 | 87 | ### Strategies to pack the TFRecords differently and some helper functions 88 | # 89 | # Each strategy takes the metadata DataFrame and returns a list of 90 | # dictionaries containing `dest_path` and `sites` where 91 | # `dest_path` - the full path of the destination TFRecord file 92 | # `sites` - a list of all of the sites that should be packed into the 93 | # destination path. Each `site` is a row, in dictionary form, 94 | # from the metadata dataframe. 95 | # 96 | 97 | 98 | def _dataset_rs_dict(seed): 99 | """Returns a dictionary of random states keyed by dataset. 100 | A seed for every dataset is created regardless of if it will be 101 | processed. This is done to guarantee determinism of the 102 | randomization invariant of what datasets are being processed. 103 | """ 104 | rs = np.random.RandomState(seed) 105 | high = 2**32 - 1 106 | return { 107 | ds: np.random.RandomState(rs.randint(high)) 108 | for ds in sorted(VALID_DATASETS) 109 | } 110 | 111 | 112 | def _correct_sirna_dtype(row): 113 | if np.isnan(row['sirna']): 114 | row['sirna'] = None 115 | else: 116 | row['sirna'] = int(row['sirna']) 117 | return row 118 | 119 | 120 | def _random_partition(metadata_df, 121 | dest_path, 122 | sites_per_tfrecord=308, 123 | random_seed=42): 124 | """ 125 | Randomly partitions each dataset into multiple TFRecords. 126 | """ 127 | # make groupby's determinisic 128 | metadata_df = metadata_df.sort_values( 129 | ['dataset', 'experiment', 'plate', 'well', 'site']) 130 | # get random states to make randomizations determinisic 131 | rs_dict = _dataset_rs_dict(random_seed) 132 | 133 | to_pack = [] 134 | for dataset, df in metadata_df.groupby('dataset'): 135 | df = (df.sort_values(['experiment', 'plate', 'well', 'site']) 136 | .sample(frac=1.0, random_state=rs_dict[dataset])) 137 | rows = [_correct_sirna_dtype(row) for row in df.to_dict(orient='row')] 138 | sites_for_files = t.partition_all(sites_per_tfrecord, rows) 139 | dataset_path = os.path.join(dest_path, 'random-{}'.format(random_seed), dataset) 140 | for file_num, sites in enumerate(sites_for_files, 1): 141 | dest_file = os.path.join(dataset_path, "{:03d}.tfrecord".format(file_num)) 142 | to_pack.append({'dest_path': dest_file, 'sites': sites}) 143 | return to_pack 144 | 145 | 146 | def _by_exp_plate_site(metadata_df, dest_path, random_seed=42): 147 | """ 148 | Groups by experiment, plate, and packs each site into individual TFRecords. 149 | """ 150 | # make groupby's determinisic 151 | metadata_df = metadata_df.sort_values( 152 | ['dataset', 'experiment', 'plate', 'well', 'site']) 153 | # get random states to make randomizations determinisic 154 | rs_dict = _dataset_rs_dict(random_seed) 155 | 156 | to_pack = [] 157 | for (dataset, exp, plate, site), df in metadata_df.groupby( 158 | ['dataset', 'experiment', 'plate', 'site']): 159 | df = (df.sort_values(['experiment', 'plate', 'well', 'site']) 160 | .sample(frac=1.0, random_state=rs_dict[dataset])) 161 | rows = [_correct_sirna_dtype(row) for row in df.to_dict(orient='row')] 162 | 163 | dest_file = os.path.join(dest_path, 'by_exp_plate_site-{}'.format(random_seed), 164 | "{}_p{}_s{}.tfrecord".format(exp, plate, site)) 165 | to_pack.append({'dest_path': dest_file, 'sites': rows}) 166 | return to_pack 167 | 168 | 169 | def _sites_df(i, ix): 170 | return pd.DataFrame([i] * len(ix), index=ix, columns=['site']) 171 | 172 | 173 | ### Main entry point and CLI logic 174 | 175 | 176 | def pack_tfrecords(images_path, 177 | metadata_df, 178 | num_workers, 179 | dest_path, 180 | strategies=['random', 'by_exp_plate_site'], 181 | channels=DEFAULT_CHANNELS, 182 | sites_per_tfrecord=308, 183 | random_seeds=[42], 184 | runner='dask', 185 | project=None, 186 | datasets=None): 187 | if datasets is None: 188 | datasets = [ 189 | ds.strip('/') for ds in tf.gfile.ListDirectory(images_path) 190 | if ds.strip('/') in VALID_DATASETS 191 | ] 192 | 193 | # Only consider metadata for the datasets we care about 194 | metadata_df = metadata_df[metadata_df.dataset.isin(datasets)] 195 | # only pack images for the treatment wells, not the controls! 196 | metadata_df = metadata_df[metadata_df.well_type == "treatment"] 197 | 198 | strategies = set(strategies) 199 | 200 | if len(strategies - VALID_STRATEGIES) > 0: 201 | raise ValueError( 202 | 'invalid strategies: {}. You may only provide a subset of {}'.format(strategies, VALID_STRATEGIES) 203 | ) 204 | 205 | to_pack = [] 206 | for random_seed in random_seeds: 207 | if 'random' in strategies: 208 | to_pack += _random_partition( 209 | metadata_df, 210 | dest_path, 211 | random_seed=random_seed, 212 | sites_per_tfrecord=sites_per_tfrecord) 213 | 214 | if 'by_exp_plate_site' in strategies: 215 | to_pack += _by_exp_plate_site( 216 | metadata_df, dest_path, random_seed=random_seed) 217 | 218 | if runner == 'dask': 219 | import dask 220 | import dask.bag 221 | 222 | print('Distributing {} on dask'.format(len(to_pack))) 223 | to_pack_bag = dask.bag.from_sequence(to_pack, npartitions=len(to_pack)) 224 | (to_pack_bag 225 | .map(lambda kw: _pack_tfrecord(base_path=images_path, 226 | channels=channels, 227 | **kw)) 228 | .compute(num_workers=num_workers)) 229 | return [p['dest_path'] for p in to_pack] 230 | else: 231 | print('Distributing {} on {}'.format(len(to_pack), runner)) 232 | run_on_dataflow(to_pack, dest_path, images_path, channels, runner, project) 233 | return None 234 | 235 | 236 | def run_on_dataflow(to_pack, dest_path, images_path, channels, runner, project): 237 | 238 | import apache_beam as beam 239 | 240 | options = { 241 | 'staging_location': 242 | os.path.join(dest_path, 'tmp', 'staging'), 243 | 'temp_location': 244 | os.path.join(dest_path, 'tmp'), 245 | 'job_name': ('rxrx1-' + os.getlogin().replace('.', '-') + '-' + 246 | datetime.datetime.now().strftime('%y%m%d-%H%M%S')), 247 | 'max_num_workers': 248 | 600, # CHANGE AS NEEDED 249 | 'machine_type': 250 | 'n1-standard-4', 251 | 'save_main_session': 252 | True, 253 | 'setup_file': (os.path.join( 254 | os.path.dirname(os.path.abspath(__file__)), '../../setup.py')), 255 | 'runner': 256 | runner, 257 | 'project': 258 | project 259 | } 260 | opts = beam.pipeline.PipelineOptions(flags=[], **options) 261 | with beam.Pipeline(runner, options=opts) as p: 262 | (p 263 | | 'find_images' >> beam.Create(to_pack) 264 | | 'pack' >> beam.FlatMap( 265 | lambda kw: _pack_tfrecord(base_path=images_path, 266 | channels=channels, 267 | **kw)) 268 | ) 269 | if runner == 'dataflow': 270 | print( 271 | 'Submitting job ... Please monitor at https://console.cloud.google.com/dataflow' 272 | ) 273 | 274 | 275 | def cli(): 276 | parser = argparse.ArgumentParser( 277 | formatter_class=argparse.RawTextHelpFormatter, 278 | description="Packs the raw PNG images into TFRecords.") 279 | parser.add_argument("--raw-images", type=str, help="Path of the raw images", 280 | default=rio.DEFAULT_IMAGES_BASE_PATH) 281 | parser.add_argument( 282 | "--metadata", type=str, help="Path to the metadata directory", 283 | default=rio.DEFAULT_METADATA_BASE_PATH) 284 | parser.add_argument( 285 | "--num-workers", 286 | type=int, 287 | default=None, 288 | help="Number of workers to be writing TFRecords. Defaults to number of cores." 289 | ) 290 | parser.add_argument( 291 | "--random-seeds", 292 | type=int, 293 | nargs='+', 294 | default=[42], 295 | help="The seed used to make the sorting determistic. Embedded in the dir name to allow multiple folds to be created." 296 | ) 297 | parser.add_argument( 298 | "--sites-per-tfrecord", 299 | type=int, 300 | default=1500, 301 | help="Only used with the random strategy, indicates how many site images you want in a single TFRecord" 302 | ) 303 | parser.add_argument( 304 | "--strategies", 305 | nargs='+', 306 | choices=VALID_STRATEGIES, 307 | default=['random', 'by_exp_plate_site'], 308 | help="""What strategies to use to pack up the records: 309 | \t`random` - Randomly partitions each dataset into multiple TFRecords. 310 | \t`by_exp_plate_site` - Groups by experiment, plate, and packs each site into individual TFRecords. 311 | """) 312 | parser.add_argument( 313 | "--dest-path", 314 | type=str, 315 | default="./tfrecords", 316 | help="Destination directory of where to write the tfrecords") 317 | parser.add_argument( 318 | "--runner", 319 | type=str, 320 | default="dask", 321 | choices={'dask', 'dataflow'}, 322 | help="Specify one of DirectRunner, dataflow, or dask") 323 | parser.add_argument( 324 | "--project", 325 | type=str, 326 | default=None, 327 | help="If using dataflow, the project to bill") 328 | args = parser.parse_args() 329 | if args.runner == 'dataflow': 330 | if not args.project: 331 | raise ValueError('When using dataflow, you need to specify project') 332 | 333 | metadata_df = rio.combine_metadata(args.metadata) 334 | if args.runner == 'dask': 335 | from dask.diagnostics import ProgressBar 336 | ProgressBar().register() 337 | 338 | pack_tfrecords( 339 | images_path=args.raw_images, 340 | metadata_df=metadata_df, 341 | dest_path=args.dest_path, 342 | strategies=args.strategies, 343 | sites_per_tfrecord=args.sites_per_tfrecord, 344 | random_seeds=args.random_seeds, 345 | num_workers=args.num_workers, 346 | runner=args.runner, 347 | project=args.project) 348 | 349 | 350 | if __name__ == '__main__': 351 | cli() 352 | -------------------------------------------------------------------------------- /rxrx/preprocess/images2zarr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import dask 5 | import dask.bag 6 | import toolz as t 7 | from dask.diagnostics import ProgressBar 8 | 9 | import zarr 10 | 11 | from .. import io as rio 12 | 13 | DEFAULT_COMPRESSION = {"cname": "zstd", "clevel": 3, "shuffle": 2} 14 | 15 | 16 | def zarrify(x, dest, chunk=512, compression=DEFAULT_COMPRESSION): 17 | compressor = None 18 | if compression: 19 | compressor = zarr.Blosc(**compression) 20 | os.makedirs(os.path.dirname(dest), exist_ok=True) 21 | z = zarr.open( 22 | dest, 23 | mode="w", 24 | shape=x.shape, 25 | chunks=(chunk, chunk, None), 26 | dtype="