├── LICENSE.txt ├── README.md ├── dataset_tool_h5.py ├── dnnlib ├── __init__.py ├── submission │ ├── __init__.py │ ├── _internal │ │ └── run.py │ ├── run_context.py │ └── submit.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── network.py │ ├── optimizer.py │ └── tfutil.py └── util.py ├── download_kodak.py ├── img └── readme_figure.png └── selfsupervised_denoising.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | Attribution-NonCommercial 4.0 International 5 | 6 | ======================================================================= 7 | 8 | Creative Commons Corporation ("Creative Commons") is not a law firm and 9 | does not provide legal services or legal advice. Distribution of 10 | Creative Commons public licenses does not create a lawyer-client or 11 | other relationship. Creative Commons makes its licenses and related 12 | information available on an "as-is" basis. Creative Commons gives no 13 | warranties regarding its licenses, any material licensed under their 14 | terms and conditions, or any related information. Creative Commons 15 | disclaims all liability for damages resulting from their use to the 16 | fullest extent possible. 17 | 18 | Using Creative Commons Public Licenses 19 | 20 | Creative Commons public licenses provide a standard set of terms and 21 | conditions that creators and other rights holders may use to share 22 | original works of authorship and other material subject to copyright 23 | and certain other rights specified in the public license below. The 24 | following considerations are for informational purposes only, are not 25 | exhaustive, and do not form part of our licenses. 26 | 27 | Considerations for licensors: Our public licenses are 28 | intended for use by those authorized to give the public 29 | permission to use material in ways otherwise restricted by 30 | copyright and certain other rights. Our licenses are 31 | irrevocable. Licensors should read and understand the terms 32 | and conditions of the license they choose before applying it. 33 | Licensors should also secure all rights necessary before 34 | applying our licenses so that the public can reuse the 35 | material as expected. Licensors should clearly mark any 36 | material not subject to the license. This includes other CC- 37 | licensed material, or material used under an exception or 38 | limitation to copyright. More considerations for licensors: 39 | wiki.creativecommons.org/Considerations_for_licensors 40 | 41 | Considerations for the public: By using one of our public 42 | licenses, a licensor grants the public permission to use the 43 | licensed material under specified terms and conditions. If 44 | the licensor's permission is not necessary for any reason--for 45 | example, because of any applicable exception or limitation to 46 | copyright--then that use is not regulated by the license. Our 47 | licenses grant only permissions under copyright and certain 48 | other rights that a licensor has authority to grant. Use of 49 | the licensed material may still be restricted for other 50 | reasons, including because others have copyright or other 51 | rights in the material. A licensor may make special requests, 52 | such as asking that all changes be marked or described. 53 | Although not required by our licenses, you are encouraged to 54 | respect those requests where reasonable. More_considerations 55 | for the public: 56 | wiki.creativecommons.org/Considerations_for_licensees 57 | 58 | ======================================================================= 59 | 60 | Creative Commons Attribution-NonCommercial 4.0 International Public 61 | License 62 | 63 | By exercising the Licensed Rights (defined below), You accept and agree 64 | to be bound by the terms and conditions of this Creative Commons 65 | Attribution-NonCommercial 4.0 International Public License ("Public 66 | License"). To the extent this Public License may be interpreted as a 67 | contract, You are granted the Licensed Rights in consideration of Your 68 | acceptance of these terms and conditions, and the Licensor grants You 69 | such rights in consideration of benefits the Licensor receives from 70 | making the Licensed Material available under these terms and 71 | conditions. 72 | 73 | 74 | Section 1 -- Definitions. 75 | 76 | a. Adapted Material means material subject to Copyright and Similar 77 | Rights that is derived from or based upon the Licensed Material 78 | and in which the Licensed Material is translated, altered, 79 | arranged, transformed, or otherwise modified in a manner requiring 80 | permission under the Copyright and Similar Rights held by the 81 | Licensor. For purposes of this Public License, where the Licensed 82 | Material is a musical work, performance, or sound recording, 83 | Adapted Material is always produced where the Licensed Material is 84 | synched in timed relation with a moving image. 85 | 86 | b. Adapter's License means the license You apply to Your Copyright 87 | and Similar Rights in Your contributions to Adapted Material in 88 | accordance with the terms and conditions of this Public License. 89 | 90 | c. Copyright and Similar Rights means copyright and/or similar rights 91 | closely related to copyright including, without limitation, 92 | performance, broadcast, sound recording, and Sui Generis Database 93 | Rights, without regard to how the rights are labeled or 94 | categorized. For purposes of this Public License, the rights 95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 96 | Rights. 97 | d. Effective Technological Measures means those measures that, in the 98 | absence of proper authority, may not be circumvented under laws 99 | fulfilling obligations under Article 11 of the WIPO Copyright 100 | Treaty adopted on December 20, 1996, and/or similar international 101 | agreements. 102 | 103 | e. Exceptions and Limitations means fair use, fair dealing, and/or 104 | any other exception or limitation to Copyright and Similar Rights 105 | that applies to Your use of the Licensed Material. 106 | 107 | f. Licensed Material means the artistic or literary work, database, 108 | or other material to which the Licensor applied this Public 109 | License. 110 | 111 | g. Licensed Rights means the rights granted to You subject to the 112 | terms and conditions of this Public License, which are limited to 113 | all Copyright and Similar Rights that apply to Your use of the 114 | Licensed Material and that the Licensor has authority to license. 115 | 116 | h. Licensor means the individual(s) or entity(ies) granting rights 117 | under this Public License. 118 | 119 | i. NonCommercial means not primarily intended for or directed towards 120 | commercial advantage or monetary compensation. For purposes of 121 | this Public License, the exchange of the Licensed Material for 122 | other material subject to Copyright and Similar Rights by digital 123 | file-sharing or similar means is NonCommercial provided there is 124 | no payment of monetary compensation in connection with the 125 | exchange. 126 | 127 | j. Share means to provide material to the public by any means or 128 | process that requires permission under the Licensed Rights, such 129 | as reproduction, public display, public performance, distribution, 130 | dissemination, communication, or importation, and to make material 131 | available to the public including in ways that members of the 132 | public may access the material from a place and at a time 133 | individually chosen by them. 134 | 135 | k. Sui Generis Database Rights means rights other than copyright 136 | resulting from Directive 96/9/EC of the European Parliament and of 137 | the Council of 11 March 1996 on the legal protection of databases, 138 | as amended and/or succeeded, as well as other essentially 139 | equivalent rights anywhere in the world. 140 | 141 | l. You means the individual or entity exercising the Licensed Rights 142 | under this Public License. Your has a corresponding meaning. 143 | 144 | 145 | Section 2 -- Scope. 146 | 147 | a. License grant. 148 | 149 | 1. Subject to the terms and conditions of this Public License, 150 | the Licensor hereby grants You a worldwide, royalty-free, 151 | non-sublicensable, non-exclusive, irrevocable license to 152 | exercise the Licensed Rights in the Licensed Material to: 153 | 154 | a. reproduce and Share the Licensed Material, in whole or 155 | in part, for NonCommercial purposes only; and 156 | 157 | b. produce, reproduce, and Share Adapted Material for 158 | NonCommercial purposes only. 159 | 160 | 2. Exceptions and Limitations. For the avoidance of doubt, where 161 | Exceptions and Limitations apply to Your use, this Public 162 | License does not apply, and You do not need to comply with 163 | its terms and conditions. 164 | 165 | 3. Term. The term of this Public License is specified in Section 166 | 6(a). 167 | 168 | 4. Media and formats; technical modifications allowed. The 169 | Licensor authorizes You to exercise the Licensed Rights in 170 | all media and formats whether now known or hereafter created, 171 | and to make technical modifications necessary to do so. The 172 | Licensor waives and/or agrees not to assert any right or 173 | authority to forbid You from making technical modifications 174 | necessary to exercise the Licensed Rights, including 175 | technical modifications necessary to circumvent Effective 176 | Technological Measures. For purposes of this Public License, 177 | simply making modifications authorized by this Section 2(a) 178 | (4) never produces Adapted Material. 179 | 180 | 5. Downstream recipients. 181 | 182 | a. Offer from the Licensor -- Licensed Material. Every 183 | recipient of the Licensed Material automatically 184 | receives an offer from the Licensor to exercise the 185 | Licensed Rights under the terms and conditions of this 186 | Public License. 187 | 188 | b. No downstream restrictions. You may not offer or impose 189 | any additional or different terms or conditions on, or 190 | apply any Effective Technological Measures to, the 191 | Licensed Material if doing so restricts exercise of the 192 | Licensed Rights by any recipient of the Licensed 193 | Material. 194 | 195 | 6. No endorsement. Nothing in this Public License constitutes or 196 | may be construed as permission to assert or imply that You 197 | are, or that Your use of the Licensed Material is, connected 198 | with, or sponsored, endorsed, or granted official status by, 199 | the Licensor or others designated to receive attribution as 200 | provided in Section 3(a)(1)(A)(i). 201 | 202 | b. Other rights. 203 | 204 | 1. Moral rights, such as the right of integrity, are not 205 | licensed under this Public License, nor are publicity, 206 | privacy, and/or other similar personality rights; however, to 207 | the extent possible, the Licensor waives and/or agrees not to 208 | assert any such rights held by the Licensor to the limited 209 | extent necessary to allow You to exercise the Licensed 210 | Rights, but not otherwise. 211 | 212 | 2. Patent and trademark rights are not licensed under this 213 | Public License. 214 | 215 | 3. To the extent possible, the Licensor waives any right to 216 | collect royalties from You for the exercise of the Licensed 217 | Rights, whether directly or through a collecting society 218 | under any voluntary or waivable statutory or compulsory 219 | licensing scheme. In all other cases the Licensor expressly 220 | reserves any right to collect such royalties, including when 221 | the Licensed Material is used other than for NonCommercial 222 | purposes. 223 | 224 | 225 | Section 3 -- License Conditions. 226 | 227 | Your exercise of the Licensed Rights is expressly made subject to the 228 | following conditions. 229 | 230 | a. Attribution. 231 | 232 | 1. If You Share the Licensed Material (including in modified 233 | form), You must: 234 | 235 | a. retain the following if it is supplied by the Licensor 236 | with the Licensed Material: 237 | 238 | i. identification of the creator(s) of the Licensed 239 | Material and any others designated to receive 240 | attribution, in any reasonable manner requested by 241 | the Licensor (including by pseudonym if 242 | designated); 243 | 244 | ii. a copyright notice; 245 | 246 | iii. a notice that refers to this Public License; 247 | 248 | iv. a notice that refers to the disclaimer of 249 | warranties; 250 | 251 | v. a URI or hyperlink to the Licensed Material to the 252 | extent reasonably practicable; 253 | 254 | b. indicate if You modified the Licensed Material and 255 | retain an indication of any previous modifications; and 256 | 257 | c. indicate the Licensed Material is licensed under this 258 | Public License, and include the text of, or the URI or 259 | hyperlink to, this Public License. 260 | 261 | 2. You may satisfy the conditions in Section 3(a)(1) in any 262 | reasonable manner based on the medium, means, and context in 263 | which You Share the Licensed Material. For example, it may be 264 | reasonable to satisfy the conditions by providing a URI or 265 | hyperlink to a resource that includes the required 266 | information. 267 | 268 | 3. If requested by the Licensor, You must remove any of the 269 | information required by Section 3(a)(1)(A) to the extent 270 | reasonably practicable. 271 | 272 | 4. If You Share Adapted Material You produce, the Adapter's 273 | License You apply must not prevent recipients of the Adapted 274 | Material from complying with this Public License. 275 | 276 | 277 | Section 4 -- Sui Generis Database Rights. 278 | 279 | Where the Licensed Rights include Sui Generis Database Rights that 280 | apply to Your use of the Licensed Material: 281 | 282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 283 | to extract, reuse, reproduce, and Share all or a substantial 284 | portion of the contents of the database for NonCommercial purposes 285 | only; 286 | 287 | b. if You include all or a substantial portion of the database 288 | contents in a database in which You have Sui Generis Database 289 | Rights, then the database in which You have Sui Generis Database 290 | Rights (but not its individual contents) is Adapted Material; and 291 | 292 | c. You must comply with the conditions in Section 3(a) if You Share 293 | all or a substantial portion of the contents of the database. 294 | 295 | For the avoidance of doubt, this Section 4 supplements and does not 296 | replace Your obligations under this Public License where the Licensed 297 | Rights include other Copyright and Similar Rights. 298 | 299 | 300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 301 | 302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 312 | 313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 322 | 323 | c. The disclaimer of warranties and limitation of liability provided 324 | above shall be interpreted in a manner that, to the extent 325 | possible, most closely approximates an absolute disclaimer and 326 | waiver of all liability. 327 | 328 | 329 | Section 6 -- Term and Termination. 330 | 331 | a. This Public License applies for the term of the Copyright and 332 | Similar Rights licensed here. However, if You fail to comply with 333 | this Public License, then Your rights under this Public License 334 | terminate automatically. 335 | 336 | b. Where Your right to use the Licensed Material has terminated under 337 | Section 6(a), it reinstates: 338 | 339 | 1. automatically as of the date the violation is cured, provided 340 | it is cured within 30 days of Your discovery of the 341 | violation; or 342 | 343 | 2. upon express reinstatement by the Licensor. 344 | 345 | For the avoidance of doubt, this Section 6(b) does not affect any 346 | right the Licensor may have to seek remedies for Your violations 347 | of this Public License. 348 | 349 | c. For the avoidance of doubt, the Licensor may also offer the 350 | Licensed Material under separate terms or conditions or stop 351 | distributing the Licensed Material at any time; however, doing so 352 | will not terminate this Public License. 353 | 354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 355 | License. 356 | 357 | 358 | Section 7 -- Other Terms and Conditions. 359 | 360 | a. The Licensor shall not be bound by any additional or different 361 | terms or conditions communicated by You unless expressly agreed. 362 | 363 | b. Any arrangements, understandings, or agreements regarding the 364 | Licensed Material not stated herein are separate from and 365 | independent of the terms and conditions of this Public License. 366 | 367 | 368 | Section 8 -- Interpretation. 369 | 370 | a. For the avoidance of doubt, this Public License does not, and 371 | shall not be interpreted to, reduce, limit, restrict, or impose 372 | conditions on any use of the Licensed Material that could lawfully 373 | be made without permission under this Public License. 374 | 375 | b. To the extent possible, if any provision of this Public License is 376 | deemed unenforceable, it shall be automatically reformed to the 377 | minimum extent necessary to make it enforceable. If the provision 378 | cannot be reformed, it shall be severed from this Public License 379 | without affecting the enforceability of the remaining terms and 380 | conditions. 381 | 382 | c. No term or condition of this Public License will be waived and no 383 | failure to comply consented to unless expressly agreed to by the 384 | Licensor. 385 | 386 | d. Nothing in this Public License constitutes or may be interpreted 387 | as a limitation upon, or waiver of, any privileges and immunities 388 | that apply to the Licensor or You, including from the legal 389 | processes of any jurisdiction or authority. 390 | 391 | ======================================================================= 392 | 393 | Creative Commons is not a party to its public 394 | licenses. Notwithstanding, Creative Commons may elect to apply one of 395 | its public licenses to material it publishes and in those instances 396 | will be considered the "Licensor." The text of the Creative Commons 397 | public licenses is dedicated to the public domain under the CC0 Public 398 | Domain Dedication. Except for the limited purpose of indicating that 399 | material is shared under a Creative Commons public license or as 400 | otherwise permitted by the Creative Commons policies published at 401 | creativecommons.org/policies, Creative Commons does not authorize the 402 | use of the trademark "Creative Commons" or any other trademark or logo 403 | of Creative Commons without its prior written consent including, 404 | without limitation, in connection with any unauthorized modifications 405 | to any of its public licenses or any other arrangements, 406 | understandings, or agreements concerning use of licensed material. For 407 | the avoidance of doubt, this paragraph does not form part of the 408 | public licenses. 409 | 410 | Creative Commons may be contacted at creativecommons.org. 411 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-Quality Self-Supervised Deep Image Denoising - Official TensorFlow implementation of the NeurIPS 2019 paper 2 | 3 | **Samuli Laine** (NVIDIA), **Tero Karras** (NVIDIA), **Jaakko Lehtinen** (NVIDIA and Aalto University), **Timo Aila** (NVIDIA) 4 | 5 | **Abstract**: 6 | 7 | _We describe a novel method for training high-quality image denoising models based on unorganized collections of corrupted images. The training does not need access to clean reference images, or explicit pairs of corrupted images, and can thus be applied in situations where such data is unacceptably expensive or impossible to acquire. We build on a recent technique that removes the need for reference data by employing networks with a "blind spot" in the receptive field, and significantly improve two key aspects: image quality and training efficiency. Our result quality is on par with state-of-the-art neural network denoisers in the case of i.i.d. additive Gaussian noise, and not far behind with Poisson and impulse noise. We also successfully handle cases where parameters of the noise model are variable and/or unknown in both training and evaluation data._ 8 | 9 | ![Denoising comparison](img/readme_figure.png "Denoising comparison") 10 | 11 | ## Resources 12 | 13 | - [Paper](https://arxiv.org/abs/1901.10277) (arXiv) 14 | - [Pre-trained networks](https://drive.google.com/open?id=1tatE9WFNSqzLm_aso3Wy05j90_wkMmo4) 15 | 16 | All material is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made. 17 | 18 | ## Python requirements 19 | 20 | This code was tested on: 21 | 22 | - Python 3.7 23 | - TensorFlow 1.14 24 | - [Anaconda 2019/07](https://www.anaconda.com/distribution/) 25 | 26 | ## Preparing training dataset 27 | 28 | Our networks have been trained with ImageNet validation set pruned to contain only images between 256x256 and 512x512 pixels in size, yielding 44328 images in total. 29 | To generate the training data hdf5 file, run: 30 | 31 | ``` 32 | # This runs through roughly 50K images and outputs a file called `imagenet_val.h5`. 33 | python dataset_tool_h5.py --input-dir "/ILSVRC2012_img_val" --out=imagenet_val.h5 34 | ``` 35 | 36 | A successful run of dataset_tool_h5.py should print the following upon completion: 37 | 38 | ``` 39 | <... snip ...> 40 | 49997 ./ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_00002873.JPEG 41 | 49998 ./ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_00031550.JPEG 42 | 49999 ./ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_00009765.JPEG 43 | 44328/44328: ./ImageNet/ILSVRC2012_img_val/ILSVRC2012_val_00039330.JPEG 44 | Dataset statistics: 45 | Total pixels 8375905404 46 | Formats: 47 | RGB: 43471 images 48 | L: 857 images 49 | width,height buckets: 50 | >= 256x256: 44328 images 51 | ``` 52 | 53 | ## Preparing validation datasets 54 | 55 | Validation data is placed under a common directory. This location can be set using `--dataset-dir ` command line argument. The below examples assume this location is at `$HOME/datasets`. 56 | 57 | **Kodak**. To download the [Kodak Lossless True Color Image Suite](http://r0k.us/graphics/kodak/), run: 58 | 59 | ``` 60 | python download_kodak.py --output-dir=$HOME/datasets/kodak 61 | ``` 62 | 63 | **BSD300**. From [Berkeley Segmentation Dataset and Benchmark](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds) download `BSDS300-images.tgz` and extract: 64 | 65 | ``` 66 | cd $HOME/datasets 67 | tar zxf ~/Downloads/BSDS300-images.tgz 68 | ``` 69 | 70 | **Set14**. From [LapSRN project page](http://vllab.ucmerced.edu/wlai24/LapSRN) download `SR_testing_datasets.zip` and extract: 71 | 72 | ``` 73 | cd $HOME/datasets 74 | unzip ~/Downloads/SR_testing_datasets.zip 75 | ``` 76 | 77 | ## Running 78 | 79 | Run `python selfsupervised_denoising.py --help` for a complete listing of command line parameters and support list of training configurations. 80 | 81 | ``` 82 | usage: selfsupervised_denoising.py [-h] [--dataset-dir DATASET_DIR] 83 | [--train-h5 TRAIN_H5] 84 | [--validation-set VALIDATION_SET] 85 | [--eval EVAL] [--train TRAIN] 86 | 87 | Train or evaluate. 88 | 89 | optional arguments: 90 | -h, --help show this help message and exit 91 | --dataset-dir DATASET_DIR 92 | Path to validation set data 93 | --train-h5 TRAIN_H5 Specify training set .h5 filename 94 | --validation-set VALIDATION_SET 95 | Evaluation dataset 96 | --eval EVAL Evaluate validation set with the given network pickle 97 | --train TRAIN Train for the given config 98 | 99 | examples: 100 | # Train a network with gauss25-blindspot-sigma_global configuration 101 | python selfsupervised_denoising.py --train=gauss25-blindspot-sigma_global --dataset-dir=$HOME/datasets --validation-set=kodak --train-h5=imagenet_val_raw.h5 102 | 103 | # Evaluate a network using the BSD300 dataset: 104 | python selfsupervised_denoising.py --eval=$HOME/pretrained/network-00012-gauss25-n2n.pickle --dataset-dir=$HOME/datasets --validation-set=kodak 105 | 106 | List of all configs: 107 | 108 | gauss25-n2c 109 | gauss25-n2n 110 | gauss25-blindspot-sigma_known 111 | ... 112 | ``` 113 | 114 | **Training**: 115 | 116 | To train a network, run: 117 | 118 | ``` 119 | python selfsupervised_denoising.py --dataset-dir=$HOME/datasets --validation-set=kodak --train=gauss25-blindspot-sigma_known --train-h5=imagenet_val.h5 120 | ``` 121 | 122 | The specified validation set is evaluated periodically during training. This can be used to roughly estimate convergence, but 123 | for reliable results the evaluation must be done using the evaluation mode below. 124 | 125 | Note that the default settings of running minibatch size of 4 with one GPU requires a lot of memory. If you run out of memory, 126 | either decrease the minibatch size or run the code on multiple GPUs. The pre-trained networks were trained on 4 GPUs. 127 | 128 | **Evaluating**: 129 | 130 | To evaluate a trained network against one of the validation sets, run: 131 | 132 | ``` 133 | python selfsupervised_denoising.py --dataset-dir=$HOME/datasets --validation-set=kodak --eval=$HOME/datasets/pretrained/network-00013-gauss25-blindspot-sigma_known.pickle 134 | ``` 135 | 136 | In evaluation mode, the random seeds are fixed so that the generated noise is repeatable. This guarantees that each network 137 | is evaluated against the exact same images. In addition, the validation sets are replicated several times to obtain ~300 138 | total validation images. This is important especially for variable noise types, to ensure that each image is evaluated using 139 | various amounts of noise. Note that the noise and network types are inferred from the filename of the trained network. 140 | 141 | The evaluation results should match the paper. For example, the network used in the command-line example should give the following PSNRs: 142 | 143 | | Network | Kodak | BSD300 | Set14 | 144 | | ----------------------------- | -------- | -------- | -------- | 145 | | gauss25-blindspot-sigma_known | 32.45 dB | 31.03 dB | 31.25 dB | 146 | 147 | ## Acknowledgements 148 | 149 | We thank Arno Solin and Samuel Kaski for helpful comments, and Janne Hellsten and Tero Kuosmanen for the compute infrastructure. 150 | -------------------------------------------------------------------------------- /dataset_tool_h5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import glob 9 | import os 10 | import sys 11 | import argparse 12 | import h5py 13 | import fnmatch 14 | 15 | import PIL.Image 16 | import numpy as np 17 | 18 | from collections import defaultdict 19 | 20 | size_stats = defaultdict(int) 21 | format_stats = defaultdict(int) 22 | 23 | 24 | def load_image(fname): 25 | global format_stats, size_stats 26 | im = PIL.Image.open(fname) 27 | format_stats[im.mode] += 1 28 | if (im.width < 256 or im.height < 256): 29 | size_stats['< 256x256'] += 1 30 | else: 31 | size_stats['>= 256x256'] += 1 32 | arr = np.array(im.convert('RGB'), dtype=np.uint8) 33 | assert len(arr.shape) == 3 34 | return arr.transpose([2, 0, 1]) 35 | 36 | 37 | def filter_image_sizes(images): 38 | filtered = [] 39 | for idx, fname in enumerate(images): 40 | if (idx % 100) == 0: 41 | print ('loading images', idx, '/', len(images)) 42 | try: 43 | with PIL.Image.open(fname) as img: 44 | w = img.size[0] 45 | h = img.size[1] 46 | if (w > 512 or h > 512) or (w < 256 or h < 256): 47 | continue 48 | filtered.append((fname, w, h)) 49 | except: 50 | print ('Could not load image', fname, 'skipping file..') 51 | return filtered 52 | 53 | 54 | examples='''examples: 55 | 56 | python %(prog)s --input-dir=./ILSVRC2012_img_val --out=imagenet_val_raw.h5 57 | ''' 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser( 61 | description='Convert a set of image files into a HDF5 dataset file.', 62 | epilog=examples, 63 | formatter_class=argparse.RawDescriptionHelpFormatter 64 | ) 65 | parser.add_argument("--input-dir", help="Directory containing ImageNet images (can be glob pattern for subdirs)") 66 | parser.add_argument("--out", help="Filename of the output file") 67 | parser.add_argument("--max-files", help="Convert up to max-files images. Process all if unspecified.") 68 | args = parser.parse_args() 69 | 70 | if args.input_dir is None: 71 | print ('Must specify input file directory with --input-dir') 72 | sys.exit(1) 73 | if args.out is None: 74 | print ('Must specify output filename with --out') 75 | sys.exit(1) 76 | 77 | print ('Loading image list from %s' % args.input_dir) 78 | images = [] 79 | pattern = os.path.join(args.input_dir, '**/*') 80 | all_fnames = glob.glob(pattern, recursive=True) 81 | for fname in all_fnames: 82 | # include only JPEG/jpg/png 83 | if fnmatch.fnmatch(fname, '*.JPEG') or fnmatch.fnmatch(fname, '*.jpg') or fnmatch.fnmatch(fname, '*.png'): 84 | images.append(fname) 85 | images = sorted(images) 86 | np.random.RandomState(0xbadf00d).shuffle(images) 87 | 88 | filtered = filter_image_sizes(images) 89 | if args.max_files: 90 | filtered = filtered[0:int(args.max_files)] 91 | 92 | # ---------------------------------------------------------- 93 | outdir = os.path.dirname(args.out) 94 | if outdir != '': 95 | os.makedirs(outdir, exist_ok=True) 96 | num_images = len(filtered) 97 | num_pixels_total = 0 98 | with h5py.File(args.out, 'w') as h5file: 99 | dt = h5py.special_dtype(vlen=np.dtype('uint8')) 100 | dset_shapes = h5file.create_dataset('shapes', (num_images, 3), dtype=np.int32) 101 | dset_images = h5file.create_dataset('images', (num_images,), dtype=dt) 102 | for (idx, (imgname, w, h)) in enumerate(filtered): 103 | print ("%d/%d: %s" % (idx+1, len(filtered), imgname)) 104 | dset_images[idx] = load_image(imgname).flatten() 105 | dset_shapes[idx] = (3, h, w) 106 | num_pixels_total += h*w 107 | 108 | print ('Dataset statistics:') 109 | print (' Total pixels', num_pixels_total) 110 | print (' Formats:') 111 | for key in format_stats: 112 | print (' %s: %d images' % (key, format_stats[key])) 113 | print (' width,height buckets:') 114 | for key in size_stats: 115 | print (' %s: %d images' % (key, size_stats[key])) 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Submit a function to be run either locally or in a computing cluster.""" 9 | 10 | import copy 11 | import io 12 | import os 13 | import pathlib 14 | import pickle 15 | import platform 16 | import pprint 17 | import re 18 | import shutil 19 | import time 20 | import traceback 21 | 22 | import zipfile 23 | 24 | from enum import Enum 25 | 26 | from .. import util 27 | from ..util import EasyDict 28 | 29 | 30 | class SubmitTarget(Enum): 31 | """The target where the function should be run. 32 | 33 | LOCAL: Run it locally. 34 | """ 35 | LOCAL = 1 36 | 37 | 38 | class PathType(Enum): 39 | """Determines in which format should a path be formatted. 40 | 41 | WINDOWS: Format with Windows style. 42 | LINUX: Format with Linux/Posix style. 43 | AUTO: Use current OS type to select either WINDOWS or LINUX. 44 | """ 45 | WINDOWS = 1 46 | LINUX = 2 47 | AUTO = 3 48 | 49 | 50 | _user_name_override = None 51 | 52 | 53 | class SubmitConfig(util.EasyDict): 54 | """Strongly typed config dict needed to submit runs. 55 | 56 | Attributes: 57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 58 | run_desc: Description of the run. Will be used in the run dir and task name. 59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 61 | submit_target: Submit target enum value. Used to select where the run is actually launched. 62 | num_gpus: Number of GPUs used/requested for the run. 63 | print_info: Whether to print debug information when submitting. 64 | ask_confirmation: Whether to ask a confirmation before submitting. 65 | run_id: Automatically populated value during submit. 66 | run_name: Automatically populated value during submit. 67 | run_dir: Automatically populated value during submit. 68 | run_func_name: Automatically populated value during submit. 69 | run_func_kwargs: Automatically populated value during submit. 70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 71 | task_name: Automatically populated value during submit. 72 | host_name: Automatically populated value during submit. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | 78 | # run (set these) 79 | self.run_dir_root = "" # should always be passed through get_path_from_template 80 | self.run_desc = "" 81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] 82 | self.run_dir_extra_files = None 83 | 84 | # submit (set these) 85 | self.submit_target = SubmitTarget.LOCAL 86 | self.num_gpus = 1 87 | self.print_info = False 88 | self.ask_confirmation = False 89 | 90 | # (automatically populated) 91 | self.run_id = None 92 | self.run_name = None 93 | self.run_dir = None 94 | self.run_func_name = None 95 | self.run_func_kwargs = None 96 | self.user_name = None 97 | self.task_name = None 98 | self.host_name = "localhost" 99 | 100 | 101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 102 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 103 | # automatically select path type depending on running OS 104 | if path_type == PathType.AUTO: 105 | if platform.system() == "Windows": 106 | path_type = PathType.WINDOWS 107 | elif platform.system() == "Linux": 108 | path_type = PathType.LINUX 109 | else: 110 | raise RuntimeError("Unknown platform") 111 | 112 | path_template = path_template.replace("", get_user_name()) 113 | 114 | # return correctly formatted path 115 | if path_type == PathType.WINDOWS: 116 | return str(pathlib.PureWindowsPath(path_template)) 117 | elif path_type == PathType.LINUX: 118 | return str(pathlib.PurePosixPath(path_template)) 119 | else: 120 | raise RuntimeError("Unknown platform") 121 | 122 | 123 | def get_template_from_path(path: str) -> str: 124 | """Convert a normal path back to its template representation.""" 125 | # replace all path parts with the template tags 126 | path = path.replace("\\", "/") 127 | return path 128 | 129 | 130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 131 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 132 | path_template = get_template_from_path(path) 133 | path = get_path_from_template(path_template, path_type) 134 | return path 135 | 136 | 137 | def set_user_name_override(name: str) -> None: 138 | """Set the global username override value.""" 139 | global _user_name_override 140 | _user_name_override = name 141 | 142 | 143 | def get_user_name(): 144 | """Get the current user name.""" 145 | if _user_name_override is not None: 146 | return _user_name_override 147 | elif platform.system() == "Windows": 148 | return os.getlogin() 149 | elif platform.system() == "Linux": 150 | try: 151 | import pwd # pylint: disable=import-error 152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member 153 | except: 154 | return "unknown" 155 | else: 156 | raise RuntimeError("Unknown platform") 157 | 158 | 159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 160 | """Create a new run dir with increasing ID number at the start.""" 161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 162 | 163 | if not os.path.exists(run_dir_root): 164 | print("Creating the run dir root: {}".format(run_dir_root)) 165 | os.makedirs(run_dir_root) 166 | 167 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 169 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 170 | 171 | if os.path.exists(run_dir): 172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 173 | 174 | print("Creating the run dir: {}".format(run_dir)) 175 | os.makedirs(run_dir) 176 | 177 | return run_dir 178 | 179 | 180 | def _get_next_run_id_local(run_dir_root: str) -> int: 181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 183 | r = re.compile("^\\d+") # match one or more digits at the start of the string 184 | run_id = 0 185 | 186 | for dir_name in dir_names: 187 | m = r.match(dir_name) 188 | 189 | if m is not None: 190 | i = int(m.group()) 191 | run_id = max(run_id, i + 1) 192 | 193 | return run_id 194 | 195 | 196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: 197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 198 | print("Copying files to the run dir") 199 | files = [] 200 | 201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 202 | assert '.' in submit_config.run_func_name 203 | for _idx in range(submit_config.run_func_name.count('.') - 1): 204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 206 | 207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 209 | 210 | if submit_config.run_dir_extra_files is not None: 211 | files += submit_config.run_dir_extra_files 212 | 213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] 215 | 216 | util.copy_files_and_create_dirs(files) 217 | 218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 219 | 220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 222 | 223 | 224 | def run_wrapper(submit_config: SubmitConfig) -> None: 225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 227 | 228 | checker = None 229 | 230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 231 | if is_local: 232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 234 | logger = util.Logger(file_name=None, should_flush=True) 235 | 236 | import dnnlib 237 | dnnlib.submit_config = submit_config 238 | 239 | try: 240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 241 | start_time = time.time() 242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) 243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 244 | except: 245 | if is_local: 246 | raise 247 | else: 248 | traceback.print_exc() 249 | 250 | log_src = os.path.join(submit_config.run_dir, "log.txt") 251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 252 | shutil.copyfile(log_src, log_dst) 253 | finally: 254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 255 | 256 | dnnlib.submit_config = None 257 | logger.close() 258 | 259 | if checker is not None: 260 | checker.stop() 261 | 262 | 263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 265 | submit_config = copy.copy(submit_config) 266 | 267 | if submit_config.user_name is None: 268 | submit_config.user_name = get_user_name() 269 | 270 | submit_config.run_func_name = run_func_name 271 | submit_config.run_func_kwargs = run_func_kwargs 272 | 273 | assert submit_config.submit_target == SubmitTarget.LOCAL 274 | if submit_config.submit_target in {SubmitTarget.LOCAL}: 275 | run_dir = _create_run_dir_local(submit_config) 276 | 277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 278 | submit_config.run_dir = run_dir 279 | _populate_run_dir(run_dir, submit_config) 280 | 281 | if submit_config.print_info: 282 | print("\nSubmit config:\n") 283 | pprint.pprint(submit_config, indent=4, width=200, compact=False) 284 | print() 285 | 286 | if submit_config.ask_confirmation: 287 | if not util.ask_yes_no("Continue submitting the job?"): 288 | return 289 | 290 | run_wrapper(submit_config) 291 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /dnnlib/tflib/network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for managing networks.""" 9 | 10 | import types 11 | import inspect 12 | import re 13 | import uuid 14 | import sys 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from collections import OrderedDict 19 | from typing import Any, List, Tuple, Union 20 | 21 | from . import tfutil 22 | from .. import util 23 | 24 | from .tfutil import TfExpression, TfExpressionEx 25 | 26 | _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. 27 | _import_module_src = dict() # Source code for temporary modules created during pickle import. 28 | 29 | 30 | def import_handler(handler_func): 31 | """Function decorator for declaring custom import handlers.""" 32 | _import_handlers.append(handler_func) 33 | return handler_func 34 | 35 | 36 | class Network: 37 | """Generic network abstraction. 38 | 39 | Acts as a convenience wrapper for a parameterized network construction 40 | function, providing several utility methods and convenient access to 41 | the inputs/outputs/weights. 42 | 43 | Network objects can be safely pickled and unpickled for long-term 44 | archival purposes. The pickling works reliably as long as the underlying 45 | network construction function is defined in a standalone Python module 46 | that has no side effects or application-specific imports. 47 | 48 | Args: 49 | name: Network name. Used to select TensorFlow name and variable scopes. 50 | func_name: Fully qualified name of the underlying network construction function, or a top-level function object. 51 | static_kwargs: Keyword arguments to be passed in to the network construction function. 52 | 53 | Attributes: 54 | name: User-specified name, defaults to build func name if None. 55 | scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name. 56 | static_kwargs: Arguments passed to the user-supplied build func. 57 | components: Container for sub-networks. Passed to the build func, and retained between calls. 58 | num_inputs: Number of input tensors. 59 | num_outputs: Number of output tensors. 60 | input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension. 61 | output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension. 62 | input_shape: Short-hand for input_shapes[0]. 63 | output_shape: Short-hand for output_shapes[0]. 64 | input_templates: Input placeholders in the template graph. 65 | output_templates: Output tensors in the template graph. 66 | input_names: Name string for each input. 67 | output_names: Name string for each output. 68 | own_vars: Variables defined by this network (local_name => var), excluding sub-networks. 69 | vars: All variables (local_name => var). 70 | trainables: All trainable variables (local_name => var). 71 | var_global_to_local: Mapping from variable global names to local names. 72 | """ 73 | 74 | def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): 75 | tfutil.assert_tf_initialized() 76 | assert isinstance(name, str) or name is None 77 | assert func_name is not None 78 | assert isinstance(func_name, str) or util.is_top_level_function(func_name) 79 | assert util.is_pickleable(static_kwargs) 80 | 81 | self._init_fields() 82 | self.name = name 83 | self.static_kwargs = util.EasyDict(static_kwargs) 84 | 85 | # Locate the user-specified network build function. 86 | if util.is_top_level_function(func_name): 87 | func_name = util.get_top_level_function_name(func_name) 88 | module, self._build_func_name = util.get_module_from_obj_name(func_name) 89 | self._build_func = util.get_obj_from_module(module, self._build_func_name) 90 | assert callable(self._build_func) 91 | 92 | # Dig up source code for the module containing the build function. 93 | self._build_module_src = _import_module_src.get(module, None) 94 | if self._build_module_src is None: 95 | self._build_module_src = inspect.getsource(module) 96 | 97 | # Init TensorFlow graph. 98 | self._init_graph() 99 | self.reset_own_vars() 100 | 101 | def _init_fields(self) -> None: 102 | self.name = None 103 | self.scope = None 104 | self.static_kwargs = util.EasyDict() 105 | self.components = util.EasyDict() 106 | self.num_inputs = 0 107 | self.num_outputs = 0 108 | self.input_shapes = [[]] 109 | self.output_shapes = [[]] 110 | self.input_shape = [] 111 | self.output_shape = [] 112 | self.input_templates = [] 113 | self.output_templates = [] 114 | self.input_names = [] 115 | self.output_names = [] 116 | self.own_vars = OrderedDict() 117 | self.vars = OrderedDict() 118 | self.trainables = OrderedDict() 119 | self.var_global_to_local = OrderedDict() 120 | 121 | self._build_func = None # User-supplied build function that constructs the network. 122 | self._build_func_name = None # Name of the build function. 123 | self._build_module_src = None # Full source code of the module containing the build function. 124 | self._run_cache = dict() # Cached graph data for Network.run(). 125 | 126 | def _init_graph(self) -> None: 127 | # Collect inputs. 128 | self.input_names = [] 129 | 130 | for param in inspect.signature(self._build_func).parameters.values(): 131 | if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: 132 | self.input_names.append(param.name) 133 | 134 | self.num_inputs = len(self.input_names) 135 | assert self.num_inputs >= 1 136 | 137 | # Choose name and scope. 138 | if self.name is None: 139 | self.name = self._build_func_name 140 | assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) 141 | with tf.name_scope(None): 142 | self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) 143 | 144 | # Finalize build func kwargs. 145 | build_kwargs = dict(self.static_kwargs) 146 | build_kwargs["is_template_graph"] = True 147 | build_kwargs["components"] = self.components 148 | 149 | # Build template graph. 150 | with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes 151 | assert tf.get_variable_scope().name == self.scope 152 | assert tf.get_default_graph().get_name_scope() == self.scope 153 | with tf.control_dependencies(None): # ignore surrounding control dependencies 154 | self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] 155 | out_expr = self._build_func(*self.input_templates, **build_kwargs) 156 | 157 | # Collect outputs. 158 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) 159 | self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) 160 | self.num_outputs = len(self.output_templates) 161 | assert self.num_outputs >= 1 162 | assert all(tfutil.is_tf_expression(t) for t in self.output_templates) 163 | 164 | # Perform sanity checks. 165 | if any(t.shape.ndims is None for t in self.input_templates): 166 | raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") 167 | if any(t.shape.ndims is None for t in self.output_templates): 168 | raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") 169 | if any(not isinstance(comp, Network) for comp in self.components.values()): 170 | raise ValueError("Components of a Network must be Networks themselves.") 171 | if len(self.components) != len(set(comp.name for comp in self.components.values())): 172 | raise ValueError("Components of a Network must have unique names.") 173 | 174 | # List inputs and outputs. 175 | self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates] 176 | self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates] 177 | self.input_shape = self.input_shapes[0] 178 | self.output_shape = self.output_shapes[0] 179 | self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] 180 | 181 | # List variables. 182 | self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) 183 | self.vars = OrderedDict(self.own_vars) 184 | self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) 185 | self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) 186 | self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) 187 | 188 | def reset_own_vars(self) -> None: 189 | """Re-initialize all variables of this network, excluding sub-networks.""" 190 | tfutil.run([var.initializer for var in self.own_vars.values()]) 191 | 192 | def reset_vars(self) -> None: 193 | """Re-initialize all variables of this network, including sub-networks.""" 194 | tfutil.run([var.initializer for var in self.vars.values()]) 195 | 196 | def reset_trainables(self) -> None: 197 | """Re-initialize all trainable variables of this network, including sub-networks.""" 198 | tfutil.run([var.initializer for var in self.trainables.values()]) 199 | 200 | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: 201 | """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" 202 | assert len(in_expr) == self.num_inputs 203 | assert not all(expr is None for expr in in_expr) 204 | 205 | # Finalize build func kwargs. 206 | build_kwargs = dict(self.static_kwargs) 207 | build_kwargs.update(dynamic_kwargs) 208 | build_kwargs["is_template_graph"] = False 209 | build_kwargs["components"] = self.components 210 | 211 | # Build TensorFlow graph to evaluate the network. 212 | with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): 213 | assert tf.get_variable_scope().name == self.scope 214 | valid_inputs = [expr for expr in in_expr if expr is not None] 215 | final_inputs = [] 216 | for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): 217 | if expr is not None: 218 | expr = tf.identity(expr, name=name) 219 | else: 220 | expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) 221 | final_inputs.append(expr) 222 | out_expr = self._build_func(*final_inputs, **build_kwargs) 223 | 224 | # Propagate input shapes back to the user-specified expressions. 225 | for expr, final in zip(in_expr, final_inputs): 226 | if isinstance(expr, tf.Tensor): 227 | expr.set_shape(final.shape) 228 | 229 | # Express outputs in the desired format. 230 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) 231 | if return_as_list: 232 | out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) 233 | return out_expr 234 | 235 | def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: 236 | """Get the local name of a given variable, without any surrounding name scopes.""" 237 | assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) 238 | global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name 239 | return self.var_global_to_local[global_name] 240 | 241 | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: 242 | """Find variable by local or global name.""" 243 | assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) 244 | return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name 245 | 246 | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: 247 | """Get the value of a given variable as NumPy array. 248 | Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" 249 | return self.find_var(var_or_local_name).eval() 250 | 251 | def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: 252 | """Set the value of a given variable based on the given NumPy array. 253 | Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" 254 | tfutil.set_vars({self.find_var(var_or_local_name): new_value}) 255 | 256 | def __getstate__(self) -> dict: 257 | """Pickle export.""" 258 | state = dict() 259 | state["version"] = 3 260 | state["name"] = self.name 261 | state["static_kwargs"] = dict(self.static_kwargs) 262 | state["components"] = dict(self.components) 263 | state["build_module_src"] = self._build_module_src 264 | state["build_func_name"] = self._build_func_name 265 | state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) 266 | return state 267 | 268 | def __setstate__(self, state: dict) -> None: 269 | """Pickle import.""" 270 | # pylint: disable=attribute-defined-outside-init 271 | tfutil.assert_tf_initialized() 272 | self._init_fields() 273 | 274 | # Execute custom import handlers. 275 | for handler in _import_handlers: 276 | state = handler(state) 277 | 278 | # Set basic fields. 279 | assert state["version"] in [2, 3] 280 | self.name = state["name"] 281 | self.static_kwargs = util.EasyDict(state["static_kwargs"]) 282 | self.components = util.EasyDict(state.get("components", {})) 283 | self._build_module_src = state["build_module_src"] 284 | self._build_func_name = state["build_func_name"] 285 | 286 | # Create temporary module from the imported source code. 287 | module_name = "_tflib_network_import_" + uuid.uuid4().hex 288 | module = types.ModuleType(module_name) 289 | sys.modules[module_name] = module 290 | _import_module_src[module] = self._build_module_src 291 | exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used 292 | 293 | # Locate network build function in the temporary module. 294 | self._build_func = util.get_obj_from_module(module, self._build_func_name) 295 | assert callable(self._build_func) 296 | 297 | # Init TensorFlow graph. 298 | self._init_graph() 299 | self.reset_own_vars() 300 | tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]}) 301 | 302 | def clone(self, name: str = None, **new_static_kwargs) -> "Network": 303 | """Create a clone of this network with its own copy of the variables.""" 304 | # pylint: disable=protected-access 305 | net = object.__new__(Network) 306 | net._init_fields() 307 | net.name = name if name is not None else self.name 308 | net.static_kwargs = util.EasyDict(self.static_kwargs) 309 | net.static_kwargs.update(new_static_kwargs) 310 | net._build_module_src = self._build_module_src 311 | net._build_func_name = self._build_func_name 312 | net._build_func = self._build_func 313 | net._init_graph() 314 | net.copy_vars_from(self) 315 | return net 316 | 317 | def copy_own_vars_from(self, src_net: "Network") -> None: 318 | """Copy the values of all variables from the given network, excluding sub-networks.""" 319 | names = [name for name in self.own_vars.keys() if name in src_net.own_vars] 320 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 321 | 322 | def copy_vars_from(self, src_net: "Network") -> None: 323 | """Copy the values of all variables from the given network, including sub-networks.""" 324 | names = [name for name in self.vars.keys() if name in src_net.vars] 325 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 326 | 327 | def copy_trainables_from(self, src_net: "Network") -> None: 328 | """Copy the values of all trainable variables from the given network, including sub-networks.""" 329 | names = [name for name in self.trainables.keys() if name in src_net.trainables] 330 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 331 | 332 | def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": 333 | """Create new network with the given parameters, and copy all variables from this network.""" 334 | if new_name is None: 335 | new_name = self.name 336 | static_kwargs = dict(self.static_kwargs) 337 | static_kwargs.update(new_static_kwargs) 338 | net = Network(name=new_name, func_name=new_func_name, **static_kwargs) 339 | net.copy_vars_from(self) 340 | return net 341 | 342 | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: 343 | """Construct a TensorFlow op that updates the variables of this network 344 | to be slightly closer to those of the given network.""" 345 | with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): 346 | ops = [] 347 | for name, var in self.vars.items(): 348 | if name in src_net.vars: 349 | cur_beta = beta if name in self.trainables else beta_nontrainable 350 | new_value = tfutil.lerp(src_net.vars[name], var, cur_beta) 351 | ops.append(var.assign(new_value)) 352 | return tf.group(*ops) 353 | 354 | def run(self, 355 | *in_arrays: Tuple[Union[np.ndarray, None], ...], 356 | input_transform: dict = None, 357 | output_transform: dict = None, 358 | return_as_list: bool = False, 359 | print_progress: bool = False, 360 | minibatch_size: int = None, 361 | num_gpus: int = 1, 362 | assume_frozen: bool = False, 363 | **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: 364 | """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). 365 | 366 | Args: 367 | input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. 368 | The dict must contain a 'func' field that points to a top-level function. The function is called with the input 369 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. 370 | output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. 371 | The dict must contain a 'func' field that points to a top-level function. The function is called with the output 372 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. 373 | return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. 374 | print_progress: Print progress to the console? Useful for very large input arrays. 375 | minibatch_size: Maximum minibatch size to use, None = disable batching. 376 | num_gpus: Number of GPUs to use. 377 | assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. 378 | dynamic_kwargs: Additional keyword arguments to be passed into the network build function. 379 | """ 380 | assert len(in_arrays) == self.num_inputs 381 | assert not all(arr is None for arr in in_arrays) 382 | assert input_transform is None or util.is_top_level_function(input_transform["func"]) 383 | assert output_transform is None or util.is_top_level_function(output_transform["func"]) 384 | output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) 385 | num_items = in_arrays[0].shape[0] 386 | if minibatch_size is None: 387 | minibatch_size = num_items 388 | 389 | # Construct unique hash key from all arguments that affect the TensorFlow graph. 390 | key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) 391 | def unwind_key(obj): 392 | if isinstance(obj, dict): 393 | return [(key, unwind_key(value)) for key, value in sorted(obj.items())] 394 | if callable(obj): 395 | return util.get_top_level_function_name(obj) 396 | return obj 397 | key = repr(unwind_key(key)) 398 | 399 | # Build graph. 400 | if key not in self._run_cache: 401 | with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): 402 | with tf.device("/cpu:0"): 403 | in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] 404 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) 405 | 406 | out_split = [] 407 | for gpu in range(num_gpus): 408 | with tf.device("/gpu:%d" % gpu): 409 | net_gpu = self.clone() if assume_frozen else self 410 | in_gpu = in_split[gpu] 411 | 412 | if input_transform is not None: 413 | in_kwargs = dict(input_transform) 414 | in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) 415 | in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) 416 | 417 | assert len(in_gpu) == self.num_inputs 418 | out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) 419 | 420 | if output_transform is not None: 421 | out_kwargs = dict(output_transform) 422 | out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) 423 | out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) 424 | 425 | assert len(out_gpu) == self.num_outputs 426 | out_split.append(out_gpu) 427 | 428 | with tf.device("/cpu:0"): 429 | out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] 430 | self._run_cache[key] = in_expr, out_expr 431 | 432 | # Run minibatches. 433 | in_expr, out_expr = self._run_cache[key] 434 | out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr] 435 | 436 | for mb_begin in range(0, num_items, minibatch_size): 437 | if print_progress: 438 | print("\r%d / %d" % (mb_begin, num_items), end="") 439 | 440 | mb_end = min(mb_begin + minibatch_size, num_items) 441 | mb_num = mb_end - mb_begin 442 | mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] 443 | mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) 444 | 445 | for dst, src in zip(out_arrays, mb_out): 446 | dst[mb_begin: mb_end] = src 447 | 448 | # Done. 449 | if print_progress: 450 | print("\r%d / %d" % (num_items, num_items)) 451 | 452 | if not return_as_list: 453 | out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) 454 | return out_arrays 455 | 456 | def list_ops(self) -> List[TfExpression]: 457 | include_prefix = self.scope + "/" 458 | exclude_prefix = include_prefix + "_" 459 | ops = tf.get_default_graph().get_operations() 460 | ops = [op for op in ops if op.name.startswith(include_prefix)] 461 | ops = [op for op in ops if not op.name.startswith(exclude_prefix)] 462 | return ops 463 | 464 | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: 465 | """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to 466 | individual layers of the network. Mainly intended to be used for reporting.""" 467 | layers = [] 468 | 469 | def recurse(scope, parent_ops, parent_vars, level): 470 | # Ignore specific patterns. 471 | if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): 472 | return 473 | 474 | # Filter ops and vars by scope. 475 | global_prefix = scope + "/" 476 | local_prefix = global_prefix[len(self.scope) + 1:] 477 | cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] 478 | cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] 479 | if not cur_ops and not cur_vars: 480 | return 481 | 482 | # Filter out all ops related to variables. 483 | for var in [op for op in cur_ops if op.type.startswith("Variable")]: 484 | var_prefix = var.name + "/" 485 | cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] 486 | 487 | # Scope does not contain ops as immediate children => recurse deeper. 488 | contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops) 489 | if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1: 490 | visited = set() 491 | for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: 492 | token = rel_name.split("/")[0] 493 | if token not in visited: 494 | recurse(global_prefix + token, cur_ops, cur_vars, level + 1) 495 | visited.add(token) 496 | return 497 | 498 | # Report layer. 499 | layer_name = scope[len(self.scope) + 1:] 500 | layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] 501 | layer_trainables = [var for _name, var in cur_vars if var.trainable] 502 | layers.append((layer_name, layer_output, layer_trainables)) 503 | 504 | recurse(self.scope, self.list_ops(), list(self.vars.items()), 0) 505 | return layers 506 | 507 | def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: 508 | """Print a summary table of the network structure.""" 509 | rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] 510 | rows += [["---"] * 4] 511 | total_params = 0 512 | 513 | for layer_name, layer_output, layer_trainables in self.list_layers(): 514 | num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables) 515 | weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] 516 | weights.sort(key=lambda x: len(x.name)) 517 | if len(weights) == 0 and len(layer_trainables) == 1: 518 | weights = layer_trainables 519 | total_params += num_params 520 | 521 | if not hide_layers_with_no_params or num_params != 0: 522 | num_params_str = str(num_params) if num_params > 0 else "-" 523 | output_shape_str = str(layer_output.shape) 524 | weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" 525 | rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] 526 | 527 | rows += [["---"] * 4] 528 | rows += [["Total", str(total_params), "", ""]] 529 | 530 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 531 | print() 532 | for row in rows: 533 | print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) 534 | print() 535 | 536 | def setup_weight_histograms(self, title: str = None) -> None: 537 | """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" 538 | if title is None: 539 | title = self.name 540 | 541 | with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): 542 | for local_name, var in self.trainables.items(): 543 | if "/" in local_name: 544 | p = local_name.split("/") 545 | name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) 546 | else: 547 | name = title + "_toplevel/" + local_name 548 | 549 | tf.summary.histogram(name, var) 550 | 551 | #---------------------------------------------------------------------------- 552 | # Backwards-compatible emulation of legacy output transformation in Network.run(). 553 | 554 | _print_legacy_warning = True 555 | 556 | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): 557 | global _print_legacy_warning 558 | legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] 559 | if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): 560 | return output_transform, dynamic_kwargs 561 | 562 | if _print_legacy_warning: 563 | _print_legacy_warning = False 564 | print() 565 | print("WARNING: Old-style output transformations in Network.run() are deprecated.") 566 | print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") 567 | print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") 568 | print() 569 | assert output_transform is None 570 | 571 | new_kwargs = dict(dynamic_kwargs) 572 | new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} 573 | new_transform["func"] = _legacy_output_transform_func 574 | return new_transform, new_kwargs 575 | 576 | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): 577 | if out_mul != 1.0: 578 | expr = [x * out_mul for x in expr] 579 | 580 | if out_add != 0.0: 581 | expr = [x + out_add for x in expr] 582 | 583 | if out_shrink > 1: 584 | ksize = [1, 1, out_shrink, out_shrink] 585 | expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] 586 | 587 | if out_dtype is not None: 588 | if tf.as_dtype(out_dtype).is_integer: 589 | expr = [tf.round(x) for x in expr] 590 | expr = [tf.saturate_cast(x, out_dtype) for x in expr] 591 | return expr 592 | -------------------------------------------------------------------------------- /dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | # Silence deprecation warnings from TensorFlow 1.13 onwards 15 | import logging 16 | logging.getLogger('tensorflow').setLevel(logging.ERROR) 17 | import tensorflow.contrib 18 | tf.contrib = tensorflow.contrib 19 | 20 | from typing import Any, Iterable, List, Union 21 | 22 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 23 | """A type that represents a valid Tensorflow expression.""" 24 | 25 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 26 | """A type that can be converted to a valid Tensorflow expression.""" 27 | 28 | 29 | def run(*args, **kwargs) -> Any: 30 | """Run the specified ops in the default session.""" 31 | assert_tf_initialized() 32 | return tf.get_default_session().run(*args, **kwargs) 33 | 34 | 35 | def is_tf_expression(x: Any) -> bool: 36 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 37 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 38 | 39 | 40 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 41 | """Convert a Tensorflow shape to a list of ints.""" 42 | return [dim.value for dim in shape] 43 | 44 | 45 | def flatten(x: TfExpressionEx) -> TfExpression: 46 | """Shortcut function for flattening a tensor.""" 47 | with tf.name_scope("Flatten"): 48 | return tf.reshape(x, [-1]) 49 | 50 | 51 | def log2(x: TfExpressionEx) -> TfExpression: 52 | """Logarithm in base 2.""" 53 | with tf.name_scope("Log2"): 54 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 55 | 56 | 57 | def exp2(x: TfExpressionEx) -> TfExpression: 58 | """Exponent in base 2.""" 59 | with tf.name_scope("Exp2"): 60 | return tf.exp(x * np.float32(np.log(2.0))) 61 | 62 | 63 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 64 | """Linear interpolation.""" 65 | with tf.name_scope("Lerp"): 66 | return a + (b - a) * t 67 | 68 | 69 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 70 | """Linear interpolation with clip.""" 71 | with tf.name_scope("LerpClip"): 72 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 73 | 74 | 75 | def absolute_name_scope(scope: str) -> tf.name_scope: 76 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 77 | return tf.name_scope(scope + "/") 78 | 79 | 80 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 81 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 82 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 83 | 84 | 85 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 86 | # Defaults. 87 | cfg = dict() 88 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 89 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 90 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 91 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 92 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 93 | 94 | # User overrides. 95 | if config_dict is not None: 96 | cfg.update(config_dict) 97 | return cfg 98 | 99 | 100 | def init_tf(config_dict: dict = None) -> None: 101 | """Initialize TensorFlow session using good default settings.""" 102 | # Skip if already initialized. 103 | if tf.get_default_session() is not None: 104 | return 105 | 106 | # Setup config dict and random seeds. 107 | cfg = _sanitize_tf_config(config_dict) 108 | np_random_seed = cfg["rnd.np_random_seed"] 109 | if np_random_seed is not None: 110 | np.random.seed(np_random_seed) 111 | tf_random_seed = cfg["rnd.tf_random_seed"] 112 | if tf_random_seed == "auto": 113 | tf_random_seed = np.random.randint(1 << 31) 114 | if tf_random_seed is not None: 115 | tf.set_random_seed(tf_random_seed) 116 | 117 | # Setup environment variables. 118 | for key, value in list(cfg.items()): 119 | fields = key.split(".") 120 | if fields[0] == "env": 121 | assert len(fields) == 2 122 | os.environ[fields[1]] = str(value) 123 | 124 | # Create default TensorFlow session. 125 | create_session(cfg, force_as_default=True) 126 | 127 | 128 | def assert_tf_initialized(): 129 | """Check that TensorFlow session has been initialized.""" 130 | if tf.get_default_session() is None: 131 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 132 | 133 | 134 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 135 | """Create tf.Session based on config dict.""" 136 | # Setup TensorFlow config proto. 137 | cfg = _sanitize_tf_config(config_dict) 138 | config_proto = tf.ConfigProto() 139 | for key, value in cfg.items(): 140 | fields = key.split(".") 141 | if fields[0] not in ["rnd", "env"]: 142 | obj = config_proto 143 | for field in fields[:-1]: 144 | obj = getattr(obj, field) 145 | setattr(obj, fields[-1], value) 146 | 147 | # Create session. 148 | session = tf.Session(config=config_proto) 149 | if force_as_default: 150 | # pylint: disable=protected-access 151 | session._default_session = session.as_default() 152 | session._default_session.enforce_nesting = False 153 | session._default_session.__enter__() # pylint: disable=no-member 154 | 155 | return session 156 | 157 | 158 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 159 | """Initialize all tf.Variables that have not already been initialized. 160 | 161 | Equivalent to the following, but more efficient and does not bloat the tf graph: 162 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 163 | """ 164 | assert_tf_initialized() 165 | if target_vars is None: 166 | target_vars = tf.global_variables() 167 | 168 | test_vars = [] 169 | test_ops = [] 170 | 171 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 172 | for var in target_vars: 173 | assert is_tf_expression(var) 174 | 175 | try: 176 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 177 | except KeyError: 178 | # Op does not exist => variable may be uninitialized. 179 | test_vars.append(var) 180 | 181 | with absolute_name_scope(var.name.split(":")[0]): 182 | test_ops.append(tf.is_variable_initialized(var)) 183 | 184 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 185 | run([var.initializer for var in init_vars]) 186 | 187 | 188 | def set_vars(var_to_value_dict: dict) -> None: 189 | """Set the values of given tf.Variables. 190 | 191 | Equivalent to the following, but more efficient and does not bloat the tf graph: 192 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 193 | """ 194 | assert_tf_initialized() 195 | ops = [] 196 | feed_dict = {} 197 | 198 | for var, value in var_to_value_dict.items(): 199 | assert is_tf_expression(var) 200 | 201 | try: 202 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 203 | except KeyError: 204 | with absolute_name_scope(var.name.split(":")[0]): 205 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 206 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 207 | 208 | ops.append(setter) 209 | feed_dict[setter.op.inputs[1]] = value 210 | 211 | run(ops, feed_dict) 212 | 213 | 214 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 215 | """Create tf.Variable with large initial value without bloating the tf graph.""" 216 | assert_tf_initialized() 217 | assert isinstance(initial_value, np.ndarray) 218 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 219 | var = tf.Variable(zeros, *args, **kwargs) 220 | set_vars({var: initial_value}) 221 | return var 222 | 223 | 224 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 225 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 226 | Can be used as an input transformation for Network.run(). 227 | """ 228 | images = tf.cast(images, tf.float32) 229 | if nhwc_to_nchw: 230 | images = tf.transpose(images, [0, 3, 1, 2]) 231 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 232 | 233 | 234 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 235 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 236 | Can be used as an output transformation for Network.run(). 237 | """ 238 | images = tf.cast(images, tf.float32) 239 | if shrink > 1: 240 | ksize = [1, 1, shrink, shrink] 241 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 242 | if nchw_to_nhwc: 243 | images = tf.transpose(images, [0, 2, 3, 1]) 244 | scale = 255 / (drange[1] - drange[0]) 245 | images = images * scale + (0.5 - drange[0] * scale) 246 | return tf.saturate_cast(images, tf.uint8) 247 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import uuid 27 | 28 | from distutils.util import strtobool 29 | from typing import Any, List, Tuple, Union 30 | 31 | 32 | # Util classes 33 | # ------------------------------------------------------------------------------------------ 34 | 35 | 36 | class EasyDict(dict): 37 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 38 | 39 | def __getattr__(self, name: str) -> Any: 40 | try: 41 | return self[name] 42 | except KeyError: 43 | raise AttributeError(name) 44 | 45 | def __setattr__(self, name: str, value: Any) -> None: 46 | self[name] = value 47 | 48 | def __delattr__(self, name: str) -> None: 49 | del self[name] 50 | 51 | 52 | class Logger(object): 53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 54 | 55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 56 | self.file = None 57 | 58 | if file_name is not None: 59 | self.file = open(file_name, file_mode) 60 | 61 | self.should_flush = should_flush 62 | self.stdout = sys.stdout 63 | self.stderr = sys.stderr 64 | 65 | sys.stdout = self 66 | sys.stderr = self 67 | 68 | def __enter__(self) -> "Logger": 69 | return self 70 | 71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 72 | self.close() 73 | 74 | def write(self, text: str) -> None: 75 | """Write text to stdout (and a file) and optionally flush.""" 76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 77 | return 78 | 79 | if self.file is not None: 80 | self.file.write(text) 81 | 82 | self.stdout.write(text) 83 | 84 | if self.should_flush: 85 | self.flush() 86 | 87 | def flush(self) -> None: 88 | """Flush written text to both stdout and a file, if open.""" 89 | if self.file is not None: 90 | self.file.flush() 91 | 92 | self.stdout.flush() 93 | 94 | def close(self) -> None: 95 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 96 | self.flush() 97 | 98 | # if using multiple loggers, prevent closing in wrong order 99 | if sys.stdout is self: 100 | sys.stdout = self.stdout 101 | if sys.stderr is self: 102 | sys.stderr = self.stderr 103 | 104 | if self.file is not None: 105 | self.file.close() 106 | 107 | 108 | # Small util functions 109 | # ------------------------------------------------------------------------------------------ 110 | 111 | 112 | def format_time(seconds: Union[int, float]) -> str: 113 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 114 | s = int(np.rint(seconds)) 115 | 116 | if s < 60: 117 | return "{0}s".format(s) 118 | elif s < 60 * 60: 119 | return "{0}m {1:02}s".format(s // 60, s % 60) 120 | elif s < 24 * 60 * 60: 121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 122 | else: 123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 124 | 125 | 126 | def ask_yes_no(question: str) -> bool: 127 | """Ask the user the question until the user inputs a valid answer.""" 128 | while True: 129 | try: 130 | print("{0} [y/n]".format(question)) 131 | return strtobool(input().lower()) 132 | except ValueError: 133 | pass 134 | 135 | 136 | def tuple_product(t: Tuple) -> Any: 137 | """Calculate the product of the tuple elements.""" 138 | result = 1 139 | 140 | for v in t: 141 | result *= v 142 | 143 | return result 144 | 145 | 146 | _str_to_ctype = { 147 | "uint8": ctypes.c_ubyte, 148 | "uint16": ctypes.c_uint16, 149 | "uint32": ctypes.c_uint32, 150 | "uint64": ctypes.c_uint64, 151 | "int8": ctypes.c_byte, 152 | "int16": ctypes.c_int16, 153 | "int32": ctypes.c_int32, 154 | "int64": ctypes.c_int64, 155 | "float32": ctypes.c_float, 156 | "float64": ctypes.c_double 157 | } 158 | 159 | 160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 162 | type_str = None 163 | 164 | if isinstance(type_obj, str): 165 | type_str = type_obj 166 | elif hasattr(type_obj, "__name__"): 167 | type_str = type_obj.__name__ 168 | elif hasattr(type_obj, "name"): 169 | type_str = type_obj.name 170 | else: 171 | raise RuntimeError("Cannot infer type name from input") 172 | 173 | assert type_str in _str_to_ctype.keys() 174 | 175 | my_dtype = np.dtype(type_str) 176 | my_ctype = _str_to_ctype[type_str] 177 | 178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 179 | 180 | return my_dtype, my_ctype 181 | 182 | 183 | def is_pickleable(obj: Any) -> bool: 184 | try: 185 | with io.BytesIO() as stream: 186 | pickle.dump(obj, stream) 187 | return True 188 | except: 189 | return False 190 | 191 | 192 | # Functionality to import modules/objects by name, and call functions by name 193 | # ------------------------------------------------------------------------------------------ 194 | 195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 196 | """Searches for the underlying module behind the name to some python object. 197 | Returns the module and the object name (original name with module part removed).""" 198 | 199 | # allow convenience shorthands, substitute them by full names 200 | obj_name = re.sub("^np.", "numpy.", obj_name) 201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 202 | 203 | # list alternatives for (module_name, local_obj_name) 204 | parts = obj_name.split(".") 205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 206 | 207 | # try each alternative in turn 208 | for module_name, local_obj_name in name_pairs: 209 | try: 210 | module = importlib.import_module(module_name) # may raise ImportError 211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 212 | return module, local_obj_name 213 | except: 214 | pass 215 | 216 | # maybe some of the modules themselves contain errors? 217 | for module_name, _local_obj_name in name_pairs: 218 | try: 219 | importlib.import_module(module_name) # may raise ImportError 220 | except ImportError: 221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 222 | raise 223 | 224 | # maybe the requested attribute is missing? 225 | for module_name, local_obj_name in name_pairs: 226 | try: 227 | module = importlib.import_module(module_name) # may raise ImportError 228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 229 | except ImportError: 230 | pass 231 | 232 | # we are out of luck, but we have no idea why 233 | raise ImportError(obj_name) 234 | 235 | 236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 237 | """Traverses the object name and returns the last (rightmost) python object.""" 238 | if obj_name == '': 239 | return module 240 | obj = module 241 | for part in obj_name.split("."): 242 | obj = getattr(obj, part) 243 | return obj 244 | 245 | 246 | def get_obj_by_name(name: str) -> Any: 247 | """Finds the python object with the given name.""" 248 | module, obj_name = get_module_from_obj_name(name) 249 | return get_obj_from_module(module, obj_name) 250 | 251 | 252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 253 | """Finds the python object with the given name and calls it as a function.""" 254 | assert func_name is not None 255 | func_obj = get_obj_by_name(func_name) 256 | assert callable(func_obj) 257 | return func_obj(*args, **kwargs) 258 | 259 | 260 | def get_module_dir_by_obj_name(obj_name: str) -> str: 261 | """Get the directory path of the module containing the given object name.""" 262 | module, _ = get_module_from_obj_name(obj_name) 263 | return os.path.dirname(inspect.getfile(module)) 264 | 265 | 266 | def is_top_level_function(obj: Any) -> bool: 267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 269 | 270 | 271 | def get_top_level_function_name(obj: Any) -> str: 272 | """Return the fully-qualified name of a top-level function.""" 273 | assert is_top_level_function(obj) 274 | return obj.__module__ + "." + obj.__name__ 275 | 276 | 277 | # File system helpers 278 | # ------------------------------------------------------------------------------------------ 279 | 280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 281 | """List all files recursively in a given directory while ignoring given file and directory names. 282 | Returns list of tuples containing both absolute and relative paths.""" 283 | assert os.path.isdir(dir_path) 284 | base_name = os.path.basename(os.path.normpath(dir_path)) 285 | 286 | if ignores is None: 287 | ignores = [] 288 | 289 | result = [] 290 | 291 | for root, dirs, files in os.walk(dir_path, topdown=True): 292 | for ignore_ in ignores: 293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 294 | 295 | # dirs need to be edited in-place 296 | for d in dirs_to_remove: 297 | dirs.remove(d) 298 | 299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 300 | 301 | absolute_paths = [os.path.join(root, f) for f in files] 302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 303 | 304 | if add_base_to_relative: 305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 306 | 307 | assert len(absolute_paths) == len(relative_paths) 308 | result += zip(absolute_paths, relative_paths) 309 | 310 | return result 311 | 312 | 313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 314 | """Takes in a list of tuples of (src, dst) paths and copies files. 315 | Will create all necessary directories.""" 316 | for file in files: 317 | target_dir_name = os.path.dirname(file[1]) 318 | 319 | # will create all intermediate-level directories 320 | if not os.path.exists(target_dir_name): 321 | os.makedirs(target_dir_name) 322 | 323 | shutil.copyfile(file[0], file[1]) 324 | 325 | 326 | # URL helpers 327 | # ------------------------------------------------------------------------------------------ 328 | 329 | def is_url(obj: Any) -> bool: 330 | """Determine whether the given object is a valid URL string.""" 331 | if not isinstance(obj, str) or not "://" in obj: 332 | return False 333 | try: 334 | res = requests.compat.urlparse(obj) 335 | if not res.scheme or not res.netloc or not "." in res.netloc: 336 | return False 337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 338 | if not res.scheme or not res.netloc or not "." in res.netloc: 339 | return False 340 | except: 341 | return False 342 | return True 343 | 344 | 345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: 346 | """Download the given URL and return a binary-mode file object to access the data.""" 347 | assert is_url(url) 348 | assert num_attempts >= 1 349 | 350 | # Lookup from cache. 351 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 352 | if cache_dir is not None: 353 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 354 | if len(cache_files) == 1: 355 | return open(cache_files[0], "rb") 356 | 357 | # Download. 358 | url_name = None 359 | url_data = None 360 | with requests.Session() as session: 361 | if verbose: 362 | print("Downloading %s ..." % url, end="", flush=True) 363 | for attempts_left in reversed(range(num_attempts)): 364 | try: 365 | with session.get(url) as res: 366 | res.raise_for_status() 367 | if len(res.content) == 0: 368 | raise IOError("No data received") 369 | 370 | if len(res.content) < 8192: 371 | content_str = res.content.decode("utf-8") 372 | if "download_warning" in res.headers.get("Set-Cookie", ""): 373 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 374 | if len(links) == 1: 375 | url = requests.compat.urljoin(url, links[0]) 376 | raise IOError("Google Drive virus checker nag") 377 | if "Google Drive - Quota exceeded" in content_str: 378 | raise IOError("Google Drive quota exceeded") 379 | 380 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 381 | url_name = match[1] if match else url 382 | url_data = res.content 383 | if verbose: 384 | print(" done") 385 | break 386 | except: 387 | if not attempts_left: 388 | if verbose: 389 | print(" failed") 390 | raise 391 | if verbose: 392 | print(".", end="", flush=True) 393 | 394 | # Save to cache. 395 | if cache_dir is not None: 396 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 397 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 398 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 399 | os.makedirs(cache_dir, exist_ok=True) 400 | with open(temp_file, "wb") as f: 401 | f.write(url_data) 402 | os.replace(temp_file, cache_file) # atomic 403 | 404 | # Return data as file object. 405 | return io.BytesIO(url_data) 406 | -------------------------------------------------------------------------------- /download_kodak.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import os 9 | import sys 10 | import argparse 11 | 12 | from urllib.request import urlretrieve 13 | 14 | examples='''examples: 15 | 16 | python %(prog)s --output-dir=./tmp 17 | ''' 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser( 21 | description='Download the Kodak dataset .PNG image files.', 22 | epilog=examples, 23 | formatter_class=argparse.RawDescriptionHelpFormatter 24 | 25 | ) 26 | parser.add_argument("--output-dir", help="Directory where to save the Kodak dataset .PNGs") 27 | args = parser.parse_args() 28 | 29 | if args.output_dir is None: 30 | print ('Must specify output directory where to store tfrecords with --output-dir') 31 | sys.exit(1) 32 | 33 | os.makedirs(args.output_dir, exist_ok=True) 34 | 35 | for i in range(1, 25): 36 | imgname = 'kodim%02d.png' % i 37 | url = "http://r0k.us/graphics/kodak/kodak/" + imgname 38 | print ('Downloading', url) 39 | urlretrieve(url, os.path.join(args.output_dir, imgname)) 40 | print ('Kodak validation set successfully downloaded.') 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /img/readme_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/selfsupervised-denoising/a41e0dbf135f532c892834b2aeedeaae79d9be9b/img/readme_figure.png -------------------------------------------------------------------------------- /selfsupervised_denoising.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import argparse 9 | import os 10 | import sys 11 | import time 12 | import numpy as np 13 | import imageio 14 | 15 | import h5py 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | import tensorflow as tf 18 | import PIL.Image 19 | import math 20 | import glob 21 | import pickle 22 | import re 23 | 24 | import dnnlib 25 | import dnnlib.tflib 26 | import dnnlib.tflib.tfutil as tfutil 27 | from dnnlib.tflib.autosummary import autosummary 28 | import dnnlib.submission.submit as submit 29 | 30 | #---------------------------------------------------------------------------- 31 | # Misc helpers. 32 | 33 | def init_tf(seed=None): 34 | config_dict = {'graph_options.place_pruned_graph': True, 'gpu_options.allow_growth': True} 35 | if tf.get_default_session() is None: 36 | tf.set_random_seed(np.random.randint(1 << 31) if (seed is None) else seed) 37 | tfutil.create_session(config_dict, force_as_default=True) 38 | 39 | def adjust_dynamic_range(data, drange_in, drange_out): 40 | if drange_in != drange_out: 41 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 42 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 43 | data = data * scale + bias 44 | return data 45 | 46 | def convert_to_pil_image(image, drange=[0,1]): 47 | assert image.ndim == 2 or image.ndim == 3 48 | if image.ndim == 3: 49 | if image.shape[0] == 1: 50 | image = image[0] # grayscale CHW => HW 51 | else: 52 | image = image.transpose(1, 2, 0) # CHW -> HWC 53 | 54 | image = adjust_dynamic_range(image, drange, [0, 255]) 55 | image = np.rint(image).clip(0, 255).astype(np.uint8) 56 | fmt = 'RGB' if image.ndim == 3 else 'L' 57 | return PIL.Image.fromarray(image, fmt) 58 | 59 | def save_image(image, filename, drange=[0,1], quality=95): 60 | img = convert_to_pil_image(image, drange) 61 | if '.jpg' in filename: 62 | img.save(filename,"JPEG", quality=quality, optimize=True) 63 | else: 64 | img.save(filename) 65 | 66 | def save_snapshot(submit_config, net, fname_postfix): 67 | dump_fname = os.path.join(submit_config.run_dir, "network-%s.pickle" % fname_postfix) 68 | with open(dump_fname, "wb") as f: 69 | pickle.dump(net, f) 70 | 71 | def compute_ramped_lrate(i, iteration_count, ramp_up_fraction, ramp_down_fraction, learning_rate): 72 | if ramp_up_fraction > 0.0: 73 | ramp_up_end_iter = iteration_count * ramp_up_fraction 74 | if i <= ramp_up_end_iter: 75 | t = (i / ramp_up_fraction) / iteration_count 76 | learning_rate = learning_rate * (0.5 - np.cos(t * np.pi)/2) 77 | 78 | if ramp_down_fraction > 0.0: 79 | ramp_down_start_iter = iteration_count * (1 - ramp_down_fraction) 80 | if i >= ramp_down_start_iter: 81 | t = ((i - ramp_down_start_iter) / ramp_down_fraction) / iteration_count 82 | learning_rate = learning_rate * (0.5 + np.cos(t * np.pi)/2)**2 83 | 84 | return learning_rate 85 | 86 | def clip_to_uint8(arr): 87 | if isinstance(arr, np.ndarray): 88 | return np.clip(arr * 255.0 + 0.5, 0, 255).astype(np.uint8) 89 | x = tf.clip_by_value(arr * 255.0 + 0.5, 0, 255) 90 | return tf.cast(x, tf.uint8) 91 | 92 | def calculate_psnr(a, b, axis=None): 93 | a, b = [clip_to_uint8(x) for x in [a, b]] 94 | if isinstance(a, np.ndarray): 95 | a, b = [x.astype(np.float32) for x in [a, b]] 96 | x = np.mean((a - b)**2, axis=axis) 97 | return np.log10((255 * 255) / x) * 10.0 98 | a, b = [tf.cast(x, tf.float32) for x in [a, b]] 99 | x = tf.reduce_mean((a - b)**2, axis=axis) 100 | return tf.log((255 * 255) / x) * (10.0 / math.log(10)) 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | def poisson(x, lam): 105 | if lam > 0.0: 106 | return np.random.poisson(x * lam) / lam 107 | return 0.0 * x 108 | 109 | #---------------------------------------------------------------------------- 110 | 111 | # Number of channels enforcer while retaining dtype. 112 | def set_color_channels(x, num_channels): 113 | assert x.shape[0] in [1, 3, 4] 114 | x = x[:min(x.shape[0], 3)] # drop possible alpha channel 115 | if x.shape[0] == num_channels: 116 | return x 117 | elif x.shape[0] == 1: 118 | return np.tile(x, [3, 1, 1]) 119 | y = np.mean(x, axis=0, keepdims=True) 120 | if np.issubdtype(x.dtype, np.integer): 121 | y = np.round(y).astype(x.dtype) 122 | return y 123 | 124 | #---------------------------------------------------------------------------- 125 | 126 | def load_datasets(num_channels, dataset_dir, train_dataset, validation_dataset, prune_dataset=None): 127 | # Training set. 128 | 129 | if train_dataset is None: 130 | print("Not loading training data.") 131 | train_images = [] 132 | else: 133 | fn = submit.get_path_from_template(train_dataset) 134 | print("Loading training dataset from '%s'." % fn) 135 | 136 | h5file = h5py.File(fn, 'r') 137 | num = h5file['images'].shape[0] 138 | print("Dataset contains %d images." % num) 139 | 140 | if prune_dataset is not None: 141 | num = prune_dataset 142 | print("Pruned down to %d first images." % num) 143 | 144 | # Load the images. 145 | train_images = [None] * num 146 | bs = 1024 147 | for i in range(0, num, bs): 148 | sys.stdout.write("\r%d / %d .." % (i, num)) 149 | n = min(bs, num - i) 150 | img = h5file['images'][i : i+n] 151 | shp = h5file['shapes'][i : i+n] 152 | for j in range(n): 153 | train_images[i+j] = set_color_channels(np.reshape(img[j], shp[j]), num_channels) 154 | 155 | print("\nLoading done.") 156 | h5file.close() 157 | 158 | if validation_dataset in ['kodak', 'bsd300', 'set14']: 159 | paths = { 'kodak': os.path.join(dataset_dir, 'kodak', '*.png'), 160 | 'bsd300': os.path.join(dataset_dir, 'BSDS300', 'images/test/*.jpg'), # Just the 100 test images 161 | 'set14': os.path.join(dataset_dir, 'Set14', '*.png')} 162 | fn = submit.get_path_from_template(paths[validation_dataset]) 163 | print("Loading validation dataset from '%s'." % fn) 164 | validation_images = [imageio.imread(x, ignoregamma=True) for x in glob.glob(fn)] 165 | validation_images = [x[..., np.newaxis] if len(x.shape) == 2 else x for x in validation_images] # Add channel axis to grayscale images. 166 | validation_images = [x.transpose([2, 0, 1]) for x in validation_images] 167 | validation_images = [set_color_channels(x, num_channels) for x in validation_images] # Enforce RGB/grayscale mode. 168 | print("Loaded %d images." % len(validation_images)) 169 | 170 | # Pad the validation images to size. 171 | validation_image_size = [max([x.shape[axis] for x in validation_images]) for axis in [1, 2]] 172 | validation_image_size = [(x + 31) // 32 * 32 for x in validation_image_size] # Round up to a multiple of 32. 173 | validation_image_size = [max(validation_image_size) for x in validation_image_size] # Square it up for the rotators. 174 | print("Validation image padded size = [%d, %d]." % (validation_image_size[0], validation_image_size[1])) 175 | 176 | return train_images, validation_images, validation_image_size 177 | 178 | #---------------------------------------------------------------------------- 179 | # Backbone autoencoder network, optional blind spot. 180 | 181 | def analysis_network(image, num_output_components, blindspot, zero_last=False): 182 | 183 | def conv(n, name, n_out, size=3, gain=np.sqrt(2), zero_weights=False): 184 | if blindspot: assert (size % 2) == 1 185 | ofs = 0 if (not blindspot) else size // 2 186 | 187 | with tf.variable_scope(name): 188 | wshape = [size, size, n.shape[1].value, n_out] 189 | wstd = gain / np.sqrt(np.prod(wshape[:-1])) # He init. 190 | W = tf.get_variable('W', shape=wshape, initializer=(tf.initializers.zeros() if zero_weights else tf.initializers.random_normal(0., wstd))) 191 | b = tf.get_variable('b', shape=[n_out], initializer=tf.initializers.zeros()) 192 | if ofs > 0: n = tf.pad(n, [[0, 0], [0, 0], [ofs, 0], [0, 0]]) 193 | n = tf.nn.conv2d(n, W, strides=[1]*4, padding='SAME', data_format='NCHW') + tf.reshape(b, [1, -1, 1, 1]) 194 | if ofs > 0: n = n[:, :, :-ofs, :] 195 | return n 196 | 197 | def up(n, name): 198 | with tf.name_scope(name): 199 | s = tf.shape(n) 200 | s = [-1, n.shape[1], s[2], s[3]] 201 | n = tf.reshape(n, [s[0], s[1], s[2], 1, s[3], 1]) 202 | n = tf.tile(n, [1, 1, 1, 2, 1, 2]) 203 | n = tf.reshape(n, [s[0], s[1], s[2] * 2, s[3] * 2]) 204 | return n 205 | 206 | def down(n, name): 207 | with tf.name_scope(name): 208 | if blindspot: # Shift and pad if blindspot. 209 | n = tf.pad(n[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]]) 210 | n = tf.nn.max_pool(n, ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2], padding='SAME', data_format='NCHW') 211 | return n 212 | 213 | def rotate(x, angle): 214 | if angle == 0: return x 215 | elif angle == 90: return tf.transpose(x[:, :, :, ::-1], [0, 1, 3, 2]) 216 | elif angle == 180: return x[:, :, ::-1, ::-1] 217 | elif angle == 270: return tf.transpose(x[:, :, ::-1, :], [0, 1, 3, 2]) 218 | 219 | def concat(name, layers): 220 | return tf.concat(layers, axis=1, name=name) 221 | 222 | def LR(n, alpha=0.1): 223 | return tf.nn.leaky_relu(n, alpha=alpha, name='lrelu') 224 | 225 | # Input stage. 226 | 227 | if not blindspot: 228 | x = image 229 | else: 230 | x = tf.concat([rotate(image, a) for a in [0, 90, 180, 270]], axis=0) 231 | 232 | # Encoder part. 233 | 234 | pool0 = x 235 | x = LR(conv(x, 'enc_conv0', 48)) 236 | x = LR(conv(x, 'enc_conv1', 48)) 237 | x = down(x, 'pool1'); pool1 = x 238 | 239 | x = LR(conv(x, 'enc_conv2', 48)) 240 | x = down(x, 'pool2'); pool2 = x 241 | 242 | x = LR(conv(x, 'enc_conv3', 48)) 243 | x = down(x, 'pool3'); pool3 = x 244 | 245 | x = LR(conv(x, 'enc_conv4', 48)) 246 | x = down(x, 'pool4'); pool4 = x 247 | 248 | x = LR(conv(x, 'enc_conv5', 48)) 249 | x = down(x, 'pool5') 250 | 251 | x = LR(conv(x, 'enc_conv6', 48)) 252 | 253 | # Decoder part. 254 | 255 | x = up(x, 'upsample5') 256 | x = concat('concat5', [x, pool4]) 257 | x = LR(conv(x, 'dec_conv5a', 96)) 258 | x = LR(conv(x, 'dec_conv5b', 96)) 259 | 260 | x = up(x, 'upsample4') 261 | x = concat('concat4', [x, pool3]) 262 | x = LR(conv(x, 'dec_conv4a', 96)) 263 | x = LR(conv(x, 'dec_conv4b', 96)) 264 | 265 | x = up(x, 'upsample3') 266 | x = concat('concat3', [x, pool2]) 267 | x = LR(conv(x, 'dec_conv3a', 96)) 268 | x = LR(conv(x, 'dec_conv3b', 96)) 269 | 270 | x = up(x, 'upsample2') 271 | x = concat('concat2', [x, pool1]) 272 | x = LR(conv(x, 'dec_conv2a', 96)) 273 | x = LR(conv(x, 'dec_conv2b', 96)) 274 | 275 | x = up(x, 'upsample1') 276 | x = concat('concat1', [x, pool0]) 277 | 278 | # Output stages. 279 | 280 | if blindspot: 281 | # Blind-spot output stages. 282 | x = LR(conv(x, 'dec_conv1a', 96)) 283 | x = LR(conv(x, 'dec_conv1b', 96)) 284 | x = tf.pad(x[:, :, :-1, :], [[0, 0], [0, 0], [1, 0], [0, 0]]) # Shift and pad. 285 | x = tf.split(x, 4, axis=0) # Split into rotations. 286 | x = [rotate(y, a) for y, a in zip(x, [0, 270, 180, 90])] # Counterrotate. 287 | x = tf.concat(x, axis=1) # Combine on channel axis. 288 | x = LR(conv(x, 'nin_a', 96*4, size=1)) 289 | x = LR(conv(x, 'nin_b', 96, size=1)) 290 | x = conv(x, 'nin_c', num_output_components, size=1, gain=1.0, zero_weights=zero_last) 291 | else: 292 | # Baseline network with postprocessing layers -- keep feature maps and distill with 1x1 convolutions. 293 | x = LR(conv(x, 'dec_conv1a', 96)) 294 | x = LR(conv(x, 'dec_conv1b', 96)) 295 | x = LR(conv(x, 'nin_a', 96, size=1)) 296 | x = LR(conv(x, 'nin_b', 96, size=1)) 297 | x = conv(x, 'nin_c', num_output_components, size=1, gain=1.0, zero_weights=zero_last) 298 | 299 | # Return results. 300 | 301 | return x 302 | 303 | #---------------------------------------------------------------------------- 304 | 305 | def blindspot_pipeline(noisy_in, 306 | noise_params_in, 307 | diagonal_covariance = False, 308 | input_shape = None, 309 | noise_style = None, 310 | noise_params = None, 311 | **_kwargs): 312 | 313 | num_channels = input_shape[1] 314 | assert num_channels in [1, 3] 315 | assert noise_style in ['gauss', 'poisson', 'impulse'] 316 | assert noise_params in ['known', 'global', 'per_image'] 317 | 318 | # Shapes. 319 | noisy_in.set_shape(input_shape) 320 | noise_params_in.set_shape(input_shape[:1] + [1, 1, 1]) 321 | 322 | # Clean data distribution. 323 | num_output_components = num_channels + (num_channels * (num_channels + 1)) // 2 # Means, triangular A. 324 | if diagonal_covariance: 325 | num_output_components = num_channels * 2 # Means, diagonal of A. 326 | net_out = analysis_network(noisy_in, num_output_components, blindspot=True) 327 | net_out = tf.cast(net_out, tf.float64) 328 | mu_x = net_out[:, 0:num_channels, ...] # Means (NCHW). 329 | A_c = net_out[:, num_channels:num_output_components, ...] # Components ot triangular A. 330 | if num_channels == 1: 331 | sigma_x = A_c ** 2 # N1HW 332 | elif num_channels == 3: 333 | A_c = tf.transpose(A_c, [0, 2, 3, 1]) # NHWC 334 | if diagonal_covariance: 335 | c00 = A_c[..., 0]**2 336 | c11 = A_c[..., 1]**2 337 | c22 = A_c[..., 2]**2 338 | zro = tf.zeros_like(c00) 339 | c0 = tf.stack([c00, zro, zro], axis=-1) # NHW3 340 | c1 = tf.stack([zro, c11, zro], axis=-1) # NHW3 341 | c2 = tf.stack([zro, zro, c22], axis=-1) # NHW3 342 | else: 343 | # Calculate A^T * A 344 | c00 = A_c[..., 0]**2 + A_c[..., 1]**2 + A_c[..., 2]**2 # NHW 345 | c01 = A_c[..., 1]*A_c[..., 3] + A_c[..., 2]*A_c[..., 4] 346 | c02 = A_c[..., 2]*A_c[..., 5] 347 | c11 = A_c[..., 3]**2 + A_c[..., 4]**2 348 | c12 = A_c[..., 4]*A_c[..., 5] 349 | c22 = A_c[..., 5]**2 350 | c0 = tf.stack([c00, c01, c02], axis=-1) # NHW3 351 | c1 = tf.stack([c01, c11, c12], axis=-1) # NHW3 352 | c2 = tf.stack([c02, c12, c22], axis=-1) # NHW3 353 | sigma_x = tf.stack([c0, c1, c2], axis=-1) # NHW33 354 | 355 | # Data on which noise parameter estimation is based. 356 | if noise_params == 'global': 357 | # Global constant over the entire dataset. 358 | noise_est_out = tf.get_variable('noise_data', shape=[1, 1, 1, 1], initializer=tf.initializers.constant(0.0)) # 1111 359 | noise_est_out = tf.cast(noise_est_out, tf.float64) 360 | elif noise_params == 'per_image': 361 | # Separate analysis network. 362 | with tf.variable_scope('param_estimation_net'): 363 | noise_est_out = analysis_network(noisy_in, 1, blindspot=False, zero_last=True) # N1HW 364 | noise_est_out = tf.reduce_mean(noise_est_out, axis=[2, 3], keepdims=True) # N111 365 | noise_est_out = tf.cast(noise_est_out, tf.float64) 366 | 367 | # Cast remaining data into float64. 368 | noisy_in = tf.cast(noisy_in, tf.float64) 369 | noise_params_in = tf.cast(noise_params_in, tf.float64) 370 | 371 | # Remap noise estimate to ensure it is always positive and starts near zero. 372 | if noise_params != 'known': 373 | noise_est_out = tf.nn.softplus(noise_est_out - 4.0) + 1e-3 374 | 375 | # Distill noise parameters from learned/known data. 376 | if noise_style == 'gauss': 377 | if noise_params == 'known': 378 | noise_std = tf.maximum(noise_params_in, 1e-3) # N111 379 | else: 380 | noise_std = noise_est_out 381 | elif noise_style == 'poisson': # Simple signal-dependent Poisson approximation [Hasinoff 2012]. 382 | if noise_params == 'known': 383 | noise_std = (tf.maximum(mu_x, tf.constant(1e-3, tf.float64)) / noise_params_in) ** 0.5 # NCHW 384 | else: 385 | noise_std = (tf.maximum(mu_x, tf.constant(1e-3, tf.float64)) * noise_est_out) ** 0.5 # NCHW 386 | elif noise_style == 'impulse': 387 | if noise_params == 'known': 388 | noise_std = noise_params_in # N111, actually the alpha. 389 | else: 390 | noise_std = noise_est_out 391 | 392 | # Casts and vars. 393 | noise_std = tf.cast(noise_std, tf.float64) 394 | I = tf.eye(num_channels, batch_shape=[1, 1, 1], dtype=tf.float64) 395 | Ieps = I * tf.constant(1e-6, dtype=tf.float64) 396 | zero64 = tf.constant(0.0, dtype=tf.float64) 397 | 398 | # Helpers. 399 | def batch_mvmul(m, v): # Batched (M * v). 400 | return tf.reduce_sum(m * v[..., tf.newaxis, :], axis=-1) 401 | def batch_vtmv(v, m): # Batched (v^T * M * v). 402 | return tf.reduce_sum(v[..., :, tf.newaxis] * v[..., tf.newaxis, :] * m, axis=[-2, -1]) 403 | def batch_vvt(v): # Batched (v * v^T). 404 | return v[..., :, tf.newaxis] * v[..., tf.newaxis, :] 405 | 406 | # Negative log-likelihood loss and posterior mean estimation. 407 | if noise_style in ['gauss', 'poisson']: 408 | if num_channels == 1: 409 | sigma_n = noise_std**2 # N111 / N1HW 410 | sigma_y = sigma_x + sigma_n # N1HW. Total variance. 411 | loss_out = ((noisy_in - mu_x) ** 2) / sigma_y + tf.log(sigma_y) # N1HW 412 | pme_out = (noisy_in * sigma_x + mu_x * sigma_n) / (sigma_x + sigma_n) # N1HW 413 | net_std_out = (sigma_x**0.5)[:, 0, ...] # NHW 414 | noise_std_out = noise_std[:, 0, ...] # N11 / NHW 415 | if noise_params != 'known': 416 | loss_out = loss_out - 0.1 * noise_std # Balance regularization. 417 | else: 418 | # Training loss. 419 | sigma_n = tf.transpose(noise_std**2, [0, 2, 3, 1])[..., tf.newaxis] * I # NHWC1 * NHWCC = NHWCC 420 | sigma_y = sigma_x + sigma_n # NHWCC, total covariance matrix. Cannot be singular because sigma_n is at least a small diagonal. 421 | sigma_y_inv = tf.linalg.inv(sigma_y) # NHWCC 422 | mu_x2 = tf.transpose(mu_x, [0, 2, 3, 1]) # NHWC 423 | noisy_in2 = tf.transpose(noisy_in, [0, 2, 3, 1]) # NHWC 424 | diff = (noisy_in2 - mu_x2) # NHWC 425 | diff = -0.5 * batch_vtmv(diff, sigma_y_inv) # NHW 426 | dets = tf.linalg.det(sigma_y) # NHW 427 | dets = tf.maximum(zero64, dets) # NHW. Avoid division by zero and negative square roots. 428 | loss_out = 0.5 * tf.log(dets) - diff # NHW 429 | if noise_params != 'known': 430 | loss_out = loss_out - 0.1 * tf.reduce_mean(noise_std, axis=1) # Balance regularization. 431 | 432 | # Posterior mean estimate. 433 | sigma_x_inv = tf.linalg.inv(sigma_x + Ieps) # NHWCC 434 | sigma_n_inv = tf.linalg.inv(sigma_n + Ieps) # NHWCC 435 | pme_c1 = tf.linalg.inv(sigma_x_inv + sigma_n_inv + Ieps) # NHWCC 436 | pme_c2 = batch_mvmul(sigma_x_inv, mu_x2) # NHWCC * NHWC -> NHWC 437 | pme_c2 = pme_c2 + batch_mvmul(sigma_n_inv, noisy_in2) # NHWC 438 | pme_out = batch_mvmul(pme_c1, pme_c2) # NHWC 439 | pme_out = tf.transpose(pme_out, [0, 3, 1, 2]) # NCHW 440 | 441 | # Summary statistics. 442 | net_std_out = tf.maximum(zero64, tf.linalg.det(sigma_x))**(1.0/6.0) # NHW 443 | noise_std_out = tf.maximum(zero64, tf.linalg.det(sigma_n))**(1.0/6.0) # N11 / NHW 444 | elif noise_style == 'impulse': 445 | alpha = noise_std # N111. 446 | if num_channels == 1: 447 | raise NotImplementedError 448 | else: 449 | # Preliminaries. 450 | sigma_x = sigma_x + Ieps # NHWCC. Inflate by epsilon. 451 | sigma_x_inv = tf.linalg.inv(sigma_x) # NHWCC 452 | mu_x2 = tf.transpose(mu_x, [0, 2, 3, 1]) # NHWC 453 | noisy_in2 = tf.transpose(noisy_in, [0, 2, 3, 1]) # NHWC 454 | diff = (noisy_in2 - mu_x2) # NHWC 455 | diff = batch_vtmv(diff, sigma_x_inv) # NHW 456 | dets = tf.linalg.det(sigma_x) # NHW 457 | dets = tf.maximum(tf.constant(1e-9, dtype=tf.float64), dets) # NHW. Avoid division by zero and negative square roots. 458 | g = tf.exp(-0.5 * diff) / ((2.0 * np.pi)**num_channels * dets)**0.5 # NHW 459 | g = g[..., tf.newaxis] # NHW1 460 | 461 | # Posterior mean estimate. 462 | h = (1.0 - alpha) * g # NHW1 463 | pme_out = (alpha * mu_x2 + h * noisy_in2) / (alpha + h) 464 | pme_out = tf.transpose(pme_out, [0, 3, 1, 2]) # NCHW 465 | 466 | # Training loss with the modified stats. 467 | mu_y2 = alpha * .5 + (1.0 - alpha) * mu_x2 # NHWC 468 | alpha = alpha[..., tf.newaxis] # n1111 469 | sigma_y = alpha * (1.0/4.0 + I/12.0) + (1.0 - alpha) * (sigma_x + batch_vvt(mu_x2)) - batch_vvt(mu_y2) # NHWCC 470 | sigma_y_inv = tf.linalg.inv(sigma_y) # NHWCC 471 | diff = (noisy_in2 - mu_y2) # NHWC 472 | diff = batch_vtmv(diff, sigma_y_inv) # NHW 473 | dets = tf.linalg.det(sigma_y) # NHW 474 | dets = tf.maximum(tf.constant(1e-9, dtype=tf.float64), dets) # NHW 475 | loss_out = diff + tf.log(dets) # NHW 476 | 477 | # Summary statistics. 478 | net_std_out = tf.maximum(zero64, tf.linalg.det(sigma_x))**(1.0/6.0) # NHW. Cube root of volumetric scaling factor. 479 | noise_std_out = alpha[..., 0, 0] / 255.0 * 100.0 # N11 / NHW. Shows as percentage in output. 480 | 481 | return mu_x, pme_out, loss_out, net_std_out, noise_std_out 482 | 483 | #---------------------------------------------------------------------------- 484 | 485 | def simple_pipeline(clean_in, 486 | noisy_in, 487 | L_exponent_in, 488 | noise_style = None, 489 | input_shape = None, 490 | blindspot = False, 491 | noisy_targets = False, 492 | **_kwargs): 493 | 494 | clean_in.set_shape(input_shape) 495 | noisy_in.set_shape(input_shape) 496 | L_exponent_in.set_shape([]) 497 | 498 | x = analysis_network(noisy_in, input_shape[1], blindspot=blindspot) 499 | 500 | if noise_style == 'impulse' and noisy_targets: # Cannot use L2 loss because mean changes 501 | loss_out = (tf.abs(x - clean_in) + 1e-8) ** L_exponent_in 502 | else: 503 | loss_out = (x - clean_in) ** 2.0 504 | 505 | net_std_out, noise_std_out = [tf.zeros_like(noisy_in) for x in range(2)] 506 | return x, x, loss_out, net_std_out, noise_std_out 507 | 508 | #---------------------------------------------------------------------------- 509 | 510 | def get_scrambled_indices(num, bs): 511 | assert num > 0 512 | i, x = 0, [] 513 | while True: 514 | res = x[i : i + bs] 515 | i += bs 516 | while len(res) < bs: 517 | x = list(np.arange(num)) 518 | np.random.shuffle(x) 519 | i = bs - len(res) 520 | res += x[:i] 521 | yield res 522 | 523 | #---------------------------------------------------------------------------- 524 | 525 | def random_crop_numpy(img, crop_size): 526 | y = np.random.randint(img.shape[1] - crop_size + 1) 527 | x = np.random.randint(img.shape[2] - crop_size + 1) 528 | return img[:, y : y+crop_size, x : x+crop_size] 529 | 530 | #---------------------------------------------------------------------------- 531 | # Noise implementations. 532 | #---------------------------------------------------------------------------- 533 | 534 | operation_seed_counter = 0 535 | def noisify(x, style): 536 | def get_seed(): 537 | global operation_seed_counter 538 | operation_seed_counter += 1 539 | return operation_seed_counter 540 | 541 | if style.startswith('gauss'): # Gaussian noise with constant/variable std.dev. 542 | params = [float(p) / 255.0 for p in style.replace('gauss', '', 1).split('_')] 543 | if len(params) == 1: 544 | std = params[0] 545 | elif len(params) == 2: 546 | min_std, max_std = params 547 | std = tf.random_uniform(shape=[tf.shape(x)[0], 1, 1, 1], minval=min_std, maxval=max_std, seed=get_seed()) 548 | return x + tf.random_normal(shape=tf.shape(x), seed=get_seed()) * std, std 549 | elif style.startswith('poisson'): # Poisson noise with constant/variable lambda. 550 | params = [float(p) for p in style.replace('poisson', '', 1).split('_')] 551 | if len(params) == 1: 552 | lam = params[0] 553 | elif len(params) == 2: 554 | min_lam, max_lam = params 555 | lam = tf.random_uniform(shape=[tf.shape(x)[0], 1, 1, 1], minval=min_lam, maxval=max_lam, seed=get_seed()) 556 | x = x * lam 557 | with tf.device("/cpu:0"): 558 | x = tf.random_poisson(x, [1], seed=get_seed()) 559 | return x[0] / lam, lam 560 | elif style.startswith('impulse'): # Random replacement with constant/variable alpha. 561 | params = [float(p) * 0.01 for p in style.replace('impulse', '', 1).split('_')] 562 | msh = tf.shape(x[:, :1, ...]) 563 | if len(params) == 1: 564 | alpha = params[0] 565 | keep_mask = tf.where(tf.random_uniform(shape=msh, seed=get_seed()) >= alpha, tf.ones(shape=msh), tf.zeros(shape=msh)) 566 | elif len(params) == 2: 567 | min_alpha, max_alpha = params 568 | alpha = tf.random_uniform(shape=[tf.shape(x)[0], 1, 1, 1], minval=min_alpha, maxval=max_alpha, seed=get_seed()) 569 | keep_mask = tf.where(tf.random_uniform(shape=msh, seed=get_seed()) >= tf.ones(shape=msh) * alpha, tf.ones(shape=msh), tf.zeros(shape=msh)) 570 | noise = tf.random_uniform(shape=tf.shape(x), seed=get_seed()) 571 | return x * keep_mask + noise * (1.0 - keep_mask), alpha 572 | 573 | #---------------------------------------------------------------------------- 574 | # Training loop. 575 | #---------------------------------------------------------------------------- 576 | 577 | def train(submit_config, 578 | num_iter = 1000000, 579 | train_resolution = 256, 580 | minibatch_size = 4, 581 | learning_rate = 3e-4, 582 | rampup_fraction = 0.1, 583 | rampdown_fraction = 0.3, 584 | snapshot_every = 0, # Export network snapshot every n images (must be divisible by minibatch). 585 | pipeline = None, 586 | diagonal_covariance = False, # Force non-diagonal covariances to zero (per-channel univariate). 587 | noise_style = None, 588 | noise_params = None, # 'known', 'global', 'per_image' 589 | train_dataset = None, 590 | validation_dataset = None, 591 | validation_repeats = 1, 592 | prune_dataset = None, 593 | num_channels = None, 594 | print_interval = 1000, 595 | eval_interval = 10000, 596 | eval_network = None, 597 | config_name = None, 598 | dataset_dir = None): 599 | 600 | # Are we in evaluation mode? 601 | eval_mode = eval_network is not None 602 | 603 | # Initialize Tensorflow. 604 | if eval_mode: 605 | init_tf(0) # Use fixed seeds if evaluating a network. 606 | np.random.seed(0) 607 | else: 608 | init_tf() # Use a random random seed. 609 | 610 | # Get going. 611 | ctx = dnnlib.RunContext(submit_config) 612 | run_dir = submit_config.run_dir 613 | img_dir = os.path.join(run_dir, 'img') 614 | os.makedirs(img_dir, exist_ok=True) 615 | 616 | # Load the data. 617 | train_images, validation_images, validation_image_size = load_datasets(num_channels, dataset_dir, None if eval_mode else train_dataset, validation_dataset, prune_dataset) 618 | 619 | # Repeat validation set if asked to. 620 | original_validation_image_count = len(validation_images) # Avoid exporting the duplicate images. 621 | if validation_repeats > 1: 622 | print("Repeating the validation set %d times." % validation_repeats) 623 | validation_images = validation_images * validation_repeats 624 | validation_image_size = validation_image_size * validation_repeats 625 | 626 | # Construct the network. 627 | input_shape = [None, num_channels, None, None] 628 | with tf.device("/gpu:0"): 629 | if eval_mode: 630 | print("Evaluating network '%s'." % eval_network) 631 | with open(eval_network, 'rb') as f: 632 | net = pickle.load(f) 633 | else: 634 | if noise_style.startswith('gauss'): net_noise_style = 'gauss' 635 | if noise_style.startswith('poisson'): net_noise_style = 'poisson' 636 | if noise_style.startswith('impulse'): net_noise_style = 'impulse' 637 | 638 | if pipeline == 'blindspot': 639 | net = dnnlib.tflib.Network('net', 'selfsupervised_denoising.blindspot_pipeline', input_shape=input_shape, noise_style=net_noise_style, noise_params=noise_params, diagonal_covariance=diagonal_covariance) 640 | elif pipeline == 'blindspot_mean': 641 | net = dnnlib.tflib.Network('net', 'selfsupervised_denoising.simple_pipeline', input_shape=input_shape, noise_style=net_noise_style, blindspot=True, noisy_targets=True) 642 | elif pipeline == 'n2c': 643 | net = dnnlib.tflib.Network('net', 'selfsupervised_denoising.simple_pipeline', input_shape=input_shape, noise_style=net_noise_style, blindspot=False, noisy_targets=False) 644 | elif pipeline == 'n2n': 645 | net = dnnlib.tflib.Network('net', 'selfsupervised_denoising.simple_pipeline', input_shape=input_shape, noise_style=net_noise_style, blindspot=False, noisy_targets=True) 646 | 647 | # Data splits. 648 | with tf.name_scope('Inputs'), tf.device("/cpu:0"): 649 | learning_rate_in = tf.placeholder(tf.float32, name='learning_rate_in', shape=[]) 650 | L_exponent_in = tf.placeholder(tf.float32, name='L_exponent', shape=[]) 651 | clean_in = tf.placeholder(tf.float32, shape=input_shape) 652 | clean_in_split = tf.split(clean_in, submit_config.num_gpus) 653 | 654 | # Optimizer. 655 | opt = dnnlib.tflib.Optimizer(tf_optimizer='tf.train.AdamOptimizer', learning_rate=learning_rate_in, beta1=0.9, beta2=0.99) 656 | 657 | # Per-gpu stuff. 658 | train_loss = 0. 659 | train_psnr = 0. 660 | train_psnr_pme = 0. 661 | gpu_outputs = [] 662 | for gpu in range(submit_config.num_gpus): 663 | with tf.device("/gpu:%d" % gpu): 664 | net_gpu = net if gpu == 0 else net.clone() 665 | clean_in_gpu = clean_in_split[gpu] 666 | noisy_in_gpu, noise_coeff = noisify(clean_in_gpu, noise_style) 667 | 668 | if pipeline == 'blindspot_mean': 669 | reference_in_gpu = noisy_in_gpu 670 | elif pipeline == 'n2n': 671 | reference_in_gpu, _ = noisify(clean_in_gpu, noise_style) # Another noise instantiation. 672 | else: 673 | reference_in_gpu = clean_in_gpu 674 | 675 | noise_coeff = tf.zeros([tf.shape(noisy_in_gpu)[0], 1, 1, 1]) + noise_coeff # Broadcast to [n, 1, 1, 1] shape. 676 | 677 | # Support for networks that were exported from an older version of code and loaded for evaluation purposes. 678 | if net.num_inputs == 5: 679 | mu_x, pme_out, loss_out, net_std_out, noise_std_out, _ = net_gpu.get_output_for(reference_in_gpu, noisy_in_gpu, noise_coeff, tf.constant(1e-6, dtype=tf.float32), tf.constant(1e-1, dtype=tf.float32)) 680 | else: 681 | if pipeline == 'blindspot': 682 | if net.num_inputs == 3: 683 | mu_x, pme_out, loss_out, net_std_out, noise_std_out, _ = net_gpu.get_output_for(noisy_in_gpu, noise_coeff, L_exponent_in) # Previous version. 684 | else: 685 | mu_x, pme_out, loss_out, net_std_out, noise_std_out = net_gpu.get_output_for(noisy_in_gpu, noise_coeff) 686 | else: 687 | if net.num_inputs == 4: 688 | mu_x, pme_out, loss_out, net_std_out, noise_std_out, _ = net_gpu.get_output_for(reference_in_gpu, noisy_in_gpu, noise_coeff, L_exponent_in) # Previous version. 689 | else: 690 | mu_x, pme_out, loss_out, net_std_out, noise_std_out = net_gpu.get_output_for(reference_in_gpu, noisy_in_gpu, L_exponent_in) 691 | 692 | gpu_outputs.append([mu_x, pme_out, loss_out, net_std_out, noise_std_out, noisy_in_gpu]) 693 | 694 | # Loss. 695 | loss = tf.reduce_mean(loss_out) 696 | 697 | # PSNR during training. 698 | psnr = tf.reduce_mean(calculate_psnr(mu_x, clean_in_gpu, axis=[1, 2, 3])) 699 | psnr_pme = tf.reduce_mean(calculate_psnr(pme_out, clean_in_gpu, axis=[1, 2, 3])) 700 | with tf.control_dependencies([autosummary("train_loss", loss), autosummary("train_psnr", psnr), autosummary("train_psnr_pme", psnr_pme)]): 701 | opt.register_gradients(loss, net_gpu.trainables) 702 | 703 | # Accumulation not on the GPU. 704 | train_loss += loss / submit_config.num_gpus 705 | train_psnr += psnr / submit_config.num_gpus 706 | train_psnr_pme += psnr_pme / submit_config.num_gpus 707 | 708 | # Total outputs. 709 | mu_x_out, pme_out, loss_out, net_std_out, noise_std_out, noisy_out = [tf.concat(x, axis=0) for x in zip(*gpu_outputs)] 710 | 711 | # Train step op. 712 | train_step = opt.apply_updates() 713 | 714 | # Create a log file for Tensorboard. 715 | if not eval_mode: 716 | summary_log = tf.summary.FileWriter(run_dir) 717 | summary_log.add_graph(tf.get_default_graph()) 718 | 719 | # Training image index generator. 720 | index_generator = get_scrambled_indices(len(train_images), minibatch_size) 721 | 722 | # Init stats. 723 | print_last, eval_last = 0, 0 724 | loss_acc, loss_n = 0., 0. 725 | psnr_acc, psnr_pme_acc = 0., 0. 726 | std_net_acc, std_noise_acc = 0., 0. 727 | valid_psnr_mu, valid_psnr_pme = 0., 0. 728 | t_start = time.time() 729 | 730 | # Train. 731 | if eval_mode: 732 | print('Evaluating network with %d images.' % len(validation_images)) 733 | else: 734 | print('Training for %d images.' % num_iter) 735 | 736 | for n in range(0, num_iter + minibatch_size, minibatch_size): 737 | if ctx.should_stop(): 738 | break 739 | 740 | # Save snapshot. 741 | if (n > 0) and (snapshot_every > 0) and (n % snapshot_every == 0): 742 | save_snapshot(submit_config, net, '%08d' % n) 743 | 744 | # Set up training step. 745 | lr = compute_ramped_lrate(n, num_iter, rampup_fraction, rampdown_fraction, learning_rate) 746 | L_exponent = 0.5 if eval_mode else max(0.5, 2.0 - 2.0 * n / num_iter) 747 | 748 | # Training step unless in evaluation mode. 749 | if not eval_mode: 750 | # Get clean images from training set. 751 | clean = np.zeros([minibatch_size, num_channels, train_resolution, train_resolution], dtype=np.uint8) 752 | for i, j in enumerate(next(index_generator)): 753 | clean[i] = random_crop_numpy(train_images[j], train_resolution) 754 | clean = adjust_dynamic_range(clean, [0, 255], [0.0, 1.0]) 755 | 756 | # Run training step. 757 | feed_dict = {clean_in: clean, learning_rate_in: lr, L_exponent_in: L_exponent} 758 | loss_val, psnr_val, psnr_pme_val, net_std_val, noise_std_val, _ = tfutil.run([train_loss, train_psnr, train_psnr_pme, net_std_out, noise_std_out, train_step], feed_dict) 759 | 760 | # Accumulate stats. 761 | loss_acc += loss_val 762 | psnr_acc += psnr_val 763 | psnr_pme_acc += psnr_pme_val 764 | std_net_acc += np.mean(net_std_val) 765 | std_noise_acc += np.mean(noise_std_val) 766 | loss_n += 1.0 767 | 768 | # Print. 769 | if n == 0 or n >= print_last + print_interval: 770 | loss_n = max(loss_n, 1.0) 771 | loss_acc /= loss_n 772 | psnr_acc /= loss_n 773 | psnr_pme_acc /= loss_n 774 | std_net_acc = std_net_acc / loss_n * 255.0 775 | std_noise_acc = std_noise_acc / loss_n * 255.0 776 | t_iter = time.time() - t_start 777 | print("%8d: time=%6.2f, loss=%8.4f, train_psnr=%8.4f, train_psnr_pme=%8.4f, std_net=%8.4f, std_noise=%8.4f" % (n, t_iter, loss_acc, psnr_acc, psnr_pme_acc, autosummary('std_net', std_net_acc), autosummary('std_noise', std_noise_acc)), end='') 778 | ctx.update(loss='%.2f %.2f' % (psnr_pme_acc, valid_psnr_pme), cur_epoch=n, max_epoch=num_iter) 779 | print_last += print_interval if (n > 0) else 0 780 | loss_acc, loss_n = 0., 0. 781 | psnr_acc, psnr_pme_acc = 0., 0. 782 | std_net_acc, std_noise_acc = 0., 0. 783 | t_start = time.time() 784 | 785 | # Measure and export validation images. 786 | if n == 0 or n >= eval_last + eval_interval or n == num_iter: 787 | valid_psnr_mu = 0. 788 | valid_psnr_pme = 0. 789 | bs = submit_config.num_gpus # Validation batch size. 790 | for idx0 in range(0, len(validation_images), bs): 791 | num = min(bs, len(validation_images) - idx0) 792 | idx = list(range(idx0, idx0 + bs)) 793 | idx = [min(x, len(validation_images) - 1) for x in idx] 794 | val_input = [] 795 | val_sz = [] 796 | for i in idx: 797 | img = validation_images[i][np.newaxis, ...] 798 | img = adjust_dynamic_range(img, [0, 255], [0.0, 1.0]) 799 | sz = img.shape[2:] 800 | img = np.pad(img, [[0, 0], [0, 0], [0, validation_image_size[0] - sz[0]], [0, validation_image_size[1] - sz[1]]], 'reflect') 801 | val_input.append(img) 802 | val_sz.append(sz) 803 | val_input = np.concatenate(val_input, axis=0) # Batch of validation images. 804 | 805 | # Run the actual step. 806 | feed_dict = {clean_in: val_input} 807 | mu_x, net_std, pme, noisy = tfutil.run([mu_x_out, net_std_out, pme_out, noisy_out], feed_dict) 808 | 809 | # Process the result images. 810 | for i, j in enumerate(idx[:num]): 811 | crop_val_input, crop_mu_x, crop_pme, crop_noisy = [x[i, :, :val_sz[i][0], :val_sz[i][1]] for x in [val_input, mu_x, pme, noisy]] 812 | crop_net_std = net_std[i, :val_sz[i][0], :val_sz[i][1]] # HW grayscale 813 | crop_net_std /= 10.0 / 255.0 # white = 10 ULPs in U8. 814 | valid_psnr_mu += calculate_psnr(crop_mu_x, crop_val_input) / len(validation_images) 815 | valid_psnr_pme += calculate_psnr(crop_pme, crop_val_input) / len(validation_images) 816 | 817 | if (eval_mode and (j < original_validation_image_count)) or ((not eval_mode) and (j == len(validation_images) - 1)): # Export last image, or all if evaluating. 818 | k, ext = (j, 'png') if eval_mode else (n, 'jpg') 819 | def save_img(name, img): save_image(img, os.path.join(img_dir, 'img-%07d-%s.%s' % (k, name, ext)), [0.0, 1.0]) 820 | save_img('a_nsy', crop_noisy) # Noisy input 821 | save_img('b_out', crop_mu_x) # Predicted mean 822 | save_img('b_out2', crop_pme) # Posterior mean estimate (actual output) 823 | save_img('b_std', crop_net_std) # Predicted std. dev 824 | save_img('c_cln', crop_val_input) # Clean reference image 825 | 826 | # Validation pass completed. 827 | 828 | print(", valid_psnr_mu=%8.4f, valid_psnr_pme=%8.4f" % (valid_psnr_mu, valid_psnr_pme), end='') 829 | eval_last += eval_interval if (n > 0) else 0 830 | 831 | # Exit if evaluation mode. 832 | if eval_mode: 833 | print("\nEvaluation done, exiting.") 834 | print("RESULT %8.4f" % valid_psnr_pme) 835 | ctx.close() 836 | return 837 | 838 | # Finish printing. 839 | autosummary('valid_psnr_mu', valid_psnr_mu) 840 | autosummary('valid_psnr_pme', valid_psnr_pme) 841 | dnnlib.tflib.autosummary.save_summaries(summary_log, n) 842 | print("") 843 | 844 | # Save the result. 845 | save_snapshot(submit_config, net, 'final-'+config_name) 846 | 847 | # Done. 848 | summary_log.close() 849 | ctx.close() 850 | 851 | 852 | #---------------------------------------------------------------------------- 853 | config_lst = [ 854 | dict(eval_id = '00011', noise_style='gauss25', num_iter=2000000, pipeline='n2c'), 855 | dict(eval_id = '00012', noise_style='gauss25', num_iter=2000000, pipeline='n2n'), 856 | dict(eval_id = '00013', noise_style='gauss25', num_iter=2000000, pipeline='blindspot', noise_params='known'), 857 | dict(eval_id = '00014', noise_style='gauss25', num_iter=2000000, pipeline='blindspot', noise_params='global'), 858 | dict(eval_id = '00015', noise_style='gauss25', num_iter=2000000, pipeline='blindspot', noise_params='known', diagonal_covariance=True), 859 | dict(eval_id = '00016', noise_style='gauss25', num_iter=2000000, pipeline='blindspot', noise_params='global', diagonal_covariance=True), 860 | dict(eval_id = '00017', noise_style='gauss25', num_iter=2000000, pipeline='blindspot_mean'), 861 | dict(eval_id = '00018', noise_style='gauss5_50', num_iter=2000000, pipeline='n2c'), 862 | dict(eval_id = '00019', noise_style='gauss5_50', num_iter=2000000, pipeline='n2n'), 863 | dict(eval_id = '00020', noise_style='gauss5_50', num_iter=2000000, pipeline='blindspot', noise_params='known'), 864 | dict(eval_id = '00021', noise_style='gauss5_50', num_iter=2000000, pipeline='blindspot', noise_params='per_image'), 865 | dict(eval_id = '00022', noise_style='gauss5_50', num_iter=2000000, pipeline='blindspot', noise_params='known', diagonal_covariance=True), 866 | dict(eval_id = '00023', noise_style='gauss5_50', num_iter=2000000, pipeline='blindspot', noise_params='per_image', diagonal_covariance=True), 867 | dict(eval_id = '00024', noise_style='gauss5_50', num_iter=2000000, pipeline='blindspot_mean'), 868 | dict(eval_id = '00030', noise_style='poisson30', num_iter=2000000, pipeline='n2c'), 869 | dict(eval_id = '00031', noise_style='poisson30', num_iter=2000000, pipeline='n2n'), 870 | dict(eval_id = '00032', noise_style='poisson30', num_iter=2000000, pipeline='blindspot', noise_params='known'), 871 | dict(eval_id = '00033', noise_style='poisson30', num_iter=2000000, pipeline='blindspot', noise_params='global'), 872 | dict(eval_id = '00034', noise_style='poisson30', num_iter=2000000, pipeline='blindspot_mean'), 873 | dict(eval_id = '00035', noise_style='poisson5_50', num_iter=2000000, pipeline='n2c'), 874 | dict(eval_id = '00036', noise_style='poisson5_50', num_iter=2000000, pipeline='n2n'), 875 | dict(eval_id = '00037', noise_style='poisson5_50', num_iter=2000000, pipeline='blindspot', noise_params='known'), 876 | dict(eval_id = '00038', noise_style='poisson5_50', num_iter=2000000, pipeline='blindspot', noise_params='per_image'), 877 | dict(eval_id = '00039', noise_style='poisson5_50', num_iter=2000000, pipeline='blindspot_mean'), 878 | dict(eval_id = '00050', noise_style='impulse50', pipeline='n2c', num_iter=16000000), 879 | dict(eval_id = '00051', noise_style='impulse50', pipeline='n2n', num_iter=16000000), 880 | dict(eval_id = '00052', noise_style='impulse50', pipeline='blindspot', noise_params='known', num_iter=4000000), 881 | dict(eval_id = '00053', noise_style='impulse50', pipeline='blindspot', noise_params='global', num_iter=4000000), 882 | dict(eval_id = '00054', noise_style='impulse50', pipeline='blindspot_mean', num_iter=8000000), 883 | dict(eval_id = '00055', noise_style='impulse0_100', pipeline='n2c', num_iter=16000000), 884 | dict(eval_id = '00056', noise_style='impulse0_100', pipeline='n2n', num_iter=16000000), 885 | dict(eval_id = '00057', noise_style='impulse0_100', pipeline='blindspot', noise_params='known', num_iter=4000000), 886 | dict(eval_id = '00058', noise_style='impulse0_100', pipeline='blindspot', noise_params='per_image', num_iter=4000000), 887 | dict(eval_id = '00059', noise_style='impulse0_100', pipeline='blindspot_mean', num_iter=8000000), 888 | dict(eval_id = '00180', noise_style='gauss25', num_channels=1, num_iter=2000000, pipeline='n2c'), 889 | dict(eval_id = '00181', noise_style='gauss25', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='known'), 890 | dict(eval_id = '00182', noise_style='gauss25', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='global'), 891 | dict(eval_id = '00183', noise_style='gauss5_50', num_channels=1, num_iter=2000000, pipeline='n2c'), 892 | dict(eval_id = '00184', noise_style='gauss5_50', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='known'), 893 | dict(eval_id = '00185', noise_style='gauss5_50', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='per_image'), 894 | dict(eval_id = '00188', noise_style='poisson30', num_channels=1, num_iter=2000000, pipeline='n2c'), 895 | dict(eval_id = '00189', noise_style='poisson30', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='known'), 896 | dict(eval_id = '00190', noise_style='poisson30', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='global'), 897 | dict(eval_id = '00191', noise_style='poisson5_50', num_channels=1, num_iter=2000000, pipeline='n2c'), 898 | dict(eval_id = '00192', noise_style='poisson5_50', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='known'), 899 | dict(eval_id = '00193', noise_style='poisson5_50', num_channels=1, num_iter=2000000, pipeline='blindspot', noise_params='per_image', snapshot_every=100000), # A bit unstable. 900 | ] 901 | 902 | 903 | def make_config_name(c): 904 | num_channels = c.get('num_channels', 3) 905 | diag = c.get('diagonal_covariance', False) 906 | is_blindspot = c['pipeline'] == 'blindspot' 907 | sigma = '-sigma_'+c['noise_params'] if is_blindspot else '' 908 | return c['noise_style']+'-'+c['pipeline']+('_diag' if diag else '')+sigma+('-mono' if num_channels == 1 else '') 909 | 910 | # ------------------------------------------------------------------------------------------ 911 | def cli_examples(configs): 912 | return '''examples: 913 | # Train a network with gauss25-blindspot-sigma_global configuration 914 | python %(prog)s --train=gauss25-blindspot-sigma_global --dataset-dir=$HOME/datasets --validation-set=kodak --train-h5=imagenet_val_raw.h5 915 | 916 | # Evaluate a network using the BSD300 dataset: 917 | python %(prog)s --eval=$HOME/pretrained/network-00012-gauss25-n2n.pickle --dataset-dir=$HOME/datasets --validation-set=kodak 918 | 919 | List of all configs: 920 | 921 | ''' + '\n '.join(configs) 922 | 923 | def main(): 924 | sc = dnnlib.SubmitConfig() 925 | sc.run_dir_root = 'results' 926 | sc.run_dir_ignore += ['datasets', 'results'] 927 | 928 | config_map = {} 929 | selected_config = None 930 | config_names = [] 931 | for c in config_lst: 932 | cfg_name = make_config_name(c) 933 | assert cfg_name not in config_map 934 | config_map[cfg_name] = c 935 | config_names.append(cfg_name) 936 | 937 | parser = argparse.ArgumentParser( 938 | description='Train or evaluate.', 939 | epilog=cli_examples(config_names), 940 | formatter_class=argparse.RawDescriptionHelpFormatter 941 | ) 942 | parser.add_argument('--dataset-dir', help='Path to validation set data') 943 | parser.add_argument('--train-h5', help='Specify training set .h5 filename') 944 | parser.add_argument('--validation-set', help='Evaluation dataset', default='kodak') 945 | parser.add_argument('--eval', help='Evaluate validation set with the given network pickle') 946 | parser.add_argument('--train', help='Train for the given config') 947 | args = parser.parse_args() 948 | 949 | eval_sets = { 950 | 'kodak': dict(validation_repeats=10), 951 | 'bsd300': dict(validation_repeats=3), 952 | 'set14': dict(validation_repeats=20) 953 | } 954 | if args.validation_set not in eval_sets: 955 | print ('Validation set specified with --validation-set not in one of: ' + ', '.join(eval_sets)) 956 | sys.exit(1) 957 | 958 | if args.dataset_dir is None: 959 | print ('Must specify validation dataset path with --dataset-dir') 960 | sys.exit(1) 961 | if not os.path.isdir(args.dataset_dir): 962 | print ('Directory specified with --dataset-dir does not seem to exist.') 963 | sys.exit(1) 964 | 965 | config_name = None 966 | if args.train: 967 | if args.eval is not None: 968 | print ('Use either --train or --eval') 969 | sys.exit(1) 970 | if args.train_h5 is None: 971 | print ('Must specify training dataset with --train-h5 when training') 972 | sys.exit(1) 973 | config_name = args.train 974 | elif args.eval: 975 | pickle_name = args.eval 976 | pickle_re = re.compile('^network-(?:[0-9]+|final)-(.+)\\.pickle') 977 | m = pickle_re.match(os.path.basename(pickle_name)) 978 | if m is None: 979 | print ('network pickle name must contain network config string') 980 | sys.exit(1) 981 | config_name = m.group(1) 982 | else: 983 | print ('Must use either --train or --eval') 984 | sys.exit(1) 985 | 986 | 987 | if config_name not in config_map: 988 | print ('unknown config', config_name) 989 | sys.exit(1) 990 | 991 | validation_repeats = eval_sets[args.validation_set]['validation_repeats'] if args.eval else 1 992 | 993 | # Common configuration for all runs. 994 | config = dnnlib.EasyDict( 995 | train_dataset = args.train_h5, # Training set. 996 | validation_dataset = args.validation_set, # Dataset used to monitor validation convergence during training. 997 | validation_repeats = validation_repeats, 998 | num_channels = 3, # RGB. 999 | train_resolution = 256, 1000 | minibatch_size = 4, 1001 | learning_rate = 3e-4, 1002 | config_name = config_name, 1003 | dataset_dir = args.dataset_dir 1004 | ) 1005 | 1006 | selected_config = config_map[config_name] 1007 | config.update(**selected_config) 1008 | if args.eval is not None: 1009 | config['eval_network'] = args.eval 1010 | del config['eval_id'] 1011 | 1012 | #---------------------------------------------------------------------------- 1013 | 1014 | # Execute. 1015 | sc.run_desc = 'eval' if config.get('eval_network') else 'train' 1016 | 1017 | # Decorate run_desc. 1018 | sc.run_desc += '-ilsvrc' 1019 | if config.get('prune_dataset'): sc.run_desc += '_%d' % config.prune_dataset 1020 | sc.run_desc += '-%s' % config.validation_dataset 1021 | sc.run_desc += '-%dc' % config.num_channels 1022 | sc.run_desc += '-%s' % config.noise_style 1023 | if config.minibatch_size != 4: sc.run_desc += '-mb%d' % config.minibatch_size 1024 | if config.learning_rate != 3e-4: sc.run_desc += '-lr%g' % config.learning_rate 1025 | if config.num_iter >= 1000000: 1026 | sc.run_desc += '-iter%dm' % (config.num_iter // 1000000) 1027 | elif config.num_iter >= 1000: 1028 | sc.run_desc += '-iter%dk' % (config.num_iter // 1000) 1029 | else: 1030 | sc.run_desc += '-iter%d' % config.num_iter 1031 | sc.run_desc += '-%s' % config.pipeline 1032 | if config.get('diagonal_covariance'): sc.run_desc += 'Diag' 1033 | if config.pipeline == 'blindspot': 1034 | sc.run_desc += '-%s' % config.noise_params 1035 | if config.train_resolution != 256: sc.run_desc += '-res%d' % config.train_resolution 1036 | 1037 | if config.get('eval_network'): sc.run_desc += '-EVAL_%s' % config_name 1038 | if config.get('eval_network'): sc.run_dir_root += '/_eval' 1039 | 1040 | # Submit. 1041 | submit.submit_run(sc, 'selfsupervised_denoising.train', **config) 1042 | 1043 | #---------------------------------------------------------------------------- 1044 | 1045 | if __name__ == "__main__": 1046 | main() 1047 | --------------------------------------------------------------------------------