├── .gitignore ├── README.md ├── external ├── LICENSE.txt ├── dataset_tool_h5.py └── download_kodak.py ├── figures └── example.png ├── models ├── final-n2c-gauss25.wt ├── final-n2c-gauss25_nc.wt ├── final-n2n-gauss25.wt ├── final-n2n-gauss25_nc.wt ├── final-n2v-gauss25.wt ├── final-n2v-gauss25_nc.wt ├── final-ssdn-gauss25-sigma_const.wt ├── final-ssdn-gauss25-sigma_known.wt ├── final-ssdn-gauss25-sigma_var.wt ├── final-ssdn-gauss25_nc-sigma_const.wt ├── final-ssdn-gauss25_nc-sigma_known.wt └── final-ssdn-gauss25_nc-sigma_var.wt ├── report ├── fancyhdr.sty ├── figures │ ├── gauss25_train_psnr_relative.pgf │ └── gauss25_val_psnr_relative.pgf ├── iclr2020_conference.bst ├── iclr2020_conference.sty ├── images │ ├── CBM3D_cropped.png │ ├── Clean.png │ ├── Clean_cropped.png │ ├── N2C.png │ ├── N2C_cropped.png │ ├── N2N_cropped.png │ ├── N2V_cropped.png │ ├── Noisy.png │ ├── Noisy_cropped.png │ ├── SSDN_cropped.png │ └── SSDN_mu_cropped.png ├── math_commands.tex ├── natbib.sty ├── references.bib ├── report.pdf └── report.tex ├── setup.cfg └── ssdn ├── setup.py ├── ssdn ├── __init__.py ├── __main__.py ├── cfg.py ├── cli │ ├── __init__.py │ ├── cli.py │ └── cmds │ │ ├── __init__.py │ │ ├── cmd.py │ │ ├── eval.py │ │ └── train.py ├── datasets │ ├── __init__.py │ ├── folder.py │ ├── hdf5.py │ ├── noise_wrapper.py │ └── sampler.py ├── denoiser.py ├── eval.py ├── logging_helper.py ├── models │ ├── __init__.py │ ├── noise_network.py │ └── utility.py ├── params.py ├── train.py ├── utils │ ├── __init__.py │ ├── data.py │ ├── data_format.py │ ├── n2v_loss.py │ ├── n2v_ups.py │ ├── noise.py │ ├── pickle_fix.py │ ├── transforms.py │ └── utils.py └── version.py └── tests └── test_sampler.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | # Edit at https://www.gitignore.io/?templates=python 4 | 5 | .DS_Store 6 | BSDS300/ 7 | kodak/ 8 | 9 | ### Python ### 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | dl-group-cw.code-workspace 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # pipenv 80 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 81 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 82 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 83 | # install all needed dependencies. 84 | #Pipfile.lock 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # Mr Developer 100 | .mr.developer.cfg 101 | .project 102 | .pydevproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | .dmypy.json 110 | dmypy.json 111 | 112 | # Pyre type checker 113 | .pyre/ 114 | 115 | # End of https://www.gitignore.io/api/python 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-Quality Self-Supervised Deep Image Denoising - Unofficial PyTorch implementation of the NeurIPS 2019 paper 2 | Reimplementers: 3 | **David Jones**, **Richard Crosland**, **Jason Barrett** | University of Southampton [ECS] 4 | 5 | Codebase for the reimplemented of the focus [paper](https://arxiv.org/abs/1901.10277) as well as discussed baselines: Noise2Clean, [Noise2Noise](https://arxiv.org/abs/1803.04189), and [Noise2Void](https://arxiv.org/abs/1811.10980). 6 | 7 | ![Denoising comparison](figures/example.png "Denoising Comparison") 8 | 9 | ## Resources 10 | 11 | - [Original Paper](https://arxiv.org/abs/1901.10277) (arXiv) 12 | - [Official Tensorflow Source](https://github.com/NVlabs/selfsupervised-denoising) (GitHub) 13 | 14 | ## Python requirements 15 | This code was tested on: 16 | - Python 3.7 17 | - [PyTorch](https://pytorch.org/get-started/locally/) 1.4.0 [Cuda 10.0 / CPU] 18 | - [Anaconda 2020/02](https://www.anaconda.com/distribution/) 19 | 20 | ## Installation 21 | 1. Create an Anaconda/venv environment (Optional) 22 | 2. Install PyTorch 23 | 3. Install SSDN package and dependencies: ```pip install -e ssdn``` 24 | 25 | 26 | ## Preparing datasets 27 | Dataset download/preparation is handled using the original implementation's methods; these are provided as tools in the `external` directory. Dataloading methods expect either a hdf5 file format or a folder dataset. Networks expect to train with a fixed patch size and inputs are padded or randomly cropped to reach this patch size. The dataset can be filtered to contain only images between 256x256 and 512x512 pixels using `dataset_tool_h5.py`. The original paper's and this paper's trained networks use this tool on the ImageNet validation set to create the training set. 28 | To generate the training data hdf5 file, run: 29 | ``` 30 | python dataset_tool_h5.py --input-dir "/ILSVRC2012_img_val" --out=ilsvrc_val.h5 31 | ``` 32 | 33 | A successful run of dataset_tool_h5.py should result in a `.h5` file containing 44328 images. 34 | 35 | ### Validation datasets used: 36 | 37 | **Kodak**. To download the [Kodak Lossless True Color Image Suite](http://r0k.us/graphics/kodak/), run: 38 | ``` 39 | python download_kodak.py --output-dir={output_dir}/kodak 40 | ``` 41 | 42 | **BSD300**. From [Berkeley Segmentation Dataset and Benchmark](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds) download `BSDS300-images.tgz` and extract. 43 | 44 | **Set14**. From [LapSRN project page](http://vllab.ucmerced.edu/wlai24/LapSRN) download `SR_testing_datasets.zip` and extract. 45 | 46 | 47 | ## Running 48 | The denoiser is exposed as a CLI accessible via the ```ssdn``` command. 49 | 50 | ### Training: 51 | To train a network, run: 52 | ``` 53 | ssdn train start [-h] --train_dataset TRAIN_DATASET 54 | [--validation_dataset VALIDATION_DATASET] --iterations 55 | ITERATIONS [--eval_interval EVAL_INTERVAL] 56 | [--checkpoint_interval CHECKPOINT_INTERVAL] 57 | [--print_interval PRINT_INTERVAL] 58 | [--train_batch_size TRAIN_BATCH_SIZE] 59 | [--validation_batch_size VALIDATION_BATCH_SIZE] 60 | [--patch_size PATCH_SIZE] --algorithm 61 | {n2c,n2n,n2v,ssdn,ssdn_u_only} --noise_style 62 | NOISE_STYLE [--noise_value {known,const,var}] [--mono] 63 | [--diagonal] [--runs_dir RUNS_DIR] 64 | 65 | The following arguments are required: --train_dataset/-t, --iterations/-i, --algorithm/-a, --noise_style/-n, --noise_value (when --algorithm=ssdn) 66 | ``` 67 | Note that the validation dataset is optional, this can be ommitted but may be helpful to monitor convergence. Where a parameter is not provided the default in `cfg.py` will be used. 68 | 69 | --- 70 | 71 | Training will create model checkpoints that contain the training state at specified intervals (`.training` files). When training completes a final output is created containing only network weights and the configuration used to create it (`.wt` file). The latest training file for a run can be resumed using: 72 | ``` 73 | ssdn train resume [-h] [--train_dataset TRAIN_DATASET] 74 | [--validation_dataset VALIDATION_DATASET] 75 | [--iterations ITERATIONS] 76 | [--eval_interval EVAL_INTERVAL] 77 | [--checkpoint_interval CHECKPOINT_INTERVAL] 78 | [--print_interval PRINT_INTERVAL] 79 | [--train_batch_size TRAIN_BATCH_SIZE] 80 | [--validation_batch_size VALIDATION_BATCH_SIZE] 81 | [--patch_size PATCH_SIZE] 82 | run_dir 83 | The following arguments are required: run_dir (positional) 84 | ``` 85 | 86 | 87 | 88 | --- 89 | 90 | Further options can be viewed using: `ssdn train {cmd} --help` where `{cmd}` is `start` or `resume`. 91 | 92 | ### Evaluating: 93 | To evaluate a trained network against one of the validation sets, run: 94 | ``` 95 | ssdn eval [-h] --model MODEL --dataset DATASET [--runs_dir RUNS_DIR] 96 | [--batch_size BATCH_SIZE] 97 | The following arguments are required: --model/-m, --dataset/-d 98 | ``` 99 | --- 100 | Further options can be viewed using: `ssdn eval --help` 101 | 102 | ### Extra notes: 103 | 104 | The network will attempt to use all available GPUs - `cuda0` being used as the master with the batch distributed across all remaining. To avoid this filter the GPUs available using: 105 | ``` 106 | CUDA_VISIBLE_DEVICES=#,#,# ssdn ... 107 | ``` 108 | 109 | --- 110 | 111 | During execution an events file is generated with all training metrics. This can be viewed using Tensorboard. 112 | 113 | When executing remotely it may be prefable to expose this to a local machine. The suggested method to do this is `ssh`: 114 | ``` 115 | ssh -L 16006:127.0.0.1:6006 {username}@{remote} 116 | $ tensorboard --logdir runs 117 | # Connect locally at: http://127.0.0.1:16006/ 118 | ``` 119 | 120 | -------------------------------------------------------------------------------- /external/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | Attribution-NonCommercial 4.0 International 5 | 6 | ======================================================================= 7 | 8 | Creative Commons Corporation ("Creative Commons") is not a law firm and 9 | does not provide legal services or legal advice. Distribution of 10 | Creative Commons public licenses does not create a lawyer-client or 11 | other relationship. Creative Commons makes its licenses and related 12 | information available on an "as-is" basis. Creative Commons gives no 13 | warranties regarding its licenses, any material licensed under their 14 | terms and conditions, or any related information. Creative Commons 15 | disclaims all liability for damages resulting from their use to the 16 | fullest extent possible. 17 | 18 | Using Creative Commons Public Licenses 19 | 20 | Creative Commons public licenses provide a standard set of terms and 21 | conditions that creators and other rights holders may use to share 22 | original works of authorship and other material subject to copyright 23 | and certain other rights specified in the public license below. The 24 | following considerations are for informational purposes only, are not 25 | exhaustive, and do not form part of our licenses. 26 | 27 | Considerations for licensors: Our public licenses are 28 | intended for use by those authorized to give the public 29 | permission to use material in ways otherwise restricted by 30 | copyright and certain other rights. Our licenses are 31 | irrevocable. Licensors should read and understand the terms 32 | and conditions of the license they choose before applying it. 33 | Licensors should also secure all rights necessary before 34 | applying our licenses so that the public can reuse the 35 | material as expected. Licensors should clearly mark any 36 | material not subject to the license. This includes other CC- 37 | licensed material, or material used under an exception or 38 | limitation to copyright. More considerations for licensors: 39 | wiki.creativecommons.org/Considerations_for_licensors 40 | 41 | Considerations for the public: By using one of our public 42 | licenses, a licensor grants the public permission to use the 43 | licensed material under specified terms and conditions. If 44 | the licensor's permission is not necessary for any reason--for 45 | example, because of any applicable exception or limitation to 46 | copyright--then that use is not regulated by the license. Our 47 | licenses grant only permissions under copyright and certain 48 | other rights that a licensor has authority to grant. Use of 49 | the licensed material may still be restricted for other 50 | reasons, including because others have copyright or other 51 | rights in the material. A licensor may make special requests, 52 | such as asking that all changes be marked or described. 53 | Although not required by our licenses, you are encouraged to 54 | respect those requests where reasonable. More_considerations 55 | for the public: 56 | wiki.creativecommons.org/Considerations_for_licensees 57 | 58 | ======================================================================= 59 | 60 | Creative Commons Attribution-NonCommercial 4.0 International Public 61 | License 62 | 63 | By exercising the Licensed Rights (defined below), You accept and agree 64 | to be bound by the terms and conditions of this Creative Commons 65 | Attribution-NonCommercial 4.0 International Public License ("Public 66 | License"). To the extent this Public License may be interpreted as a 67 | contract, You are granted the Licensed Rights in consideration of Your 68 | acceptance of these terms and conditions, and the Licensor grants You 69 | such rights in consideration of benefits the Licensor receives from 70 | making the Licensed Material available under these terms and 71 | conditions. 72 | 73 | 74 | Section 1 -- Definitions. 75 | 76 | a. Adapted Material means material subject to Copyright and Similar 77 | Rights that is derived from or based upon the Licensed Material 78 | and in which the Licensed Material is translated, altered, 79 | arranged, transformed, or otherwise modified in a manner requiring 80 | permission under the Copyright and Similar Rights held by the 81 | Licensor. For purposes of this Public License, where the Licensed 82 | Material is a musical work, performance, or sound recording, 83 | Adapted Material is always produced where the Licensed Material is 84 | synched in timed relation with a moving image. 85 | 86 | b. Adapter's License means the license You apply to Your Copyright 87 | and Similar Rights in Your contributions to Adapted Material in 88 | accordance with the terms and conditions of this Public License. 89 | 90 | c. Copyright and Similar Rights means copyright and/or similar rights 91 | closely related to copyright including, without limitation, 92 | performance, broadcast, sound recording, and Sui Generis Database 93 | Rights, without regard to how the rights are labeled or 94 | categorized. For purposes of this Public License, the rights 95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 96 | Rights. 97 | d. Effective Technological Measures means those measures that, in the 98 | absence of proper authority, may not be circumvented under laws 99 | fulfilling obligations under Article 11 of the WIPO Copyright 100 | Treaty adopted on December 20, 1996, and/or similar international 101 | agreements. 102 | 103 | e. Exceptions and Limitations means fair use, fair dealing, and/or 104 | any other exception or limitation to Copyright and Similar Rights 105 | that applies to Your use of the Licensed Material. 106 | 107 | f. Licensed Material means the artistic or literary work, database, 108 | or other material to which the Licensor applied this Public 109 | License. 110 | 111 | g. Licensed Rights means the rights granted to You subject to the 112 | terms and conditions of this Public License, which are limited to 113 | all Copyright and Similar Rights that apply to Your use of the 114 | Licensed Material and that the Licensor has authority to license. 115 | 116 | h. Licensor means the individual(s) or entity(ies) granting rights 117 | under this Public License. 118 | 119 | i. NonCommercial means not primarily intended for or directed towards 120 | commercial advantage or monetary compensation. For purposes of 121 | this Public License, the exchange of the Licensed Material for 122 | other material subject to Copyright and Similar Rights by digital 123 | file-sharing or similar means is NonCommercial provided there is 124 | no payment of monetary compensation in connection with the 125 | exchange. 126 | 127 | j. Share means to provide material to the public by any means or 128 | process that requires permission under the Licensed Rights, such 129 | as reproduction, public display, public performance, distribution, 130 | dissemination, communication, or importation, and to make material 131 | available to the public including in ways that members of the 132 | public may access the material from a place and at a time 133 | individually chosen by them. 134 | 135 | k. Sui Generis Database Rights means rights other than copyright 136 | resulting from Directive 96/9/EC of the European Parliament and of 137 | the Council of 11 March 1996 on the legal protection of databases, 138 | as amended and/or succeeded, as well as other essentially 139 | equivalent rights anywhere in the world. 140 | 141 | l. You means the individual or entity exercising the Licensed Rights 142 | under this Public License. Your has a corresponding meaning. 143 | 144 | 145 | Section 2 -- Scope. 146 | 147 | a. License grant. 148 | 149 | 1. Subject to the terms and conditions of this Public License, 150 | the Licensor hereby grants You a worldwide, royalty-free, 151 | non-sublicensable, non-exclusive, irrevocable license to 152 | exercise the Licensed Rights in the Licensed Material to: 153 | 154 | a. reproduce and Share the Licensed Material, in whole or 155 | in part, for NonCommercial purposes only; and 156 | 157 | b. produce, reproduce, and Share Adapted Material for 158 | NonCommercial purposes only. 159 | 160 | 2. Exceptions and Limitations. For the avoidance of doubt, where 161 | Exceptions and Limitations apply to Your use, this Public 162 | License does not apply, and You do not need to comply with 163 | its terms and conditions. 164 | 165 | 3. Term. The term of this Public License is specified in Section 166 | 6(a). 167 | 168 | 4. Media and formats; technical modifications allowed. The 169 | Licensor authorizes You to exercise the Licensed Rights in 170 | all media and formats whether now known or hereafter created, 171 | and to make technical modifications necessary to do so. The 172 | Licensor waives and/or agrees not to assert any right or 173 | authority to forbid You from making technical modifications 174 | necessary to exercise the Licensed Rights, including 175 | technical modifications necessary to circumvent Effective 176 | Technological Measures. For purposes of this Public License, 177 | simply making modifications authorized by this Section 2(a) 178 | (4) never produces Adapted Material. 179 | 180 | 5. Downstream recipients. 181 | 182 | a. Offer from the Licensor -- Licensed Material. Every 183 | recipient of the Licensed Material automatically 184 | receives an offer from the Licensor to exercise the 185 | Licensed Rights under the terms and conditions of this 186 | Public License. 187 | 188 | b. No downstream restrictions. You may not offer or impose 189 | any additional or different terms or conditions on, or 190 | apply any Effective Technological Measures to, the 191 | Licensed Material if doing so restricts exercise of the 192 | Licensed Rights by any recipient of the Licensed 193 | Material. 194 | 195 | 6. No endorsement. Nothing in this Public License constitutes or 196 | may be construed as permission to assert or imply that You 197 | are, or that Your use of the Licensed Material is, connected 198 | with, or sponsored, endorsed, or granted official status by, 199 | the Licensor or others designated to receive attribution as 200 | provided in Section 3(a)(1)(A)(i). 201 | 202 | b. Other rights. 203 | 204 | 1. Moral rights, such as the right of integrity, are not 205 | licensed under this Public License, nor are publicity, 206 | privacy, and/or other similar personality rights; however, to 207 | the extent possible, the Licensor waives and/or agrees not to 208 | assert any such rights held by the Licensor to the limited 209 | extent necessary to allow You to exercise the Licensed 210 | Rights, but not otherwise. 211 | 212 | 2. Patent and trademark rights are not licensed under this 213 | Public License. 214 | 215 | 3. To the extent possible, the Licensor waives any right to 216 | collect royalties from You for the exercise of the Licensed 217 | Rights, whether directly or through a collecting society 218 | under any voluntary or waivable statutory or compulsory 219 | licensing scheme. In all other cases the Licensor expressly 220 | reserves any right to collect such royalties, including when 221 | the Licensed Material is used other than for NonCommercial 222 | purposes. 223 | 224 | 225 | Section 3 -- License Conditions. 226 | 227 | Your exercise of the Licensed Rights is expressly made subject to the 228 | following conditions. 229 | 230 | a. Attribution. 231 | 232 | 1. If You Share the Licensed Material (including in modified 233 | form), You must: 234 | 235 | a. retain the following if it is supplied by the Licensor 236 | with the Licensed Material: 237 | 238 | i. identification of the creator(s) of the Licensed 239 | Material and any others designated to receive 240 | attribution, in any reasonable manner requested by 241 | the Licensor (including by pseudonym if 242 | designated); 243 | 244 | ii. a copyright notice; 245 | 246 | iii. a notice that refers to this Public License; 247 | 248 | iv. a notice that refers to the disclaimer of 249 | warranties; 250 | 251 | v. a URI or hyperlink to the Licensed Material to the 252 | extent reasonably practicable; 253 | 254 | b. indicate if You modified the Licensed Material and 255 | retain an indication of any previous modifications; and 256 | 257 | c. indicate the Licensed Material is licensed under this 258 | Public License, and include the text of, or the URI or 259 | hyperlink to, this Public License. 260 | 261 | 2. You may satisfy the conditions in Section 3(a)(1) in any 262 | reasonable manner based on the medium, means, and context in 263 | which You Share the Licensed Material. For example, it may be 264 | reasonable to satisfy the conditions by providing a URI or 265 | hyperlink to a resource that includes the required 266 | information. 267 | 268 | 3. If requested by the Licensor, You must remove any of the 269 | information required by Section 3(a)(1)(A) to the extent 270 | reasonably practicable. 271 | 272 | 4. If You Share Adapted Material You produce, the Adapter's 273 | License You apply must not prevent recipients of the Adapted 274 | Material from complying with this Public License. 275 | 276 | 277 | Section 4 -- Sui Generis Database Rights. 278 | 279 | Where the Licensed Rights include Sui Generis Database Rights that 280 | apply to Your use of the Licensed Material: 281 | 282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 283 | to extract, reuse, reproduce, and Share all or a substantial 284 | portion of the contents of the database for NonCommercial purposes 285 | only; 286 | 287 | b. if You include all or a substantial portion of the database 288 | contents in a database in which You have Sui Generis Database 289 | Rights, then the database in which You have Sui Generis Database 290 | Rights (but not its individual contents) is Adapted Material; and 291 | 292 | c. You must comply with the conditions in Section 3(a) if You Share 293 | all or a substantial portion of the contents of the database. 294 | 295 | For the avoidance of doubt, this Section 4 supplements and does not 296 | replace Your obligations under this Public License where the Licensed 297 | Rights include other Copyright and Similar Rights. 298 | 299 | 300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 301 | 302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 312 | 313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 322 | 323 | c. The disclaimer of warranties and limitation of liability provided 324 | above shall be interpreted in a manner that, to the extent 325 | possible, most closely approximates an absolute disclaimer and 326 | waiver of all liability. 327 | 328 | 329 | Section 6 -- Term and Termination. 330 | 331 | a. This Public License applies for the term of the Copyright and 332 | Similar Rights licensed here. However, if You fail to comply with 333 | this Public License, then Your rights under this Public License 334 | terminate automatically. 335 | 336 | b. Where Your right to use the Licensed Material has terminated under 337 | Section 6(a), it reinstates: 338 | 339 | 1. automatically as of the date the violation is cured, provided 340 | it is cured within 30 days of Your discovery of the 341 | violation; or 342 | 343 | 2. upon express reinstatement by the Licensor. 344 | 345 | For the avoidance of doubt, this Section 6(b) does not affect any 346 | right the Licensor may have to seek remedies for Your violations 347 | of this Public License. 348 | 349 | c. For the avoidance of doubt, the Licensor may also offer the 350 | Licensed Material under separate terms or conditions or stop 351 | distributing the Licensed Material at any time; however, doing so 352 | will not terminate this Public License. 353 | 354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 355 | License. 356 | 357 | 358 | Section 7 -- Other Terms and Conditions. 359 | 360 | a. The Licensor shall not be bound by any additional or different 361 | terms or conditions communicated by You unless expressly agreed. 362 | 363 | b. Any arrangements, understandings, or agreements regarding the 364 | Licensed Material not stated herein are separate from and 365 | independent of the terms and conditions of this Public License. 366 | 367 | 368 | Section 8 -- Interpretation. 369 | 370 | a. For the avoidance of doubt, this Public License does not, and 371 | shall not be interpreted to, reduce, limit, restrict, or impose 372 | conditions on any use of the Licensed Material that could lawfully 373 | be made without permission under this Public License. 374 | 375 | b. To the extent possible, if any provision of this Public License is 376 | deemed unenforceable, it shall be automatically reformed to the 377 | minimum extent necessary to make it enforceable. If the provision 378 | cannot be reformed, it shall be severed from this Public License 379 | without affecting the enforceability of the remaining terms and 380 | conditions. 381 | 382 | c. No term or condition of this Public License will be waived and no 383 | failure to comply consented to unless expressly agreed to by the 384 | Licensor. 385 | 386 | d. Nothing in this Public License constitutes or may be interpreted 387 | as a limitation upon, or waiver of, any privileges and immunities 388 | that apply to the Licensor or You, including from the legal 389 | processes of any jurisdiction or authority. 390 | 391 | ======================================================================= 392 | 393 | Creative Commons is not a party to its public 394 | licenses. Notwithstanding, Creative Commons may elect to apply one of 395 | its public licenses to material it publishes and in those instances 396 | will be considered the "Licensor." The text of the Creative Commons 397 | public licenses is dedicated to the public domain under the CC0 Public 398 | Domain Dedication. Except for the limited purpose of indicating that 399 | material is shared under a Creative Commons public license or as 400 | otherwise permitted by the Creative Commons policies published at 401 | creativecommons.org/policies, Creative Commons does not authorize the 402 | use of the trademark "Creative Commons" or any other trademark or logo 403 | of Creative Commons without its prior written consent including, 404 | without limitation, in connection with any unauthorized modifications 405 | to any of its public licenses or any other arrangements, 406 | understandings, or agreements concerning use of licensed material. For 407 | the avoidance of doubt, this paragraph does not form part of the 408 | public licenses. 409 | 410 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /external/dataset_tool_h5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import glob 9 | import os 10 | import sys 11 | import argparse 12 | import h5py 13 | import fnmatch 14 | 15 | import PIL.Image 16 | import numpy as np 17 | 18 | from collections import defaultdict 19 | 20 | size_stats = defaultdict(int) 21 | format_stats = defaultdict(int) 22 | 23 | 24 | def load_image(fname): 25 | global format_stats, size_stats 26 | im = PIL.Image.open(fname) 27 | format_stats[im.mode] += 1 28 | if (im.width < 256 or im.height < 256): 29 | size_stats['< 256x256'] += 1 30 | else: 31 | size_stats['>= 256x256'] += 1 32 | arr = np.array(im.convert('RGB'), dtype=np.uint8) 33 | # print(arr.shape) 34 | # print(arr) 35 | # arr = arr.transpose([2, 0, 1]) 36 | # print(arr[:, 1:5, 1:3]) 37 | # print(arr.shape) 38 | # exit() 39 | assert len(arr.shape) == 3 40 | return arr.transpose([2, 0, 1]) 41 | 42 | 43 | def filter_image_sizes(images): 44 | filtered = [] 45 | for idx, fname in enumerate(images): 46 | if (idx % 100) == 0: 47 | print ('loading images', idx, '/', len(images)) 48 | try: 49 | with PIL.Image.open(fname) as img: 50 | w = img.size[0] 51 | h = img.size[1] 52 | if (w > 512 or h > 512) or (w < 256 or h < 256): 53 | continue 54 | filtered.append((fname, w, h)) 55 | except: 56 | print ('Could not load image', fname, 'skipping file..') 57 | return filtered 58 | 59 | 60 | examples='''examples: 61 | 62 | python %(prog)s --input-dir=./ILSVRC2012_img_val --out=imagenet_val_raw.h5 63 | ''' 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='Convert a set of image files into a HDF5 dataset file.', 68 | epilog=examples, 69 | formatter_class=argparse.RawDescriptionHelpFormatter 70 | ) 71 | parser.add_argument("--input-dir", help="Directory containing ImageNet images (can be glob pattern for subdirs)") 72 | parser.add_argument("--out", help="Filename of the output file") 73 | parser.add_argument("--max-files", help="Convert up to max-files images. Process all if unspecified.") 74 | args = parser.parse_args() 75 | 76 | if args.input_dir is None: 77 | print ('Must specify input file directory with --input-dir') 78 | sys.exit(1) 79 | if args.out is None: 80 | print ('Must specify output filename with --out') 81 | sys.exit(1) 82 | 83 | print ('Loading image list from %s' % args.input_dir) 84 | images = [] 85 | pattern = os.path.join(args.input_dir, '**/*') 86 | all_fnames = glob.glob(pattern, recursive=True) 87 | for fname in all_fnames: 88 | # include only JPEG/jpg/png 89 | if fnmatch.fnmatch(fname, '*.JPEG') or fnmatch.fnmatch(fname, '*.jpg') or fnmatch.fnmatch(fname, '*.png'): 90 | images.append(fname) 91 | images = sorted(images) 92 | np.random.RandomState(0xbadf00d).shuffle(images) 93 | 94 | filtered = filter_image_sizes(images) 95 | if args.max_files: 96 | filtered = filtered[0:int(args.max_files)] 97 | 98 | # ---------------------------------------------------------- 99 | outdir = os.path.dirname(args.out) 100 | if outdir != '': 101 | os.makedirs(outdir, exist_ok=True) 102 | num_images = len(filtered) 103 | num_pixels_total = 0 104 | with h5py.File(args.out, 'w') as h5file: 105 | dt = h5py.special_dtype(vlen=np.dtype('uint8')) 106 | dset_shapes = h5file.create_dataset('shapes', (num_images, 3), dtype=np.int32) 107 | dset_images = h5file.create_dataset('images', (num_images,), dtype=dt) 108 | for (idx, (imgname, w, h)) in enumerate(filtered): 109 | print ("%d/%d: %s" % (idx+1, len(filtered), imgname)) 110 | dset_images[idx] = load_image(imgname).flatten() 111 | dset_shapes[idx] = (3, h, w) 112 | num_pixels_total += h*w 113 | 114 | print ('Dataset statistics:') 115 | print (' Total pixels', num_pixels_total) 116 | print (' Formats:') 117 | for key in format_stats: 118 | print (' %s: %d images' % (key, format_stats[key])) 119 | print (' width,height buckets:') 120 | for key in size_stats: 121 | print (' %s: %d images' % (key, size_stats[key])) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /external/download_kodak.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import os 9 | import sys 10 | import argparse 11 | 12 | from urllib.request import urlretrieve 13 | 14 | examples='''examples: 15 | 16 | python %(prog)s --output-dir=./tmp 17 | ''' 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser( 21 | description='Download the Kodak dataset .PNG image files.', 22 | epilog=examples, 23 | formatter_class=argparse.RawDescriptionHelpFormatter 24 | 25 | ) 26 | parser.add_argument("--output-dir", help="Directory where to save the Kodak dataset .PNGs") 27 | args = parser.parse_args() 28 | 29 | if args.output_dir is None: 30 | print ('Must specify output directory where to store tfrecords with --output-dir') 31 | sys.exit(1) 32 | 33 | os.makedirs(args.output_dir, exist_ok=True) 34 | 35 | for i in range(1, 25): 36 | imgname = 'kodim%02d.png' % i 37 | url = "http://r0k.us/graphics/kodak/kodak/" + imgname 38 | print ('Downloading', url) 39 | urlretrieve(url, os.path.join(args.output_dir, imgname)) 40 | print ('Kodak validation set successfully downloaded.') 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /figures/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/figures/example.png -------------------------------------------------------------------------------- /models/final-n2c-gauss25.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-n2c-gauss25.wt -------------------------------------------------------------------------------- /models/final-n2c-gauss25_nc.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-n2c-gauss25_nc.wt -------------------------------------------------------------------------------- /models/final-n2n-gauss25.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-n2n-gauss25.wt -------------------------------------------------------------------------------- /models/final-n2n-gauss25_nc.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-n2n-gauss25_nc.wt -------------------------------------------------------------------------------- /models/final-n2v-gauss25.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-n2v-gauss25.wt -------------------------------------------------------------------------------- /models/final-n2v-gauss25_nc.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-n2v-gauss25_nc.wt -------------------------------------------------------------------------------- /models/final-ssdn-gauss25-sigma_const.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-ssdn-gauss25-sigma_const.wt -------------------------------------------------------------------------------- /models/final-ssdn-gauss25-sigma_known.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-ssdn-gauss25-sigma_known.wt -------------------------------------------------------------------------------- /models/final-ssdn-gauss25-sigma_var.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-ssdn-gauss25-sigma_var.wt -------------------------------------------------------------------------------- /models/final-ssdn-gauss25_nc-sigma_const.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-ssdn-gauss25_nc-sigma_const.wt -------------------------------------------------------------------------------- /models/final-ssdn-gauss25_nc-sigma_known.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-ssdn-gauss25_nc-sigma_known.wt -------------------------------------------------------------------------------- /models/final-ssdn-gauss25_nc-sigma_var.wt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/models/final-ssdn-gauss25_nc-sigma_var.wt -------------------------------------------------------------------------------- /report/iclr2020_conference.sty: -------------------------------------------------------------------------------- 1 | %%%% ICLR Macros (LaTex) 2 | %%%% Adapted by Hugo Larochelle from the NIPS stylefile Macros 3 | %%%% Style File 4 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999; October 2014 5 | 6 | % This file can be used with Latex2e whether running in main mode, or 7 | % 2.09 compatibility mode. 8 | % 9 | % If using main mode, you need to include the commands 10 | % \documentclass{article} 11 | % \usepackage{iclr14submit_e,times} 12 | % 13 | 14 | % Change the overall width of the page. If these parameters are 15 | % changed, they will require corresponding changes in the 16 | % maketitle section. 17 | % 18 | \usepackage{eso-pic} % used by \AddToShipoutPicture 19 | \RequirePackage{fancyhdr} 20 | \RequirePackage{natbib} 21 | 22 | % modification to natbib citations 23 | \setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}} 24 | 25 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 26 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 27 | 28 | % Define iclrfinal, set to true if iclrfinalcopy is defined 29 | \newif\ificlrfinal 30 | \iclrfinalfalse 31 | \def\iclrfinalcopy{\iclrfinaltrue} 32 | \font\iclrtenhv = phvb at 8pt 33 | 34 | % Specify the dimensions of each page 35 | 36 | \setlength{\paperheight}{11in} 37 | \setlength{\paperwidth}{8.5in} 38 | 39 | 40 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 41 | \evensidemargin .5in 42 | \marginparwidth 0.07 true in 43 | %\marginparwidth 0.75 true in 44 | %\topmargin 0 true pt % Nominal distance from top of page to top of 45 | %\topmargin 0.125in 46 | \topmargin -0.625in 47 | \addtolength{\headsep}{0.25in} 48 | \textheight 9.0 true in % Height of text (including footnotes & figures) 49 | \textwidth 5.5 true in % Width of text line. 50 | \widowpenalty=10000 51 | \clubpenalty=10000 52 | 53 | % \thispagestyle{empty} \pagestyle{empty} 54 | \flushbottom \sloppy 55 | 56 | % We're never going to need a table of contents, so just flush it to 57 | % save space --- suggested by drstrip@sandia-2 58 | \def\addcontentsline#1#2#3{} 59 | 60 | % Title stuff, taken from deproc. 61 | \def\maketitle{\par 62 | \begingroup 63 | \def\thefootnote{\fnsymbol{footnote}} 64 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 65 | % name centering 66 | % The footnote-mark was overlapping the footnote-text, 67 | % added the following to fix this problem (MK) 68 | \long\def\@makefntext##1{\parindent 1em\noindent 69 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 70 | \@maketitle \@thanks 71 | \endgroup 72 | \setcounter{footnote}{0} 73 | \let\maketitle\relax \let\@maketitle\relax 74 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 75 | 76 | % The toptitlebar has been raised to top-justify the first page 77 | 78 | \usepackage{fancyhdr} 79 | \pagestyle{fancy} 80 | \fancyhead{} 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | %\linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\sc \@title\par} 86 | %\bottomtitlebar % \vskip 0.1in % minus 87 | \ificlrfinal 88 | \lhead{COMP6248 Reproducibility Challenge} 89 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 90 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 91 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 92 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 93 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 94 | \else 95 | \lhead{Under review as a conference paper at ICLR 2020} 96 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 97 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 98 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 99 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 100 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}Anonymous authors\\Paper under double-blind review\end{tabular}% 101 | \fi 102 | \vskip 0.3in minus 0.1in}} 103 | 104 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\sc 105 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 106 | 107 | % sections with less space 108 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 109 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 110 | minus0.2ex}{\large\sc\raggedright}} 111 | 112 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 113 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\sc\raggedright}} 114 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 115 | plus -0.5ex minus -.2ex}{0.5ex plus 116 | .2ex}{\normalsize\sc\raggedright}} 117 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 118 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 119 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 120 | 0.5ex minus .2ex}{-1em}{\normalsize\sc}} 121 | \def\subsubsubsection{\vskip 122 | 5pt{\noindent\normalsize\rm\raggedright}} 123 | 124 | 125 | % Footnotes 126 | \footnotesep 6.65pt % 127 | \skip\footins 9pt plus 4pt minus 2pt 128 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 129 | \setcounter{footnote}{0} 130 | 131 | % Lists and paragraphs 132 | \parindent 0pt 133 | \topsep 4pt plus 1pt minus 2pt 134 | \partopsep 1pt plus 0.5pt minus 0.5pt 135 | \itemsep 2pt plus 1pt minus 0.5pt 136 | \parsep 2pt plus 1pt minus 0.5pt 137 | \parskip .5pc 138 | 139 | 140 | %\leftmargin2em 141 | \leftmargin3pc 142 | \leftmargini\leftmargin \leftmarginii 2em 143 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 144 | 145 | %\labelsep \labelsep 5pt 146 | 147 | \def\@listi{\leftmargin\leftmargini} 148 | \def\@listii{\leftmargin\leftmarginii 149 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 150 | \topsep 2pt plus 1pt minus 0.5pt 151 | \parsep 1pt plus 0.5pt minus 0.5pt 152 | \itemsep \parsep} 153 | \def\@listiii{\leftmargin\leftmarginiii 154 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 155 | \topsep 1pt plus 0.5pt minus 0.5pt 156 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 157 | \itemsep \topsep} 158 | \def\@listiv{\leftmargin\leftmarginiv 159 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 160 | \def\@listv{\leftmargin\leftmarginv 161 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 162 | \def\@listvi{\leftmargin\leftmarginvi 163 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 164 | 165 | \abovedisplayskip 7pt plus2pt minus5pt% 166 | \belowdisplayskip \abovedisplayskip 167 | \abovedisplayshortskip 0pt plus3pt% 168 | \belowdisplayshortskip 4pt plus3pt minus3pt% 169 | 170 | % Less leading in most fonts (due to the narrow columns) 171 | % The choices were between 1-pt and 1.5-pt leading 172 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 173 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 174 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 175 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 176 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 177 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 178 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 179 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 180 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 181 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 182 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 183 | 184 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 185 | 186 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 187 | .09in} % 188 | %Reduced second vskip to compensate for adding the strut in \@author 189 | 190 | 191 | %% % Vertical Ruler 192 | %% % This code is, largely, from the CVPR 2010 conference style file 193 | %% % ----- define vruler 194 | %% \makeatletter 195 | %% \newbox\iclrrulerbox 196 | %% \newcount\iclrrulercount 197 | %% \newdimen\iclrruleroffset 198 | %% \newdimen\cv@lineheight 199 | %% \newdimen\cv@boxheight 200 | %% \newbox\cv@tmpbox 201 | %% \newcount\cv@refno 202 | %% \newcount\cv@tot 203 | %% % NUMBER with left flushed zeros \fillzeros[] 204 | %% \newcount\cv@tmpc@ \newcount\cv@tmpc 205 | %% \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 206 | %% \cv@tmpc=1 % 207 | %% \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 208 | %% \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 209 | %% \ifnum#2<0\advance\cv@tmpc1\relax-\fi 210 | %% \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 211 | %% \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 212 | %% % \makevruler[][][][][] 213 | %% \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 214 | %% \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 215 | %% \global\setbox\iclrrulerbox=\vbox to \textheight{% 216 | %% {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 217 | %% \cv@lineheight=#1\global\iclrrulercount=#2% 218 | %% \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 219 | %% \cv@refno1\vskip-\cv@lineheight\vskip1ex% 220 | %% \loop\setbox\cv@tmpbox=\hbox to0cm{{\iclrtenhv\hfil\fillzeros[#4]\iclrrulercount}}% 221 | %% \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 222 | %% \advance\cv@refno1\global\advance\iclrrulercount#3\relax 223 | %% \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 224 | %% \makeatother 225 | %% % ----- end of vruler 226 | 227 | %% % \makevruler[][][][][] 228 | %% \def\iclrruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\iclrrulerbox}} 229 | %% \AddToShipoutPicture{% 230 | %% \ificlrfinal\else 231 | %% \iclrruleroffset=\textheight 232 | %% \advance\iclrruleroffset by -3.7pt 233 | %% \color[rgb]{.7,.7,.7} 234 | %% \AtTextUpperLeft{% 235 | %% \put(\LenToUnit{-35pt},\LenToUnit{-\iclrruleroffset}){%left ruler 236 | %% \iclrruler{\iclrrulercount}} 237 | %% } 238 | %% \fi 239 | %% } 240 | %%% To add a vertical bar on the side 241 | %\AddToShipoutPicture{ 242 | %\AtTextLowerLeft{ 243 | %\hspace*{-1.8cm} 244 | %\colorbox[rgb]{0.7,0.7,0.7}{\small \parbox[b][\textheight]{0.1cm}{}}} 245 | %} 246 | -------------------------------------------------------------------------------- /report/images/CBM3D_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/CBM3D_cropped.png -------------------------------------------------------------------------------- /report/images/Clean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/Clean.png -------------------------------------------------------------------------------- /report/images/Clean_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/Clean_cropped.png -------------------------------------------------------------------------------- /report/images/N2C.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/N2C.png -------------------------------------------------------------------------------- /report/images/N2C_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/N2C_cropped.png -------------------------------------------------------------------------------- /report/images/N2N_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/N2N_cropped.png -------------------------------------------------------------------------------- /report/images/N2V_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/N2V_cropped.png -------------------------------------------------------------------------------- /report/images/Noisy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/Noisy.png -------------------------------------------------------------------------------- /report/images/Noisy_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/Noisy_cropped.png -------------------------------------------------------------------------------- /report/images/SSDN_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/SSDN_cropped.png -------------------------------------------------------------------------------- /report/images/SSDN_mu_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/images/SSDN_mu_cropped.png -------------------------------------------------------------------------------- /report/math_commands.tex: -------------------------------------------------------------------------------- 1 | %%%%% NEW MATH DEFINITIONS %%%%% 2 | 3 | \usepackage{amsmath,amsfonts,bm} 4 | 5 | % Mark sections of captions for referring to divisions of figures 6 | \newcommand{\figleft}{{\em (Left)}} 7 | \newcommand{\figcenter}{{\em (Center)}} 8 | \newcommand{\figright}{{\em (Right)}} 9 | \newcommand{\figtop}{{\em (Top)}} 10 | \newcommand{\figbottom}{{\em (Bottom)}} 11 | \newcommand{\captiona}{{\em (a)}} 12 | \newcommand{\captionb}{{\em (b)}} 13 | \newcommand{\captionc}{{\em (c)}} 14 | \newcommand{\captiond}{{\em (d)}} 15 | 16 | % Highlight a newly defined term 17 | \newcommand{\newterm}[1]{{\bf #1}} 18 | 19 | 20 | % Figure reference, lower-case. 21 | \def\figref#1{figure~\ref{#1}} 22 | % Figure reference, capital. For start of sentence 23 | \def\Figref#1{Figure~\ref{#1}} 24 | \def\twofigref#1#2{figures \ref{#1} and \ref{#2}} 25 | \def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}} 26 | % Section reference, lower-case. 27 | \def\secref#1{section~\ref{#1}} 28 | % Section reference, capital. 29 | \def\Secref#1{Section~\ref{#1}} 30 | % Reference to two sections. 31 | \def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}} 32 | % Reference to three sections. 33 | \def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}} 34 | % Reference to an equation, lower-case. 35 | \def\eqref#1{equation~\ref{#1}} 36 | % Reference to an equation, upper case 37 | \def\Eqref#1{Equation~\ref{#1}} 38 | % A raw reference to an equation---avoid using if possible 39 | \def\plaineqref#1{\ref{#1}} 40 | % Reference to a chapter, lower-case. 41 | \def\chapref#1{chapter~\ref{#1}} 42 | % Reference to an equation, upper case. 43 | \def\Chapref#1{Chapter~\ref{#1}} 44 | % Reference to a range of chapters 45 | \def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}} 46 | % Reference to an algorithm, lower-case. 47 | \def\algref#1{algorithm~\ref{#1}} 48 | % Reference to an algorithm, upper case. 49 | \def\Algref#1{Algorithm~\ref{#1}} 50 | \def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}} 51 | \def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}} 52 | % Reference to a part, lower case 53 | \def\partref#1{part~\ref{#1}} 54 | % Reference to a part, upper case 55 | \def\Partref#1{Part~\ref{#1}} 56 | \def\twopartref#1#2{parts \ref{#1} and \ref{#2}} 57 | 58 | \def\ceil#1{\lceil #1 \rceil} 59 | \def\floor#1{\lfloor #1 \rfloor} 60 | \def\1{\bm{1}} 61 | \newcommand{\train}{\mathcal{D}} 62 | \newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}} 63 | \newcommand{\test}{\mathcal{D_{\mathrm{test}}}} 64 | 65 | \def\eps{{\epsilon}} 66 | 67 | 68 | % Random variables 69 | \def\reta{{\textnormal{$\eta$}}} 70 | \def\ra{{\textnormal{a}}} 71 | \def\rb{{\textnormal{b}}} 72 | \def\rc{{\textnormal{c}}} 73 | \def\rd{{\textnormal{d}}} 74 | \def\re{{\textnormal{e}}} 75 | \def\rf{{\textnormal{f}}} 76 | \def\rg{{\textnormal{g}}} 77 | \def\rh{{\textnormal{h}}} 78 | \def\ri{{\textnormal{i}}} 79 | \def\rj{{\textnormal{j}}} 80 | \def\rk{{\textnormal{k}}} 81 | \def\rl{{\textnormal{l}}} 82 | % rm is already a command, just don't name any random variables m 83 | \def\rn{{\textnormal{n}}} 84 | \def\ro{{\textnormal{o}}} 85 | \def\rp{{\textnormal{p}}} 86 | \def\rq{{\textnormal{q}}} 87 | \def\rr{{\textnormal{r}}} 88 | \def\rs{{\textnormal{s}}} 89 | \def\rt{{\textnormal{t}}} 90 | \def\ru{{\textnormal{u}}} 91 | \def\rv{{\textnormal{v}}} 92 | \def\rw{{\textnormal{w}}} 93 | \def\rx{{\textnormal{x}}} 94 | \def\ry{{\textnormal{y}}} 95 | \def\rz{{\textnormal{z}}} 96 | 97 | % Random vectors 98 | \def\rvepsilon{{\mathbf{\epsilon}}} 99 | \def\rvtheta{{\mathbf{\theta}}} 100 | \def\rva{{\mathbf{a}}} 101 | \def\rvb{{\mathbf{b}}} 102 | \def\rvc{{\mathbf{c}}} 103 | \def\rvd{{\mathbf{d}}} 104 | \def\rve{{\mathbf{e}}} 105 | \def\rvf{{\mathbf{f}}} 106 | \def\rvg{{\mathbf{g}}} 107 | \def\rvh{{\mathbf{h}}} 108 | \def\rvu{{\mathbf{i}}} 109 | \def\rvj{{\mathbf{j}}} 110 | \def\rvk{{\mathbf{k}}} 111 | \def\rvl{{\mathbf{l}}} 112 | \def\rvm{{\mathbf{m}}} 113 | \def\rvn{{\mathbf{n}}} 114 | \def\rvo{{\mathbf{o}}} 115 | \def\rvp{{\mathbf{p}}} 116 | \def\rvq{{\mathbf{q}}} 117 | \def\rvr{{\mathbf{r}}} 118 | \def\rvs{{\mathbf{s}}} 119 | \def\rvt{{\mathbf{t}}} 120 | \def\rvu{{\mathbf{u}}} 121 | \def\rvv{{\mathbf{v}}} 122 | \def\rvw{{\mathbf{w}}} 123 | \def\rvx{{\mathbf{x}}} 124 | \def\rvy{{\mathbf{y}}} 125 | \def\rvz{{\mathbf{z}}} 126 | 127 | % Elements of random vectors 128 | \def\erva{{\textnormal{a}}} 129 | \def\ervb{{\textnormal{b}}} 130 | \def\ervc{{\textnormal{c}}} 131 | \def\ervd{{\textnormal{d}}} 132 | \def\erve{{\textnormal{e}}} 133 | \def\ervf{{\textnormal{f}}} 134 | \def\ervg{{\textnormal{g}}} 135 | \def\ervh{{\textnormal{h}}} 136 | \def\ervi{{\textnormal{i}}} 137 | \def\ervj{{\textnormal{j}}} 138 | \def\ervk{{\textnormal{k}}} 139 | \def\ervl{{\textnormal{l}}} 140 | \def\ervm{{\textnormal{m}}} 141 | \def\ervn{{\textnormal{n}}} 142 | \def\ervo{{\textnormal{o}}} 143 | \def\ervp{{\textnormal{p}}} 144 | \def\ervq{{\textnormal{q}}} 145 | \def\ervr{{\textnormal{r}}} 146 | \def\ervs{{\textnormal{s}}} 147 | \def\ervt{{\textnormal{t}}} 148 | \def\ervu{{\textnormal{u}}} 149 | \def\ervv{{\textnormal{v}}} 150 | \def\ervw{{\textnormal{w}}} 151 | \def\ervx{{\textnormal{x}}} 152 | \def\ervy{{\textnormal{y}}} 153 | \def\ervz{{\textnormal{z}}} 154 | 155 | % Random matrices 156 | \def\rmA{{\mathbf{A}}} 157 | \def\rmB{{\mathbf{B}}} 158 | \def\rmC{{\mathbf{C}}} 159 | \def\rmD{{\mathbf{D}}} 160 | \def\rmE{{\mathbf{E}}} 161 | \def\rmF{{\mathbf{F}}} 162 | \def\rmG{{\mathbf{G}}} 163 | \def\rmH{{\mathbf{H}}} 164 | \def\rmI{{\mathbf{I}}} 165 | \def\rmJ{{\mathbf{J}}} 166 | \def\rmK{{\mathbf{K}}} 167 | \def\rmL{{\mathbf{L}}} 168 | \def\rmM{{\mathbf{M}}} 169 | \def\rmN{{\mathbf{N}}} 170 | \def\rmO{{\mathbf{O}}} 171 | \def\rmP{{\mathbf{P}}} 172 | \def\rmQ{{\mathbf{Q}}} 173 | \def\rmR{{\mathbf{R}}} 174 | \def\rmS{{\mathbf{S}}} 175 | \def\rmT{{\mathbf{T}}} 176 | \def\rmU{{\mathbf{U}}} 177 | \def\rmV{{\mathbf{V}}} 178 | \def\rmW{{\mathbf{W}}} 179 | \def\rmX{{\mathbf{X}}} 180 | \def\rmY{{\mathbf{Y}}} 181 | \def\rmZ{{\mathbf{Z}}} 182 | 183 | % Elements of random matrices 184 | \def\ermA{{\textnormal{A}}} 185 | \def\ermB{{\textnormal{B}}} 186 | \def\ermC{{\textnormal{C}}} 187 | \def\ermD{{\textnormal{D}}} 188 | \def\ermE{{\textnormal{E}}} 189 | \def\ermF{{\textnormal{F}}} 190 | \def\ermG{{\textnormal{G}}} 191 | \def\ermH{{\textnormal{H}}} 192 | \def\ermI{{\textnormal{I}}} 193 | \def\ermJ{{\textnormal{J}}} 194 | \def\ermK{{\textnormal{K}}} 195 | \def\ermL{{\textnormal{L}}} 196 | \def\ermM{{\textnormal{M}}} 197 | \def\ermN{{\textnormal{N}}} 198 | \def\ermO{{\textnormal{O}}} 199 | \def\ermP{{\textnormal{P}}} 200 | \def\ermQ{{\textnormal{Q}}} 201 | \def\ermR{{\textnormal{R}}} 202 | \def\ermS{{\textnormal{S}}} 203 | \def\ermT{{\textnormal{T}}} 204 | \def\ermU{{\textnormal{U}}} 205 | \def\ermV{{\textnormal{V}}} 206 | \def\ermW{{\textnormal{W}}} 207 | \def\ermX{{\textnormal{X}}} 208 | \def\ermY{{\textnormal{Y}}} 209 | \def\ermZ{{\textnormal{Z}}} 210 | 211 | % Vectors 212 | \def\vzero{{\bm{0}}} 213 | \def\vone{{\bm{1}}} 214 | \def\vmu{{\bm{\mu}}} 215 | \def\vtheta{{\bm{\theta}}} 216 | \def\va{{\bm{a}}} 217 | \def\vb{{\bm{b}}} 218 | \def\vc{{\bm{c}}} 219 | \def\vd{{\bm{d}}} 220 | \def\ve{{\bm{e}}} 221 | \def\vf{{\bm{f}}} 222 | \def\vg{{\bm{g}}} 223 | \def\vh{{\bm{h}}} 224 | \def\vi{{\bm{i}}} 225 | \def\vj{{\bm{j}}} 226 | \def\vk{{\bm{k}}} 227 | \def\vl{{\bm{l}}} 228 | \def\vm{{\bm{m}}} 229 | \def\vn{{\bm{n}}} 230 | \def\vo{{\bm{o}}} 231 | \def\vp{{\bm{p}}} 232 | \def\vq{{\bm{q}}} 233 | \def\vr{{\bm{r}}} 234 | \def\vs{{\bm{s}}} 235 | \def\vt{{\bm{t}}} 236 | \def\vu{{\bm{u}}} 237 | \def\vv{{\bm{v}}} 238 | \def\vw{{\bm{w}}} 239 | \def\vx{{\bm{x}}} 240 | \def\vy{{\bm{y}}} 241 | \def\vz{{\bm{z}}} 242 | 243 | % Elements of vectors 244 | \def\evalpha{{\alpha}} 245 | \def\evbeta{{\beta}} 246 | \def\evepsilon{{\epsilon}} 247 | \def\evlambda{{\lambda}} 248 | \def\evomega{{\omega}} 249 | \def\evmu{{\mu}} 250 | \def\evpsi{{\psi}} 251 | \def\evsigma{{\sigma}} 252 | \def\evtheta{{\theta}} 253 | \def\eva{{a}} 254 | \def\evb{{b}} 255 | \def\evc{{c}} 256 | \def\evd{{d}} 257 | \def\eve{{e}} 258 | \def\evf{{f}} 259 | \def\evg{{g}} 260 | \def\evh{{h}} 261 | \def\evi{{i}} 262 | \def\evj{{j}} 263 | \def\evk{{k}} 264 | \def\evl{{l}} 265 | \def\evm{{m}} 266 | \def\evn{{n}} 267 | \def\evo{{o}} 268 | \def\evp{{p}} 269 | \def\evq{{q}} 270 | \def\evr{{r}} 271 | \def\evs{{s}} 272 | \def\evt{{t}} 273 | \def\evu{{u}} 274 | \def\evv{{v}} 275 | \def\evw{{w}} 276 | \def\evx{{x}} 277 | \def\evy{{y}} 278 | \def\evz{{z}} 279 | 280 | % Matrix 281 | \def\mA{{\bm{A}}} 282 | \def\mB{{\bm{B}}} 283 | \def\mC{{\bm{C}}} 284 | \def\mD{{\bm{D}}} 285 | \def\mE{{\bm{E}}} 286 | \def\mF{{\bm{F}}} 287 | \def\mG{{\bm{G}}} 288 | \def\mH{{\bm{H}}} 289 | \def\mI{{\bm{I}}} 290 | \def\mJ{{\bm{J}}} 291 | \def\mK{{\bm{K}}} 292 | \def\mL{{\bm{L}}} 293 | \def\mM{{\bm{M}}} 294 | \def\mN{{\bm{N}}} 295 | \def\mO{{\bm{O}}} 296 | \def\mP{{\bm{P}}} 297 | \def\mQ{{\bm{Q}}} 298 | \def\mR{{\bm{R}}} 299 | \def\mS{{\bm{S}}} 300 | \def\mT{{\bm{T}}} 301 | \def\mU{{\bm{U}}} 302 | \def\mV{{\bm{V}}} 303 | \def\mW{{\bm{W}}} 304 | \def\mX{{\bm{X}}} 305 | \def\mY{{\bm{Y}}} 306 | \def\mZ{{\bm{Z}}} 307 | \def\mBeta{{\bm{\beta}}} 308 | \def\mPhi{{\bm{\Phi}}} 309 | \def\mLambda{{\bm{\Lambda}}} 310 | \def\mSigma{{\bm{\Sigma}}} 311 | 312 | % Tensor 313 | \DeclareMathAlphabet{\mathsfit}{\encodingdefault}{\sfdefault}{m}{sl} 314 | \SetMathAlphabet{\mathsfit}{bold}{\encodingdefault}{\sfdefault}{bx}{n} 315 | \newcommand{\tens}[1]{\bm{\mathsfit{#1}}} 316 | \def\tA{{\tens{A}}} 317 | \def\tB{{\tens{B}}} 318 | \def\tC{{\tens{C}}} 319 | \def\tD{{\tens{D}}} 320 | \def\tE{{\tens{E}}} 321 | \def\tF{{\tens{F}}} 322 | \def\tG{{\tens{G}}} 323 | \def\tH{{\tens{H}}} 324 | \def\tI{{\tens{I}}} 325 | \def\tJ{{\tens{J}}} 326 | \def\tK{{\tens{K}}} 327 | \def\tL{{\tens{L}}} 328 | \def\tM{{\tens{M}}} 329 | \def\tN{{\tens{N}}} 330 | \def\tO{{\tens{O}}} 331 | \def\tP{{\tens{P}}} 332 | \def\tQ{{\tens{Q}}} 333 | \def\tR{{\tens{R}}} 334 | \def\tS{{\tens{S}}} 335 | \def\tT{{\tens{T}}} 336 | \def\tU{{\tens{U}}} 337 | \def\tV{{\tens{V}}} 338 | \def\tW{{\tens{W}}} 339 | \def\tX{{\tens{X}}} 340 | \def\tY{{\tens{Y}}} 341 | \def\tZ{{\tens{Z}}} 342 | 343 | 344 | % Graph 345 | \def\gA{{\mathcal{A}}} 346 | \def\gB{{\mathcal{B}}} 347 | \def\gC{{\mathcal{C}}} 348 | \def\gD{{\mathcal{D}}} 349 | \def\gE{{\mathcal{E}}} 350 | \def\gF{{\mathcal{F}}} 351 | \def\gG{{\mathcal{G}}} 352 | \def\gH{{\mathcal{H}}} 353 | \def\gI{{\mathcal{I}}} 354 | \def\gJ{{\mathcal{J}}} 355 | \def\gK{{\mathcal{K}}} 356 | \def\gL{{\mathcal{L}}} 357 | \def\gM{{\mathcal{M}}} 358 | \def\gN{{\mathcal{N}}} 359 | \def\gO{{\mathcal{O}}} 360 | \def\gP{{\mathcal{P}}} 361 | \def\gQ{{\mathcal{Q}}} 362 | \def\gR{{\mathcal{R}}} 363 | \def\gS{{\mathcal{S}}} 364 | \def\gT{{\mathcal{T}}} 365 | \def\gU{{\mathcal{U}}} 366 | \def\gV{{\mathcal{V}}} 367 | \def\gW{{\mathcal{W}}} 368 | \def\gX{{\mathcal{X}}} 369 | \def\gY{{\mathcal{Y}}} 370 | \def\gZ{{\mathcal{Z}}} 371 | 372 | % Sets 373 | \def\sA{{\mathbb{A}}} 374 | \def\sB{{\mathbb{B}}} 375 | \def\sC{{\mathbb{C}}} 376 | \def\sD{{\mathbb{D}}} 377 | % Don't use a set called E, because this would be the same as our symbol 378 | % for expectation. 379 | \def\sF{{\mathbb{F}}} 380 | \def\sG{{\mathbb{G}}} 381 | \def\sH{{\mathbb{H}}} 382 | \def\sI{{\mathbb{I}}} 383 | \def\sJ{{\mathbb{J}}} 384 | \def\sK{{\mathbb{K}}} 385 | \def\sL{{\mathbb{L}}} 386 | \def\sM{{\mathbb{M}}} 387 | \def\sN{{\mathbb{N}}} 388 | \def\sO{{\mathbb{O}}} 389 | \def\sP{{\mathbb{P}}} 390 | \def\sQ{{\mathbb{Q}}} 391 | \def\sR{{\mathbb{R}}} 392 | \def\sS{{\mathbb{S}}} 393 | \def\sT{{\mathbb{T}}} 394 | \def\sU{{\mathbb{U}}} 395 | \def\sV{{\mathbb{V}}} 396 | \def\sW{{\mathbb{W}}} 397 | \def\sX{{\mathbb{X}}} 398 | \def\sY{{\mathbb{Y}}} 399 | \def\sZ{{\mathbb{Z}}} 400 | 401 | % Entries of a matrix 402 | \def\emLambda{{\Lambda}} 403 | \def\emA{{A}} 404 | \def\emB{{B}} 405 | \def\emC{{C}} 406 | \def\emD{{D}} 407 | \def\emE{{E}} 408 | \def\emF{{F}} 409 | \def\emG{{G}} 410 | \def\emH{{H}} 411 | \def\emI{{I}} 412 | \def\emJ{{J}} 413 | \def\emK{{K}} 414 | \def\emL{{L}} 415 | \def\emM{{M}} 416 | \def\emN{{N}} 417 | \def\emO{{O}} 418 | \def\emP{{P}} 419 | \def\emQ{{Q}} 420 | \def\emR{{R}} 421 | \def\emS{{S}} 422 | \def\emT{{T}} 423 | \def\emU{{U}} 424 | \def\emV{{V}} 425 | \def\emW{{W}} 426 | \def\emX{{X}} 427 | \def\emY{{Y}} 428 | \def\emZ{{Z}} 429 | \def\emSigma{{\Sigma}} 430 | 431 | % entries of a tensor 432 | % Same font as tensor, without \bm wrapper 433 | \newcommand{\etens}[1]{\mathsfit{#1}} 434 | \def\etLambda{{\etens{\Lambda}}} 435 | \def\etA{{\etens{A}}} 436 | \def\etB{{\etens{B}}} 437 | \def\etC{{\etens{C}}} 438 | \def\etD{{\etens{D}}} 439 | \def\etE{{\etens{E}}} 440 | \def\etF{{\etens{F}}} 441 | \def\etG{{\etens{G}}} 442 | \def\etH{{\etens{H}}} 443 | \def\etI{{\etens{I}}} 444 | \def\etJ{{\etens{J}}} 445 | \def\etK{{\etens{K}}} 446 | \def\etL{{\etens{L}}} 447 | \def\etM{{\etens{M}}} 448 | \def\etN{{\etens{N}}} 449 | \def\etO{{\etens{O}}} 450 | \def\etP{{\etens{P}}} 451 | \def\etQ{{\etens{Q}}} 452 | \def\etR{{\etens{R}}} 453 | \def\etS{{\etens{S}}} 454 | \def\etT{{\etens{T}}} 455 | \def\etU{{\etens{U}}} 456 | \def\etV{{\etens{V}}} 457 | \def\etW{{\etens{W}}} 458 | \def\etX{{\etens{X}}} 459 | \def\etY{{\etens{Y}}} 460 | \def\etZ{{\etens{Z}}} 461 | 462 | % The true underlying data generating distribution 463 | \newcommand{\pdata}{p_{\rm{data}}} 464 | % The empirical distribution defined by the training set 465 | \newcommand{\ptrain}{\hat{p}_{\rm{data}}} 466 | \newcommand{\Ptrain}{\hat{P}_{\rm{data}}} 467 | % The model distribution 468 | \newcommand{\pmodel}{p_{\rm{model}}} 469 | \newcommand{\Pmodel}{P_{\rm{model}}} 470 | \newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}} 471 | % Stochastic autoencoder distributions 472 | \newcommand{\pencode}{p_{\rm{encoder}}} 473 | \newcommand{\pdecode}{p_{\rm{decoder}}} 474 | \newcommand{\precons}{p_{\rm{reconstruct}}} 475 | 476 | \newcommand{\laplace}{\mathrm{Laplace}} % Laplace distribution 477 | 478 | \newcommand{\E}{\mathbb{E}} 479 | \newcommand{\Ls}{\mathcal{L}} 480 | \newcommand{\R}{\mathbb{R}} 481 | \newcommand{\emp}{\tilde{p}} 482 | \newcommand{\lr}{\alpha} 483 | \newcommand{\reg}{\lambda} 484 | \newcommand{\rect}{\mathrm{rectifier}} 485 | \newcommand{\softmax}{\mathrm{softmax}} 486 | \newcommand{\sigmoid}{\sigma} 487 | \newcommand{\softplus}{\zeta} 488 | \newcommand{\KL}{D_{\mathrm{KL}}} 489 | \newcommand{\Var}{\mathrm{Var}} 490 | \newcommand{\standarderror}{\mathrm{SE}} 491 | \newcommand{\Cov}{\mathrm{Cov}} 492 | % Wolfram Mathworld says $L^2$ is for function spaces and $\ell^2$ is for vectors 493 | % But then they seem to use $L^2$ for vectors throughout the site, and so does 494 | % wikipedia. 495 | \newcommand{\normlzero}{L^0} 496 | \newcommand{\normlone}{L^1} 497 | \newcommand{\normltwo}{L^2} 498 | \newcommand{\normlp}{L^p} 499 | \newcommand{\normmax}{L^\infty} 500 | 501 | \newcommand{\parents}{Pa} % See usage in notation.tex. Chosen to match Daphne's book. 502 | 503 | \DeclareMathOperator*{\argmax}{arg\,max} 504 | \DeclareMathOperator*{\argmin}{arg\,min} 505 | 506 | \DeclareMathOperator{\sign}{sign} 507 | \DeclareMathOperator{\Tr}{Tr} 508 | \let\ab\allowbreak 509 | -------------------------------------------------------------------------------- /report/references.bib: -------------------------------------------------------------------------------- 1 | @incollection{Bengio+chapter2007, 2 | author = {Bengio, Yoshua and LeCun, Yann}, 3 | booktitle = {Large Scale Kernel Machines}, 4 | publisher = {MIT Press}, 5 | title = {Scaling Learning Algorithms Towards {AI}}, 6 | year = {2007} 7 | } 8 | 9 | @article{Hinton06, 10 | author = {Hinton, Geoffrey E. and Osindero, Simon and Teh, Yee Whye}, 11 | journal = {Neural Computation}, 12 | pages = {1527--1554}, 13 | title = {A Fast Learning Algorithm for Deep Belief Nets}, 14 | volume = {18}, 15 | year = {2006} 16 | } 17 | 18 | @book{goodfellow2016deep, 19 | title={Deep learning}, 20 | author={Goodfellow, Ian and Bengio, Yoshua and Courville, Aaron and Bengio, Yoshua}, 21 | volume={1}, 22 | year={2016}, 23 | publisher={MIT Press} 24 | } 25 | 26 | 27 | @article{noise2void, 28 | author = {Alexander Krull and 29 | Tim{-}Oliver Buchholz and 30 | Florian Jug}, 31 | title = {Noise2Void - Learning Denoising from Single Noisy Images}, 32 | journal = {CoRR}, 33 | volume = {abs/1811.10980}, 34 | year = {2018}, 35 | url = {http://arxiv.org/abs/1811.10980}, 36 | archivePrefix = {arXiv}, 37 | eprint = {1811.10980}, 38 | timestamp = {Fri, 30 Nov 2018 12:44:28 +0100}, 39 | biburl = {https://dblp.org/rec/journals/corr/abs-1811-10980.bib}, 40 | bibsource = {dblp computer science bibliography, https://dblp.org} 41 | } 42 | 43 | @article{noise2noise, 44 | author = {Jaakko Lehtinen and 45 | Jacob Munkberg and 46 | Jon Hasselgren and 47 | Samuli Laine and 48 | Tero Karras and 49 | Miika Aittala and 50 | Timo Aila}, 51 | title = {Noise2Noise: Learning Image Restoration without Clean Data}, 52 | journal = {CoRR}, 53 | volume = {abs/1803.04189}, 54 | year = {2018}, 55 | url = {http://arxiv.org/abs/1803.04189}, 56 | archivePrefix = {arXiv}, 57 | eprint = {1803.04189}, 58 | timestamp = {Mon, 13 Aug 2018 16:46:38 +0200}, 59 | biburl = {https://dblp.org/rec/journals/corr/abs-1803-04189.bib}, 60 | bibsource = {dblp computer science bibliography, https://dblp.org} 61 | } 62 | 63 | @article{ssdn, 64 | author = {Samuli Laine and 65 | Jaakko Lehtinen and 66 | Timo Aila}, 67 | title = {Self-Supervised Deep Image Denoising}, 68 | journal = {CoRR}, 69 | volume = {abs/1901.10277}, 70 | year = {2019}, 71 | url = {http://arxiv.org/abs/1901.10277}, 72 | archivePrefix = {arXiv}, 73 | eprint = {1901.10277}, 74 | timestamp = {Sat, 02 Feb 2019 16:56:00 +0100}, 75 | biburl = {https://dblp.org/rec/journals/corr/abs-1901-10277.bib}, 76 | bibsource = {dblp computer science bibliography, https://dblp.org} 77 | } 78 | 79 | 80 | 81 | @article{unet, 82 | author = {Olaf Ronneberger and 83 | Philipp Fischer and 84 | Thomas Brox}, 85 | title = {U-Net: Convolutional Networks for Biomedical Image Segmentation}, 86 | journal = {CoRR}, 87 | volume = {abs/1505.04597}, 88 | year = {2015}, 89 | url = {http://arxiv.org/abs/1505.04597}, 90 | archivePrefix = {arXiv}, 91 | eprint = {1505.04597}, 92 | timestamp = {Mon, 13 Aug 2018 16:46:52 +0200}, 93 | biburl = {https://dblp.org/rec/journals/corr/RonnebergerFB15.bib}, 94 | bibsource = {dblp computer science bibliography, https://dblp.org} 95 | } 96 | 97 | 98 | @inproceedings{cbm3d, 99 | author = {Dabov, Kostadin and Foi, Alessandro and Katkovnik, Vladimir and Egiazarian, Karen}, 100 | year = {2007}, 101 | month = {09}, 102 | pages = {I - 313 }, 103 | title = {Color Image Denoising via Sparse 3D Collaborative Filtering with Grouping Constraint in Luminance-Chrominance Space}, 104 | volume = {1}, 105 | isbn = {978-1-4244-1437-6}, 106 | doi = {10.1109/ICIP.2007.4378954} 107 | } 108 | 109 | @misc{he2015delving, 110 | title={Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification}, 111 | author={Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, 112 | year={2015}, 113 | eprint={1502.01852}, 114 | archivePrefix={arXiv}, 115 | primaryClass={cs.CV} 116 | } -------------------------------------------------------------------------------- /report/report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising/566c83ad9779b3fc4b0788727da37de5e29f774f/report/report.pdf -------------------------------------------------------------------------------- /report/report.tex: -------------------------------------------------------------------------------- 1 | 2 | \documentclass{article} % For LaTeX2e 3 | \usepackage{iclr2020_conference,times} 4 | 5 | % Optional math commands from https://github.com/goodfeli/dlbook_notation. 6 | \input{math_commands.tex} 7 | 8 | \usepackage{caption} 9 | \usepackage{float} 10 | \usepackage{graphicx} 11 | \usepackage{hyperref} 12 | \usepackage{url} 13 | \usepackage{tikz} 14 | \usepackage{pgfplots} 15 | \usepackage{import} 16 | \usepackage{subfig} 17 | \usepackage{makecell} 18 | \usepackage{booktabs} 19 | 20 | \newcommand\inputpgf[2]{{ 21 | \let\pgfimageWithoutPath\pgfimage 22 | \renewcommand{\pgfimage}[2][]{\pgfimageWithoutPath[##1]{#1/##2}} 23 | \input{#1/#2} 24 | }} 25 | 26 | \title{COMP6248 Reproducibility Challenge\\ 27 | {\Large High-Quality Self-Supervised Deep Image Denoising}} 28 | 29 | \author{David Jones, Richard Crosland \& Jason Barrett \\ 30 | Department of Electronics \& Computer Science \\ 31 | University of Southampton \\ 32 | \texttt{\{dsj1n15, rtc1g16, jb17g16\}@soton.ac.uk} 33 | } 34 | 35 | \newcommand{\fix}{\marginpar{FIX}} 36 | \newcommand{\new}{\marginpar{NEW}} 37 | 38 | \iclrfinalcopy 39 | \begin{document} 40 | 41 | \maketitle 42 | 43 | % \begin{abstract} 44 | % This report gives an account of the reimplementation and testing of the paper `High-Quality Self-Supervised Deep Image Denoising'. It details the implementation process, as well as the obstacles faced, before comparing the results against a subset of those seen in the original paper. The results were reproducible, however, full training runs were not attempted and the reimplementation relied heavily on the submitted paper's supplement and existing codebase as a reference. 45 | % \end{abstract} 46 | 47 | \section{Introduction} 48 | The aim of denoising is to reduce noise in a corrupted image while preserving features. Traditional techniques include median filtering and non-local means. Recently a focus has been placed on training autoencoder neural networks to learn this typically non-linear mapping. Often these models will be trained via supervised methods, requiring clean and noisy image pairs for training; when a clean target is not provided the training is considered to be self-supervised. This report details the reimplementation of the NeurIPS 2019 conference paper `High-Quality Self-Supervised Deep Image Denoising' \citep{ssdn}. It claims significant improvements to both the image quality and training efficiency over other self-supervised methods. An overview of the reimplementation process and a comparison of results against a subset of those seen in the original paper follow. 49 | 50 | \section{Background} 51 | 52 | \subsection{Self-Supervised Deep Image Denoising (SSDN)} 53 | This paper introduces a network architecture that incorporates a blindspot into the receptive field of the convolution and down-sampling layers. This has a resultant effect of the central pixel not being considered as part of the loss function so that the autoencoder does not learn the identity when the target is the same as the input; thus permitting self-supervised learning. This follows on from the concepts of Noise2Void \citep{noise2void}, which uses a mask to replace the central pixel with a different value. At evaluation time it is deemed that the central pixel likely carries relevant information, and is therefore incorporated back into the cleaned image via Bayes' rule, as in the following equation: 54 | \[p(x|y, \Omega_{y}) \propto p(y|x) p(x|\Omega_{y})\] 55 | where $x$ is the clean value of pixel being analysed, $y$ the noisy value, and $\Omega_{y}$ the noisy context. The model learns the value of $p(x|\Omega_{y})$, and the noise model is either given as a parameter to training, or learnt via a second, internal model. The mean value of this inference can then be used, as it minimises the MSE loss, and as such maximises the PSNR. 56 | 57 | \subsection{Baselines} 58 | To supplement results, implementations of baseline models were required for comparison. The following gives a brief overview of the baselines used in the original paper which were also reimplemented: 59 | \begin{itemize} 60 | \item \textbf{CBM3D} \citep{cbm3d} is the colour variant of the block-matching and 3D filtering algorithm used for noise reduction. This is a state-of-the-art non-neural network approach which does not require training. 61 | 62 | \item \textbf{Noise2Clean} (N2C), as named in the original paper, is an autoencoder that uses clean reference images in training alongside standard MSE; this can generally be considered the best-case scenario for training-based approaches. 63 | 64 | \item \textbf{Noise2Noise} (N2N) \citep{noise2noise} is an approach to image denoising that trains using two different noisy versions of the same image. This removes the need for clean references but still has limited applications. 65 | 66 | \item \textbf{Noise2Void} (N2V) \citep{noise2void} is a self-supervised training method that introduces a blindspot via a masking scheme. This involves selecting a certain percentage of pixels in an image patch as centre pixels for sub-patches. From each sub-patch, a random pixel is selected to replace the centre pixel. The training loss is calculated using this mask rather than the whole image. 67 | 68 | \end{itemize} 69 | 70 | % These were also implemented in PyTorch so as to offer a more accurate comparison when tested against our reimplementation. 71 | \section{Implementation Overview} 72 | \label{sec:implementation} 73 | The original paper, for which the source code was available\footnote{Tensorflow Source, https://github.com/NVlabs/selfsupervised-denoising}, was implemented in Tensorflow. The reimplementation\footnote{PyTorch Source, https://github.com/COMP6248-Reproducability-Challenge/selfsupervised-denoising} instead used PyTorch. The original paper's supplement details the network structure, defining a network based on U-Net \citep{unet} with additional layers for handling creation and collation of multiple rotated views of the input, as well as modification to layers to handle upward shifts. These modifications define the blindspot in the network. The network was implemented as per the specification. Where details were not specified, such as the upsampling method, the original source was referred to (it using nearest-neighbour). All other baseline networks used the same U-Net architecture without blindspot additions. The differences between methods, other than the network used, were the inputs/targets and loss calculation pipeline. Data preparation involved adding synthetic noise and padding to shapes accepted by the networks. 74 | 75 | The paper itself details the loss function and integration of the prior. For SSDN this involved prior calculation using either the known $\sigma$, learning of a constant $\sigma$ using a single learnable parameter, or a variable $\sigma$ using a separate U-Net. When the prior is not incorporated this is referred to as SSDN-$\mu$. Other features implemented and used in testing, but not discussed in this report, include support for Poisson noise, diagonal covariance matrices, and single channel inputs. Impulse noise was not implemented. 76 | 77 | Initialisation of weights was handled as directed using \cite{he2015delving}. Initial testing showed when using parameter estimation, the model failed to converge; this was traced in the original source to the last layer using zeroed weights in the parameter estimation network; this was not mentioned in either the paper or supplement. Training conditions were mimicked using the Adam optimiser with default parameters except the learning rate, which was set every iteration following a cosine ramp-down, ramp-up. It should be noted that training used an iteration based approach (not epoch based), where each iteration is a random cropped patch of a randomly sampled image from the training dataset (without replacement). The dataset would then reset once all images are used. The suggested minibatch-size of 4 was used with this batch split across GPUs if available. Tensorboard and checkpointing were implemented alongside the implementation to aid tracking and allow pausing/resuming of long training runs. 78 | 79 | The main implementation issue that occurred was a misunderstanding of how to treat noise addition. The paper expected that synthetic noise is not clipped after addition; this leads to values outside the standard \texttt{uint8} boundary ($<0$ and $>255$). Initial reimplementations clipped these values to closer represent real-world scenarios, this caused performance drops with a known parameter and with parameter estimation to fail completely. These results are attributed to clipping causing truncated noise addition, which would require different prior calculations -- performance for this is not known. Establishing this issue required fine inspection of the original source and was not clear in either the supplement or paper. 80 | 81 | \section{Results} 82 | Due to limited compute infrastructure it was decided to limit result reproduction to a single noise type and level -- Gaussian ($\sigma = 25$); this is one of the main benchmarks provided in the original paper. The original paper typically used 2,000,000 iterations per training, however, due to slower hardware and a slightly slower implementation this would take approximately 7 days per training (versus 16 hours). Therefore the number of iterations used for each model was 200,000 unless stated. Outputs from the models are evaluated by their peak signal-to-noise ratio (PSNR), calculated between the clean (pre-noised) image and the cleaned noisy image. This metric is commonly used to quantify the extent to which an image has been restored after removing noise with higher values suggesting less noise. The training curves for a subset of the models are shown in Figure \ref{fig:training_performance}. Table \ref{tab:eval_psnrs} show calculated metrics for SSDN and all neural-net baselines, with data for CBM3D sourced from \cite{ssdn}. Additional verification of the parameter estimation aspects of the network was performed using a variable $\sigma$ network configuration trained on images with $\sigma$ between 25 and 50. This was trained for 844k iterations and achieved a PSNR of 28.89dB on the BSD300 dataset. 83 | 84 | \begin{figure}[H] 85 | \centering 86 | \captionsetup{justification=centering} 87 | \setlength\tabcolsep{0pt} 88 | \begin{tabular}{cc} 89 | {\inputpgf{figures/}{gauss25_train_psnr_relative.pgf}} & 90 | \hspace{-7.5mm} 91 | {\inputpgf{figures/}{gauss25_val_psnr_relative.pgf}} 92 | \end{tabular} 93 | \vspace{-5mm} 94 | \caption{Training and validation curves comparing PSNRs [5th-100th percentile] of implementation versus baselines. Gaussian noise ($\sigma=25$). Training (ImageNet-Validation) [left], Validation (BSD300-Test) [right]. N.B. Variable $\sigma$ estimation used for SSDN.} 95 | \label{fig:training_performance} 96 | \end{figure} 97 | 98 | 99 | \begin{table}[H] 100 | \captionsetup{justification=centering} 101 | \caption{Validation PSNRs for denoising images from both the Kodak and BSD300 dataset with Gaussian noise ($\sigma = 25$). See Figure \ref{fig:images-psnr} for example image comparison.} 102 | \label{tab:eval_psnrs} 103 | \begin{center} 104 | \begin{tabular}{lllll} 105 | \toprule 106 | Method & $\sigma$ known? & KODAK & BSD300 & Average \\ 107 | \midrule 108 | CBM3D (Untrained Baseline) & no & 31.81 & 30.40 & 31.11 \\ 109 | N2C (Baseline) & no & 32.19 & 31.22 & 31.71 \\ 110 | N2N (Baseline) & no & 32.16 & 31.20 & 31.68 \\ 111 | N2V (Baseline) & no & 31.03 & 30.24 & 30.64 \\ 112 | SSDN ($\mu$ only) & no & 30.00 & 28.56 & 29.28 \\ 113 | SSDN & no & 31.61 & 30.55 & 31.08 \\ 114 | SSDN & yes & 32.12 & 31.13 & 31.63 \\ 115 | \bottomrule 116 | \end{tabular} 117 | \end{center} 118 | \end{table} 119 | \vspace{-5mm} 120 | \begin{figure}[H] 121 | \centering 122 | \setlength{\fboxsep}{0pt}% 123 | \setlength{\fboxrule}{0.5pt}% 124 | \captionsetup{justification=centering} 125 | \setlength\tabcolsep{3pt} 126 | \begin{tabular}{cccccc} 127 | \fbox{\includegraphics[height=0.14\textwidth]{images/Clean.png}} & 128 | \fbox{\includegraphics[height=0.14\textwidth]{images/Noisy_cropped.png}} & 129 | \fbox{\includegraphics[height=0.14\textwidth]{images/CBM3D_cropped.png}} & 130 | \fbox{\includegraphics[height=0.14\textwidth]{images/N2C_cropped.png}} & 131 | \fbox{\includegraphics[height=0.14\textwidth]{images/SSDN_mu_cropped.png}} & 132 | \fbox{\includegraphics[height=0.14\textwidth]{images/SSDN_cropped.png}} \\ 133 | \makecell{Test Image\\ (KODAK-2)} & \makecell{Input\\ 20.59 dB} & \makecell{CBM3D\\ 32.44 dB} & \makecell{N2C\\ 32.80 dB} & \makecell{SSDN ($\mu$)\\ 31.38 dB} & \makecell{SSDN\\32.73 dB} 134 | \end{tabular} 135 | \caption{Results of applying SSDN and Baselines to Gaussian ($\sigma=25$ noise [$\sigma$ known])} 136 | \label{fig:images-psnr} 137 | \end{figure} 138 | 139 | 140 | \section{Discussion} 141 | Despite constraints in breadth of models trained and training durations, a range of results were reproduced from the original paper with relative comparisons against the baselines showing the expected trends. Results in Table \ref{tab:eval_psnrs} correspond to the results shown in Table 1 of the original paper, covering validation performance of: CBM3D, N2C, N2N and different configurations of SSDN. A major claim is that SSDN's training performance matches that of N2C; this was reproducible even when $\sigma$ was not known as can be seen in Figure \ref{fig:training_performance}. Likewise, visual checks using Figure \ref{fig:images-psnr} show very little detail difference between N2C and SSDN. Validation performance with $\sigma$ known also manages to exceed that of CBM3D as expected. In validation, SSDN is expected to achieve an average PSNR of 31.73dB (on BSD300 and KODAK); with a known sigma across the same datasets; the reproduced results show a similar 31.63dB average PSNR, only 0.1dB less with one tenth of the training iterations. Since the number of training iterations in these experiments is lower than that used in the original paper, it is not possible to verify the performance of the fully trained network, however the trend of all plots indicate that further training would keep increasing PSNR for all models. One discrepancy noted between the paper and its results, as well as the reproduced results, is a claim made in the abstract that the methods introduced improve images quality, however they are only ever seen to at best match the state-of-the-art results, rather that exceed them. 142 | 143 | One key claim of this paper is that it achieves improved training efficiency compared to state-of-the-art self-supervised denoising methods. The original paper suggests training the N2V baseline by maintaining a smoothed network created from an exponential moving average of previous weights; this is not suggested by \cite{noise2void}. The reproduced baseline results (Figure \ref{fig:training_performance}), which do not do this, indicate that the implementation of SSDN does indeed learn at a higher rate (PSNR achieved in same iterations) than N2V. However, it should be noted that the training time per iteration for SSDN was $4\times$ that of N2V due to loss calculation and the $4\times$ larger network, it is therefore unclear if real-world training performance is improved. 144 | 145 | To verify the implementation, training performance was captured in the same configuration using the original Tensorflow implementation (SSDN-TF in Figure \ref{fig:training_performance}). Although training PSNRs tracked almost exactly, the validation PSNRs from the Tensorflow implementation decreased significantly after approximately 65,000 iterations (this was repeatable for default seed setups). It is not clear why this occurred but may be a quirk in training or highlight a potential issue in their original implementation that was not carried to the PyTorch implementation. 146 | 147 | \section{Conclusion} 148 | % Is it a reproducible paper? 149 | When reading the paper, the reader is provided with a detailed account of the theory behind the workings of the paper, as well as justifications for the changes made to the previous solutions to the problem. The provision of the supplement also gave more specific details about the model itself that would not be crucial to the theory, namely details about testing, and the architecture of the network used. The combination of paper and supplement alone was not enough to fully reproduce the results. The authors' code was also required, shedding light on the clipping issue discussed at the end of Section \ref{sec:implementation}. In conclusion, fully utilising a combination of sources allowed for a successful reimplementation and reproduction of results. 150 | 151 | \bibliography{references} 152 | \bibliographystyle{iclr2020_conference} 153 | 154 | \end{document} 155 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 -------------------------------------------------------------------------------- /ssdn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | exec(open("ssdn/version.py").read()) 4 | 5 | setup( 6 | name="ssdn", 7 | version=__version__, # noqa 8 | packages=find_packages(), 9 | entry_points={"console_scripts": ["ssdn = ssdn.__main__:start_cli"]}, 10 | install_requires=[ 11 | "nptyping", 12 | "h5py", 13 | "imagesize", 14 | "overrides", 15 | "colorlog", 16 | "colored_traceback", 17 | "tqdm" 18 | ], 19 | ) 20 | -------------------------------------------------------------------------------- /ssdn/ssdn/__init__.py: -------------------------------------------------------------------------------- 1 | import ssdn.utils as utils 2 | import ssdn.logging_helper as logging_helper 3 | import ssdn.cfg as cfg 4 | -------------------------------------------------------------------------------- /ssdn/ssdn/__main__.py: -------------------------------------------------------------------------------- 1 | """Main method for interacting with denoiser through CLI. 2 | """ 3 | 4 | import sys 5 | import ssdn 6 | import ssdn.cli 7 | 8 | from typing import List 9 | 10 | 11 | def start_cli(args: List[str] = None): 12 | ssdn.logging_helper.setup() 13 | if args is not None: 14 | sys.argv[1:] = args 15 | ssdn.cli.start() 16 | 17 | 18 | if __name__ == "__main__": 19 | start_cli() 20 | -------------------------------------------------------------------------------- /ssdn/ssdn/cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ssdn.params import ConfigValue, DatasetType, NoiseAlgorithm, Pipeline 4 | from typing import Dict 5 | 6 | 7 | DEFAULT_RUN_DIR = "runs" 8 | 9 | 10 | def base(): 11 | return { 12 | ConfigValue.TRAIN_ITERATIONS: 2000000, 13 | ConfigValue.TRAIN_MINIBATCH_SIZE: 4, 14 | ConfigValue.TEST_MINIBATCH_SIZE: 2, 15 | ConfigValue.IMAGE_CHANNELS: 3, 16 | ConfigValue.TRAIN_PATCH_SIZE: 64, 17 | ConfigValue.LEARNING_RATE: 3e-4, 18 | ConfigValue.LR_RAMPDOWN_FRACTION: 0.1, 19 | ConfigValue.LR_RAMPUP_FRACTION: 0.3, 20 | ConfigValue.EVAL_INTERVAL: 10000, 21 | ConfigValue.PRINT_INTERVAL: 1000, 22 | ConfigValue.SNAPSHOT_INTERVAL: 10000, 23 | ConfigValue.DATALOADER_WORKERS: 4, 24 | ConfigValue.PIN_DATA_MEMORY: False, 25 | ConfigValue.DIAGONAL_COVARIANCE: False, 26 | ConfigValue.TRAIN_DATA_PATH: None, 27 | ConfigValue.TRAIN_DATASET_TYPE: None, 28 | ConfigValue.TRAIN_DATASET_NAME: None, 29 | ConfigValue.TEST_DATA_PATH: None, 30 | ConfigValue.TEST_DATASET_TYPE: None, 31 | ConfigValue.TEST_DATASET_NAME: None, 32 | } 33 | 34 | 35 | class DatasetName: 36 | BSD = "bsd" 37 | IMAGE_NET = "ilsvrc" 38 | KODAK = "kodak" 39 | SET14 = "set14" 40 | 41 | 42 | def infer_datasets(cfg: Dict): 43 | """For training and test dataset parameters infer from the path the name of 44 | the dataset being targetted and whether or not the data should be loaded as 45 | a h5 file or a folder. 46 | 47 | Args: 48 | cfg (Dict): Configuration to infer for. 49 | """ 50 | 51 | def infer_dname(path: str): 52 | # Look for part of dataset name in path for guessing dataset 53 | dataset_dict = { 54 | "BSDS300": DatasetName.BSD, 55 | "ILSVRC": DatasetName.IMAGE_NET, 56 | "KODAK": DatasetName.KODAK, 57 | "SET14": DatasetName.SET14, 58 | } 59 | potentials = [] 60 | for key, name in dataset_dict.items(): 61 | if key.lower() in path.lower(): 62 | potentials += [name] 63 | if len(potentials) == 0: 64 | raise ValueError("Could not infer dataset from path.") 65 | if len(potentials) > 1: 66 | raise ValueError("Matched multiple datasets with dataset path.") 67 | return potentials[0] 68 | 69 | def infer_dtype(path: str): 70 | # Treat files as HDF5 and directories as folders 71 | dtype = DatasetType.FOLDER if os.path.isdir(path) else DatasetType.HDF5 72 | return dtype 73 | 74 | # Infer for training set 75 | if cfg.get(ConfigValue.TRAIN_DATA_PATH, None) is not None: 76 | if cfg.get(ConfigValue.TRAIN_DATASET_NAME, None) is None: 77 | cfg[ConfigValue.TRAIN_DATASET_NAME] = infer_dname( 78 | cfg[ConfigValue.TRAIN_DATA_PATH] 79 | ) 80 | if cfg.get(ConfigValue.TRAIN_DATASET_TYPE, None) is None: 81 | cfg[ConfigValue.TRAIN_DATASET_TYPE] = infer_dtype( 82 | cfg[ConfigValue.TRAIN_DATA_PATH] 83 | ) 84 | # Infer for testing/validation set 85 | if cfg.get(ConfigValue.TEST_DATA_PATH, None) is not None: 86 | if cfg.get(ConfigValue.TEST_DATASET_NAME, None) is None: 87 | cfg[ConfigValue.TEST_DATASET_NAME] = infer_dname( 88 | cfg[ConfigValue.TEST_DATA_PATH] 89 | ) 90 | if cfg.get(ConfigValue.TEST_DATASET_TYPE, None) is None: 91 | cfg[ConfigValue.TEST_DATASET_TYPE] = infer_dtype( 92 | cfg[ConfigValue.TEST_DATA_PATH] 93 | ) 94 | 95 | 96 | def test_length(dataset_name: str) -> int: 97 | """To give meaningful PSNR results similar amounts of data should be evaluated. 98 | Return the test length based on image size and image count. Note that for all 99 | datasets it is assumed the test dataset is being used. 100 | 101 | Args: 102 | dataset_name (str): Name of the dataset (BSD...), 103 | 104 | Returns: 105 | int: Image count to test for. When higher than the dataset length existing 106 | images should be reused. 107 | """ 108 | mapping = { 109 | DatasetName.BSD: 300, # 3 x Testset Length 110 | DatasetName.KODAK: 240, # 10 x Testset Length 111 | DatasetName.SET14: 280, # 20 x Testset Length 112 | } 113 | return mapping[dataset_name] 114 | 115 | 116 | def infer_pipeline(algorithm: NoiseAlgorithm) -> Pipeline: 117 | if algorithm in [NoiseAlgorithm.SELFSUPERVISED_DENOISING]: 118 | return Pipeline.SSDN 119 | elif algorithm in [ 120 | NoiseAlgorithm.SELFSUPERVISED_DENOISING_MEAN_ONLY, 121 | NoiseAlgorithm.NOISE_TO_NOISE, 122 | NoiseAlgorithm.NOISE_TO_CLEAN, 123 | ]: 124 | return Pipeline.MSE 125 | elif algorithm in [NoiseAlgorithm.NOISE_TO_VOID]: 126 | return Pipeline.MASK_MSE 127 | else: 128 | raise NotImplementedError("Algorithm does not have a default pipeline.") 129 | 130 | 131 | def infer_blindspot(algorithm: NoiseAlgorithm): 132 | if algorithm in [ 133 | NoiseAlgorithm.SELFSUPERVISED_DENOISING, 134 | NoiseAlgorithm.SELFSUPERVISED_DENOISING_MEAN_ONLY, 135 | ]: 136 | return True 137 | elif algorithm in [ 138 | NoiseAlgorithm.NOISE_TO_NOISE, 139 | NoiseAlgorithm.NOISE_TO_CLEAN, 140 | NoiseAlgorithm.NOISE_TO_VOID, 141 | ]: 142 | return False 143 | else: 144 | raise NotImplementedError("Not known if algorithm requires blindspot.") 145 | 146 | 147 | def infer(cfg: Dict, model_only: bool = False) -> Dict: 148 | if cfg.get(ConfigValue.PIPELINE, None) is None: 149 | cfg[ConfigValue.PIPELINE] = infer_pipeline(cfg[ConfigValue.ALGORITHM]) 150 | if cfg.get(ConfigValue.BLINDSPOT, None) is None: 151 | cfg[ConfigValue.BLINDSPOT] = infer_blindspot(cfg[ConfigValue.ALGORITHM]) 152 | 153 | if not model_only: 154 | infer_datasets(cfg) 155 | return cfg 156 | 157 | 158 | def config_name(cfg: Dict) -> str: 159 | cfg = infer(cfg) 160 | config_lst = [cfg[ConfigValue.ALGORITHM].value] 161 | 162 | # Check if pipeline cannot be inferred 163 | inferred_pipeline = infer_pipeline(cfg[ConfigValue.ALGORITHM]) 164 | if cfg[ConfigValue.PIPELINE] != inferred_pipeline: 165 | config_lst += [cfg[ConfigValue.PIPELINE].value + "_pipeline"] 166 | # Check if blindspot enable cannot be inferred 167 | inferred_blindspot = infer_blindspot(cfg[ConfigValue.ALGORITHM]) 168 | if cfg[ConfigValue.BLINDSPOT] != inferred_blindspot: 169 | config_lst += [ 170 | "blindspot" if cfg[ConfigValue.BLINDSPOT] else "blindspot_disabled" 171 | ] 172 | # Add noise information 173 | config_lst += [cfg[ConfigValue.NOISE_STYLE]] 174 | if cfg[ConfigValue.PIPELINE] in [Pipeline.SSDN]: 175 | config_lst += ["sigma_" + cfg[ConfigValue.NOISE_VALUE].value] 176 | 177 | if cfg[ConfigValue.IMAGE_CHANNELS] == 1: 178 | config_lst += ["mono"] 179 | if cfg[ConfigValue.PIPELINE] in [Pipeline.SSDN]: 180 | if cfg[ConfigValue.DIAGONAL_COVARIANCE]: 181 | config_lst += ["diag"] 182 | 183 | config_name = "-".join(config_lst) 184 | return config_name 185 | -------------------------------------------------------------------------------- /ssdn/ssdn/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from ssdn.cli.cli import * 4 | -------------------------------------------------------------------------------- /ssdn/ssdn/cli/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ssdn.version import __version__ 4 | 5 | from ssdn.cli.cmds.train import TrainCommand 6 | from ssdn.cli.cmds.eval import EvaluateCommand 7 | 8 | 9 | def start(): 10 | parser = argparse.ArgumentParser( 11 | prog="ssdn", 12 | description=( 13 | "Command line interface for the denoising training and evaluation system. " 14 | + "Supported algorithms include Self Supervised Denoising (SSDN), Noise2Clean, " 15 | + "Noise2Void, and Noise2Noise." 16 | ), 17 | ) 18 | parser.add_argument( 19 | "--version", action="version", version="%(prog)s v" + __version__ 20 | ) 21 | 22 | cmd_parsers = parser.add_subparsers(dest="command", required=True) 23 | # Populate available commands 24 | cmd_list = [ 25 | TrainCommand(), 26 | EvaluateCommand(), 27 | ] 28 | 29 | # Add commands to parser 30 | cmds = {} 31 | for cmd in cmd_list: 32 | cmd.configure(cmd_parsers) 33 | cmds[cmd.cmd()] = cmd 34 | # Process arguments 35 | args = parser.parse_args() 36 | arg_dict = vars(args) 37 | arg_dict["PARSER"] = parser 38 | 39 | # Call handle on function 40 | cmds[args.command].execute(arg_dict) 41 | 42 | 43 | if __name__ == "__main__": 44 | start() 45 | -------------------------------------------------------------------------------- /ssdn/ssdn/cli/cmds/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from ssdn.cli.cmds.cmd import Command 4 | -------------------------------------------------------------------------------- /ssdn/ssdn/cli/cmds/cmd.py: -------------------------------------------------------------------------------- 1 | """Base command case for CLI. 2 | """ 3 | 4 | from argparse import _SubParsersAction 5 | from abc import ABCMeta, abstractmethod 6 | from overrides import EnforceOverrides 7 | from typing import Dict 8 | 9 | 10 | class Command(EnforceOverrides, metaclass=ABCMeta): 11 | """Base class for system commmands. `configure` method should be implemented 12 | to attach the required command(s) to the provided parser. The command behaviour 13 | should run using the ``execute`` command with the parsed arguments. 14 | """ 15 | 16 | def __init__(self): 17 | self.device_manager = None 18 | 19 | @abstractmethod 20 | def configure(self, parser: _SubParsersAction): 21 | """Add the command and any subparsing to the provided parser. 22 | 23 | Args: 24 | parser (_SubParsersAction): A parser to add the command to. 25 | """ 26 | pass 27 | 28 | @abstractmethod 29 | def execute(self, args: Dict): 30 | """Execute command behaviour with parsed arguments. 31 | 32 | Args: 33 | args (Dict): Dictionary of parsed arguments including those from any parent 34 | argument parsers. 35 | """ 36 | pass 37 | 38 | @abstractmethod 39 | def cmd(self) -> str: 40 | """The string that executes the command. 41 | 42 | Returns: 43 | str: Command - case sensitive. 44 | """ 45 | pass 46 | -------------------------------------------------------------------------------- /ssdn/ssdn/cli/cmds/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from overrides import overrides 4 | from argparse import _SubParsersAction 5 | from typing import Dict 6 | 7 | from ssdn.cli.cmds import Command 8 | from ssdn.eval import DenoiserEvaluator 9 | from ssdn.cfg import DEFAULT_RUN_DIR 10 | from ssdn.params import ConfigValue 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class EvaluateCommand(Command): 16 | @overrides 17 | def configure(self, parser: _SubParsersAction): 18 | cmd_parser = parser.add_parser(self.cmd(), help="Evaluate a pre-trained model.") 19 | cmd_parser.add_argument( 20 | "--model", 21 | "-m", 22 | required=True, 23 | help="Path to model weights or training file.", 24 | ) 25 | cmd_parser.add_argument( 26 | "--dataset", 27 | "-d", 28 | required=True, 29 | help="Path to either a hdf5 file generated by 'dataset_tool_h5.py' or a folder of images.", 30 | ) 31 | cmd_parser.add_argument( 32 | "--runs_dir", 33 | default=DEFAULT_RUN_DIR, 34 | help="Directory in which the output directory is generated." 35 | ) 36 | cmd_parser.add_argument( 37 | "--batch_size", 38 | type=int, 39 | help="Batch size to use, will default to that used while training.", 40 | ) 41 | 42 | @overrides 43 | def execute(self, args: Dict): 44 | evaluator = DenoiserEvaluator(args["model"], runs_dir=args["runs_dir"]) 45 | if args.get("batch_size", None) is not None: 46 | evaluator.cfg[ConfigValue.TEST_MINIBATCH_SIZE] = args["batch_size"] 47 | evaluator.set_test_data(args["dataset"]) 48 | evaluator.evaluate() 49 | 50 | @overrides 51 | def cmd(self) -> str: 52 | return "eval" 53 | -------------------------------------------------------------------------------- /ssdn/ssdn/cli/cmds/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import ssdn 3 | 4 | from overrides import overrides 5 | from argparse import _SubParsersAction 6 | from typing import Dict 7 | 8 | from ssdn.cli.cmds import Command 9 | from ssdn.params import NoiseAlgorithm, NoiseValue, ConfigValue 10 | from ssdn.train import resume_run, DenoiserTrainer 11 | from ssdn.cfg import DEFAULT_RUN_DIR 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class TrainCommand(Command): 17 | 18 | START_CMD = "start" 19 | RESUME_CMD = "resume" 20 | 21 | @overrides 22 | def configure(self, parser: _SubParsersAction): 23 | cmd_parser = parser.add_parser( 24 | self.cmd(), help="Train or resume training of a Denosier model. " 25 | ) 26 | # Split behaviour into start and resume 27 | action_parsers = cmd_parser.add_subparsers(dest="train_cmd", required=True) 28 | start_parser = action_parsers.add_parser( 29 | TrainCommand.START_CMD, help="start command help" 30 | ) 31 | self.add_shared_args(start_parser, True) 32 | start_parser.add_argument( 33 | "--algorithm", 34 | "-a", 35 | required=True, 36 | help="The algorithm to train.", 37 | choices=[c.value for c in ssdn.utils.list_constants(NoiseAlgorithm)], 38 | ) 39 | start_parser.add_argument( 40 | "--noise_style", 41 | "-n", 42 | required=True, 43 | help="Noise style using a string configuration in the format: {noise_type}{args} " 44 | + "where {args} are the arguments passed to the noise function. The formats for the " 45 | + "supported noise types include 'gauss{SD}', 'gauss{MIN_SD}_{MAX_SD}', 'poisson{LAMBDA}' " 46 | + "'poisson{MIN_LAMBDA}_{MAX_LAMBDA}'. If parameters contain a decimal point they are " 47 | + "treated as floats. This means the underlying noise adding method will not attempt to " 48 | + "scale them (/ 255). By default noise algorithms will use truncated versions (clipped " 49 | + "prior to training) - Append '_nc' to avoid this.", 50 | ) 51 | start_parser.add_argument( 52 | "--noise_value", 53 | help="[SSDN] Whether the noise value should be estimated.", 54 | choices=[c.value for c in ssdn.utils.list_constants(NoiseValue)], 55 | ) 56 | start_parser.add_argument( 57 | "--mono", 58 | action="store_true", 59 | help="Convert input data to greyscale (single channel).", 60 | ) 61 | start_parser.add_argument( 62 | "--diagonal", 63 | action="store_true", 64 | help="[SSDN] Enforce a diagonal covariance matrix.", 65 | ) 66 | 67 | start_parser.add_argument( 68 | "--runs_dir", 69 | default=DEFAULT_RUN_DIR, 70 | help="Directory in which the output directory is generated." 71 | ) 72 | 73 | resume_parser = action_parsers.add_parser( 74 | TrainCommand.RESUME_CMD, 75 | help="Resume the training of a model. Note that configuration arguments " 76 | + "used on start are valid when resuming but may cause undefined behaviour - " 77 | + "use these for redefining data locations if needed.", 78 | ) 79 | resume_parser.add_argument( 80 | "run_dir", 81 | help="Path to run directory to resume, the latest *.training file will be used.", 82 | ) 83 | 84 | self.add_shared_args(resume_parser, False) 85 | 86 | def add_shared_args(self, parser: _SubParsersAction, start: bool): 87 | parser.add_argument( 88 | "--train_dataset", 89 | "-t", 90 | required=start, 91 | help="Path to training dataset. This can be either a hdf5 file generated by " 92 | + "'dataset_tool_h5.py' or a folder of images. Note that images smaller than " 93 | + "the patch size will be padded using reflection when a folder is used.", 94 | ) 95 | parser.add_argument( 96 | "--validation_dataset", 97 | "-v", 98 | help="Path to validation dataset. This can be either a hdf5 file generated by " 99 | + "'dataset_tool_h5.py' or a folder of images.", 100 | ) 101 | 102 | parser.add_argument( 103 | "--iterations", 104 | "-i", 105 | required=start, 106 | type=int, 107 | help="Number of iterations (input images) to train for.", 108 | ) 109 | parser.add_argument( 110 | "--eval_interval", 111 | type=int, 112 | help="Number of iterations between evaluations. Should be divisible by " 113 | + "training batch size.", 114 | ) 115 | parser.add_argument( 116 | "--checkpoint_interval", 117 | type=int, 118 | help="Number of iterations between saving checkpoints. Should be divisible by " 119 | + "training batch size.", 120 | ) 121 | parser.add_argument( 122 | "--print_interval", 123 | type=int, 124 | help="Number of iterations between printing ongoing results to command line and " 125 | + "Tensorboard, should be divisible by training batch size.", 126 | ) 127 | parser.add_argument( 128 | "--train_batch_size", 129 | type=int, 130 | help="Batch size to use for training images.", 131 | ) 132 | parser.add_argument( 133 | "--validation_batch_size", 134 | type=int, 135 | help="Batch size to use for validation images.", 136 | ) 137 | parser.add_argument( 138 | "--patch_size", type=int, help="Patch size to use for training (square).", 139 | ) 140 | 141 | @overrides 142 | def execute(self, args: Dict): 143 | if args["train_cmd"] == "start": 144 | if args["algorithm"] == "ssdn" and args.get("noise_value", None) == None: 145 | args["PARSER"].error("SSDN requires --noise_value") 146 | cfg = ssdn.cfg.base() 147 | if args.get("algorithm", None) is not None: 148 | cfg[ConfigValue.ALGORITHM] = NoiseAlgorithm(args["algorithm"]) 149 | if args.get("noise_style", None) is not None: 150 | cfg[ConfigValue.NOISE_STYLE] = args["noise_style"] 151 | if args.get("noise_value", None) is not None: 152 | cfg[ConfigValue.NOISE_VALUE] = NoiseValue(args["noise_value"]) 153 | if args.get("mono", None) is not None: 154 | cfg[ConfigValue.IMAGE_CHANNELS] = 1 155 | if args.get("diagonal", None) is not None: 156 | cfg[ConfigValue.TRAIN_ITERATIONS] = args["diagonal"] 157 | trainer = DenoiserTrainer(cfg, runs_dir=args["runs_dir"]) 158 | elif args["train_cmd"] == "resume": 159 | trainer = resume_run(args["run_dir"]) 160 | else: 161 | raise NotImplementedError("Invalid train command") 162 | 163 | # Handle shared args 164 | if args.get("train_dataset", None) is not None: 165 | trainer.set_train_data(args["train_dataset"]) 166 | if args.get("validation_dataset", None) is not None: 167 | trainer.set_test_data(args["validation_dataset"]) 168 | if args.get("iterations", None) is not None: 169 | cfg[ConfigValue.TRAIN_ITERATIONS] = args["iterations"] 170 | if args.get("eval_interval", None) is not None: 171 | cfg[ConfigValue.EVAL_INTERVAL] = args["eval_interval"] 172 | if args.get("checkpoint_interval", None) is not None: 173 | cfg[ConfigValue.SNAPSHOT_INTERVAL] = args["checkpoint_interval"] 174 | if args.get("print_interval", None) is not None: 175 | cfg[ConfigValue.PRINT_INTERVAL] = args["print_interval"] 176 | if args.get("train_batch_size", None) is not None: 177 | cfg[ConfigValue.TRAIN_MINIBATCH_SIZE] = args["train_batch_size"] 178 | if args.get("validation_batch_size", None) is not None: 179 | cfg[ConfigValue.TEST_MINIBATCH_SIZE] = args["validation_batch_size"] 180 | if args.get("patch_size", None) is not None: 181 | cfg[ConfigValue.TRAIN_PATCH_SIZE] = args["patch_size"] 182 | 183 | # Start the training 184 | trainer.train() 185 | 186 | @overrides 187 | def cmd(self) -> str: 188 | return "train" 189 | -------------------------------------------------------------------------------- /ssdn/ssdn/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from ssdn.datasets.folder import UnlabelledImageFolderDataset 2 | from ssdn.datasets.hdf5 import HDF5Dataset 3 | from ssdn.datasets.sampler import FixedLengthSampler, SamplingOrder 4 | from ssdn.datasets.noise_wrapper import NoisyDataset 5 | -------------------------------------------------------------------------------- /ssdn/ssdn/datasets/folder.py: -------------------------------------------------------------------------------- 1 | """Contains custom dataset for loading unlabelled images from a folder. 2 | """ 3 | __authors__ = "David Jones " 4 | 5 | import torch 6 | import torchvision.transforms.functional as F 7 | import os 8 | import glob 9 | import tempfile 10 | import string 11 | import imagesize 12 | import ssdn 13 | 14 | from torch import Tensor 15 | from ssdn.utils.transforms import Transform 16 | from ssdn.utils.data_format import DataFormat, PIL_FORMAT, permute_tuple 17 | from torch.utils.data import Dataset 18 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 19 | from PIL import Image 20 | from typing import List, Tuple 21 | 22 | 23 | class UnlabelledImageFolderDataset(Dataset): 24 | """Custom dataset using behaviour similar to Torchvision's `ImageFolder` with the 25 | difference of not expecting subfolders of labelled data. By default outputs are three 26 | channel tensors where channels are split at the top structure, i.e. 27 | [[[R, R]], [[G, G]], [[B, B, B]]]. 28 | 29 | Args: 30 | dir_path (str): Root directory to load images from. 31 | extensions (str, optional): Image extensions to match on. Defaults to those used 32 | by Torchvision's `ImageFolder` dataset. These will always be matched in a case 33 | insensitive manner. Duplicate extensions will be removed. 34 | transform (Transform, optional): A custom transform to apply after loading. 35 | A PIL input will be fed into this transform. A Tensor conversion operation 36 | will always occur after. Defaults to None. 37 | recursive (bool, optional): Whether to search folders recursively for images. 38 | Defaults to False. 39 | output_format (str, optional): Data format to output data in, if None the default 40 | format used by PyTorch will be used. Defaults to DataFormat.CHW. 41 | channels (int, optional): Number of output channels (1 or 3). If the loaded 42 | image is 1 channel and 3 channels are required the single channel 43 | will be copied across each channel. If 3 channels are loaded and 1 channel 44 | is required a weighted RGB to L conversion occurs. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | dir_path: str, 50 | extensions: List = IMG_EXTENSIONS, 51 | transform: Transform = None, 52 | recursive: bool = False, 53 | output_format: str = DataFormat.CHW, 54 | channels: int = 3, 55 | ): 56 | super(UnlabelledImageFolderDataset, self).__init__() 57 | self.dir_path = dir_path 58 | self.files = sorted(find_files(dir_path, extensions, recursive)) 59 | self.cache = {} 60 | self.loader = default_loader 61 | self.transform = transform 62 | self.output_format = output_format 63 | assert channels in [1, 3] 64 | self.channels = channels 65 | if len(self.files) == 0: 66 | raise ( 67 | RuntimeError( 68 | "Found 0 files in directory: " + self.dir_path + "\n" 69 | "Supported extensions are: " + ",".join(extensions) 70 | ) 71 | ) 72 | 73 | def __getitem__(self, index: int) -> Tuple[Tensor, int]: 74 | path = self.files[index] 75 | img = self.loader(path) 76 | img = ssdn.utils.set_color_channels(img, self.channels) 77 | # Apply custom transform 78 | if self.transform: 79 | img = self.transform(img) 80 | # Convert to tensor if this hasn't be done during the transform 81 | if not isinstance(img, torch.Tensor): 82 | img = F.to_tensor(img) 83 | if self.output_format is not None: 84 | img = img.permute(permute_tuple(PIL_FORMAT, self.output_format)) 85 | return img, index 86 | 87 | def image_size(self, index: int, ignore_transform: bool = False) -> Tensor: 88 | """Quick method to check image size using header information. Note that if a 89 | transform is in place then the data must be loaded directly from the dataset 90 | to ensure the transform has not changed the shape. 91 | 92 | Args: 93 | index (int): Index of image in dataset. 94 | ignore_transform (bool, optional): Whether the transform is known not to 95 | affect the output size. This will cause the true image size to always 96 | be returned. Defaults to False. 97 | 98 | Returns: 99 | Tensor: Shape tensor in output data format 100 | """ 101 | # Check if quick method viable 102 | if self.transform is not None and not ignore_transform: 103 | return torch.tensor(self.__getitem__(index)[0].shape) 104 | # Can use quick method 105 | path = self.files[index] 106 | size = imagesize.get(path) 107 | cwh = torch.tensor(tuple((self.channels, *size))) 108 | return cwh[list(permute_tuple(DataFormat.CWH, self.output_format))] 109 | 110 | def __len__(self): 111 | return len(self.files) 112 | 113 | 114 | def is_fs_case_sensitive() -> bool: 115 | """Check if file system is case sensitive using a temporary file. The result will 116 | be cached for future calls. See `https://stackoverflow.com/a/36580834`. 117 | 118 | Returns: 119 | bool: Whether the file system is case insensitive. 120 | """ 121 | if not hasattr(is_fs_case_sensitive, "case_sensitive"): 122 | with tempfile.NamedTemporaryFile(prefix="TmP") as tmp_file: 123 | setattr( 124 | is_fs_case_sensitive, 125 | "case_sensitive", 126 | not os.path.exists(tmp_file.name.lower()), 127 | ) 128 | return is_fs_case_sensitive.case_sensitive 129 | 130 | 131 | def case_insensitive_extensions(extensions: List[str]) -> List[str]: 132 | """In the case that the current file system is case sensitive, this method will 133 | map the provided list of extensions from the form ".PNG" to ".[pP][nN][gG]". When 134 | used with glob case will be ignored on any file system. Any duplicated match patterns 135 | generated will be removed to ensure files will not be matched twice. 136 | 137 | Args: 138 | extensions (List[str]): List of file extensions. 139 | 140 | Returns: 141 | List[str]: Extensions formatted for case insensitivity. The original order may 142 | not be preserved. If the file system is not case sensitive the original list 143 | is returned untouched. 144 | """ 145 | 146 | def mp(char: str) -> str: 147 | if char not in string.ascii_letters: 148 | return char 149 | return "[{}{}]".format(char.lower(), char.upper()) 150 | 151 | extensions = [ext.lower() for ext in extensions] 152 | # Add both lower and uppercase versions when file system case sensitive 153 | if is_fs_case_sensitive(): 154 | cs_extensions = [] 155 | for extension in extensions: 156 | cs_extensions += ["".join(map(mp, extension))] 157 | extensions = cs_extensions 158 | # Remove duplicates 159 | return list(set(extensions)) 160 | 161 | 162 | def find_files( 163 | dir_path: str, extensions: List[str], recursive: bool, case_insensitive: bool = True 164 | ) -> List[str]: 165 | """Structured glob match for finding files ending with a given extension. These 166 | extensions by default will be treated as case insensitive on all file systems. 167 | 168 | Args: 169 | dir_path (str): Root directory path to search from. 170 | extensions (List[str]): List of extensions to match. These can be prefixed with 171 | a '.' but this is not required. 172 | recursive (bool): Whether to search folders recursively for matches. 173 | case_insensitive (bool, optional): Whether to force the file system to ignore 174 | case. Defaults to True. 175 | 176 | Returns: 177 | List[str]: List of matched files. 178 | """ 179 | if case_insensitive: 180 | extensions = case_insensitive_extensions(extensions) 181 | star_match = os.path.join("**", "*") if recursive else "*" 182 | files = [] 183 | for ext in extensions: 184 | if ext[0] != ".": 185 | ext = "." + ext 186 | path = os.path.join(dir_path, star_match + ext) 187 | files.extend(glob.glob(path, recursive=recursive)) 188 | return files 189 | -------------------------------------------------------------------------------- /ssdn/ssdn/datasets/hdf5.py: -------------------------------------------------------------------------------- 1 | """Contains custom dataset for loading from files created by `dataset_tool_h5.py`. 2 | """ 3 | __authors__ = "David Jones " 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | import ssdn 9 | import torchvision.transforms.functional as F 10 | 11 | from PIL import Image 12 | from torch import Tensor 13 | from torch.utils.data import Dataset 14 | from ssdn.utils.transforms import Transform 15 | from ssdn.utils.data_format import DataFormat, PIL_FORMAT, permute_tuple 16 | from typing import Tuple 17 | 18 | 19 | class HDF5Dataset(Dataset): 20 | def __init__( 21 | self, 22 | file_path: str, 23 | transform: Transform = None, 24 | h5_format: str = PIL_FORMAT, 25 | output_format: str = DataFormat.CHW, 26 | channels: int = 3, 27 | ): 28 | """Custom dataset for loading from a file stored in the HDF5 format. This is 29 | provided to mirror the file format created by the `dataset_tool_h5.py` used 30 | by the Tensorflow implementation. By default outputs are three channel tensors 31 | where channels are split at the top structure, i.e. [[[R, R]], [[G, G]], [[B, B, B]]]. 32 | 33 | Args: 34 | file_path (str): HDF5 file to load from. 35 | transform (Transform, optional): A custom transform to apply after loading. 36 | A PIL input will be fed into this transform. A Tensor conversion operation 37 | will always occur after. Defaults to None. 38 | h5_format (str, optional): Format h5 data is stored in. Defaults to PIL_FORMAT. 39 | output_format (str, optional): Data format to output data in, if None the default 40 | format used by PyTorch will be used. Defaults to DataFormat.CHW. 41 | channels (int, optional): Number of output channels (1 or 3). If the loaded 42 | image is 1 channel and 3 channels are required the single channel 43 | will be copied across each channel. If 3 channels are loaded and 1 channel 44 | is required a weighted RGB to L conversion occurs. 45 | """ 46 | super(HDF5Dataset, self).__init__() 47 | 48 | self.file_path = file_path 49 | with h5py.File(self.file_path, "r") as h5file: 50 | self.img_count = h5file["images"].shape[0] 51 | self.transform = transform 52 | self.output_format = output_format 53 | self.channels = channels 54 | self.h5_format = h5_format 55 | 56 | def __getitem__(self, index: int) -> Tuple[Tensor, int]: 57 | with h5py.File(self.file_path, "r") as h5file: 58 | img = h5file["images"][index] 59 | shp = h5file["shapes"][index] 60 | img = np.reshape(img, shp) 61 | # Get actual PIL object for transforms to be applied to 62 | img = img.transpose(*permute_tuple(self.h5_format, "WHC")) 63 | img = Image.fromarray(img) 64 | img = ssdn.utils.set_color_channels(img, self.channels) 65 | # Apply custom transform 66 | if self.transform: 67 | img = self.transform(img) 68 | # Convert to tensor if this hasn't be done during the transform 69 | if not isinstance(img, torch.Tensor): 70 | img = F.to_tensor(img) 71 | if self.output_format is not None: 72 | img = img.permute(permute_tuple(PIL_FORMAT, self.output_format)) 73 | 74 | return img, index 75 | 76 | def image_size(self, index: int, ignore_transform: bool = False) -> Tensor: 77 | """Quick method to check image size by accessing only shape field. Note that if a 78 | transform is in place then the data must be loaded directly from the dataset 79 | to ensure the transform has not changed the shape. 80 | 81 | Args: 82 | index (int): Index of image in dataset. 83 | ignore_transform (bool, optional): Whether the transform is known not to 84 | affect the output size. This will cause the true image size to always 85 | be returned. Defaults to False. 86 | 87 | Returns: 88 | Tensor: Shape tensor in output data format 89 | """ 90 | # Check if quick method viable 91 | if self.transform is not None and not ignore_transform: 92 | return torch.tensor(self.__getitem__(index)[0].shape) 93 | # Can use quick method 94 | with h5py.File(self.file_path, "r") as h5file: 95 | shp = h5file["shapes"][index] 96 | shp = shp[list(permute_tuple(self.h5_format, self.output_format))] 97 | return torch.tensor(shp) 98 | 99 | def __len__(self): 100 | return self.img_count 101 | -------------------------------------------------------------------------------- /ssdn/ssdn/datasets/noise_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import ssdn 4 | 5 | from torch import Tensor 6 | 7 | from torch.utils.data import Dataset 8 | from ssdn.params import NoiseAlgorithm 9 | 10 | from enum import Enum, auto 11 | from typing import Union, Dict, Tuple, List, Optional 12 | from numbers import Number 13 | from ssdn.utils.data_format import DataFormat, DATA_FORMAT_DIM_INDEX, DataDim 14 | 15 | NULL_IMAGE = torch.zeros(0) 16 | 17 | 18 | class NoisyDataset(Dataset): 19 | """Wrapper for a child dataset for creating inputs for training denoising 20 | algorithms. This involves adding noise to data and creating appropriate 21 | references for the algorithm being trained. Data can be padded to match 22 | requirements of input network; this can be unpadded again using information 23 | provided in the metadata dictionary. Metadata includes shapes of the inputs 24 | before padding, the index of the data returned, and noise coefficients used. 25 | 26 | Args: 27 | child (Dataset): Child dataset to load data from. It is expected that an 28 | unlabelled image is the first element of any returned data. 29 | noise_style (str): The noise style to use in string representation. 30 | algorithm (NoiseAlgorithm): The algorithm the loader should prepare data for. 31 | This will dictate the appropriate reference images created. 32 | enable_metadata (bool, optional): Whether to return a dictionary containing 33 | information about data creation. When False only two values are returned. 34 | Defaults to True. 35 | pad_uniform (bool, optional): Whether to pad returned images to the same size 36 | as the largest image. This may cause very slow initialisation for large 37 | datasets. Defaults to False. 38 | pad_multiple (int, optional): Whether to pad the width and height of returned 39 | images to a divisor. Ignored if None. Defaults to None. 40 | square (bool, optional): Whether to pad such that width and height are equal. 41 | Defaults to False. 42 | data_format (str, optional): Format of data from underlying dataset. 43 | Defaults to DataFormat.CHW. 44 | """ 45 | 46 | INPUT = 0 47 | REFERENCE = 1 48 | METADATA = 2 49 | """ Indexes for returned data.""" 50 | 51 | def __init__( 52 | self, 53 | child: Dataset, 54 | noise_style: str, 55 | algorithm: NoiseAlgorithm, 56 | enable_metadata: bool = True, 57 | pad_uniform: bool = False, 58 | pad_multiple: int = None, 59 | square: bool = False, 60 | data_format: str = DataFormat.CHW, 61 | training_mode: bool = False 62 | ): 63 | self.child = child 64 | self.enable_metadata = enable_metadata 65 | self.noise_style = noise_style 66 | self.algorithm = algorithm 67 | self.pad_uniform = pad_uniform 68 | self.pad_multiple = pad_multiple 69 | self.square = square 70 | self.data_format = data_format 71 | self.training_mode = training_mode 72 | 73 | # Initialise max image size property, this will load all data 74 | self._max_image_size = None 75 | if self.pad_uniform: 76 | _ = self.max_image_size 77 | 78 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, Optional[Dict]]: 79 | data = self.child.__getitem__(index) 80 | img = data[0] 81 | 82 | if self.enable_metadata: 83 | metadata = {} 84 | metadata[NoisyDataset.Metadata.INDEXES] = index 85 | else: 86 | metadata = None 87 | # Noisify and create appropriate reference 88 | (inp, ref, metadata) = self.prepare_input(img, metadata) 89 | 90 | if self.enable_metadata: 91 | return (inp, ref, metadata) 92 | else: 93 | return (inp, ref) 94 | 95 | def __len__(self) -> int: 96 | return self.child.__len__() 97 | 98 | def prepare_input( 99 | self, clean: Tensor, metadata: Dict = {} 100 | ) -> Tuple[Tensor, Tensor, Dict]: 101 | """Translate clean reference into training input and reference. The algorithm 102 | being trained dictates the reference used, e.g. Noise2Noise will create a 103 | noisy input and noisy reference. 104 | 105 | Args: 106 | clean (Tensor): Clean input image to create noisy input and reference from. 107 | metadata (Dict, optional): Dictionary to fill with metadata. Defaults to 108 | creating a new dictionary. 109 | 110 | Returns: 111 | Tuple[Tensor, Tensor, Dict]: Input, Reference, Metadata Dictionary 112 | """ 113 | # Helper function to fix coefficient shape to [1, 1, 1] shape, batcher will 114 | # automatically elevate to [n, 1, 1, 1] shape if required 115 | def broadcast_coeffs(imgs: Tensor, coeffs: Union[Tensor, Number]): 116 | return torch.zeros((1, 1, 1)) + coeffs 117 | 118 | # Track the true shape as batching may lead to padding distorting this shape 119 | image_shape = clean.shape 120 | 121 | # Create the noisy input images 122 | noisy_in, noisy_in_coeff = ssdn.utils.noise.add_style(clean, self.noise_style) 123 | if self.algorithm == NoiseAlgorithm.NOISE_TO_VOID and self.training_mode: 124 | noisy_in, mask_coords = ssdn.utils.n2v_ups.manipulate(noisy_in, 5) # TODO use config for neighbourhood radius 125 | metadata[NoisyDataset.Metadata.MASK_COORDS] = mask_coords 126 | inp, inp_coeff = noisy_in, noisy_in_coeff 127 | 128 | # N2C requires noisy input and clean reference images 129 | if self.algorithm == NoiseAlgorithm.NOISE_TO_CLEAN: 130 | ref, ref_coeff = clean, 0 131 | # N2N and N2V require noisy input and noisy reference images 132 | elif ( 133 | self.algorithm == NoiseAlgorithm.NOISE_TO_NOISE or 134 | self.algorithm == NoiseAlgorithm.NOISE_TO_VOID 135 | ): 136 | ref, ref_coeff = ssdn.utils.noise.add_style(clean, self.noise_style) 137 | # SSDN requires noisy input and no reference images 138 | elif self.algorithm == NoiseAlgorithm.SELFSUPERVISED_DENOISING: 139 | ref, ref_coeff = NULL_IMAGE, 0 140 | # SSDN mean only requires noisy input and same image as noisy input reference 141 | elif self.algorithm == NoiseAlgorithm.SELFSUPERVISED_DENOISING_MEAN_ONLY: 142 | ref, ref_coeff = noisy_in, noisy_in_coeff 143 | else: 144 | raise NotImplementedError("Denoising algorithm not supported") 145 | 146 | # Original implementation pads before adding noise, here it is done after as it 147 | # reduces the false scenario of adding structured noise across the full image 148 | inp = self.pad_to_output_size(inp) 149 | if ref is not NULL_IMAGE: 150 | ref = self.pad_to_output_size(ref) 151 | 152 | # Fill metdata dictionary 153 | if metadata is not None: 154 | metadata[NoisyDataset.Metadata.CLEAN] = self.pad_to_output_size(clean) 155 | metadata[NoisyDataset.Metadata.IMAGE_SHAPE] = torch.tensor(image_shape) 156 | metadata[NoisyDataset.Metadata.INPUT_NOISE_VALUES] = broadcast_coeffs( 157 | inp, inp_coeff 158 | ) 159 | metadata[NoisyDataset.Metadata.REFERENCE_NOISE_VALUES] = broadcast_coeffs( 160 | ref, ref_coeff 161 | ) 162 | 163 | return (inp, ref, metadata) 164 | 165 | @property 166 | def max_image_size(self) -> List[int]: 167 | """ Find the maximum image size in the dataset. Will try calling `image_size` method 168 | first in case a fast method for checking size has been implemented. Will fall back 169 | to loading images from the dataset as normal and checking their shape. Once this 170 | method has been called once the maximum size will be cached for subsequent calls. 171 | """ 172 | if self._max_image_size is None: 173 | try: 174 | image_sizes = [self.child.image_size(i) for i in range(len(self.child))] 175 | except AttributeError: 176 | image_sizes = [torch.tensor(data[0].shape) for data in self.child] 177 | 178 | image_sizes = torch.stack(image_sizes) 179 | max_image_size = torch.max(image_sizes, dim=0).values 180 | self._max_image_size = max_image_size 181 | return self._max_image_size 182 | 183 | def get_output_size(self, image: Tensor) -> Tensor: 184 | """Calculate output size of an image using the current padding configuration. 185 | """ 186 | df = DATA_FORMAT_DIM_INDEX[self.data_format] 187 | # Use largest image size in dataset if returning uniform sized tensors 188 | if self.pad_uniform: 189 | image_size = self.max_image_size 190 | else: 191 | image_size = image.shape 192 | image_size = list(image_size) 193 | 194 | # Pad width and height axis up to a supported multiple 195 | if self.pad_multiple: 196 | pad = self.pad_multiple 197 | for dim in [DataDim.HEIGHT, DataDim.WIDTH]: 198 | image_size[df[dim]] = (image_size[df[dim]] + pad - 1) // pad * pad 199 | 200 | # Pad to be a square 201 | if self.square: 202 | size = max(image_size[df[DataDim.HEIGHT]], image_size[df[DataDim.WIDTH]]) 203 | image_size[df[DataDim.HEIGHT]] = size 204 | image_size[df[DataDim.WIDTH]] = size 205 | 206 | return torch.tensor(image_size) 207 | 208 | def pad_to_output_size(self, image: Tensor) -> Tensor: 209 | """ Apply reflection padding to the image to meet the current padding 210 | configuration. Note that padding is handled by Numpy. 211 | """ 212 | 213 | supported = [DataFormat.CHW, DataFormat.CWH, DataFormat.BCHW, DataFormat.BCWH] 214 | if self.data_format not in supported: 215 | raise NotImplementedError("Padding not supported by data format") 216 | 217 | df = DATA_FORMAT_DIM_INDEX[self.data_format] 218 | output_size = self.get_output_size(image) 219 | # Already correct, do not pad 220 | if all(output_size == torch.tensor(image.shape)): 221 | return image 222 | left, top = 0, 0 223 | right = output_size[df[DataDim.WIDTH]] - image.shape[df[DataDim.WIDTH]] 224 | bottom = output_size[df[DataDim.HEIGHT]] - image.shape[df[DataDim.HEIGHT]] 225 | # Pad Width/Height ignoring other axis 226 | pad_matrix = [[0, 0]] * len(self.data_format) 227 | pad_matrix[df[DataDim.WIDTH]] = [left, right] 228 | pad_matrix[df[DataDim.HEIGHT]] = [top, bottom] 229 | # PyTorch methods expect PIL images so fallback to Numpy for padding 230 | np_padded = np.pad(image, pad_matrix, mode="reflect") 231 | # Convert back to Tensor 232 | return torch.tensor( 233 | np_padded, device=image.device, requires_grad=image.requires_grad 234 | ) 235 | 236 | @staticmethod 237 | def _unpad_single(image: Tensor, shape: Tensor) -> Tensor: 238 | # Create slice list extracting from 0:n for each shape axis 239 | slices = list(map(lambda x: slice(*x), (zip([0] * len(shape), shape)))) 240 | return image[slices] 241 | 242 | @staticmethod 243 | def _unpad(image: Tensor, shape: Tensor) -> Union[Tensor, List[Tensor]]: 244 | if len(image.shape) <= shape.shape[-1]: 245 | return NoisyDataset._unpad_single(image, shape) 246 | return [NoisyDataset._unpad_single(i, s) for i, s in zip(image, shape)] 247 | 248 | @staticmethod 249 | def unpad( 250 | image: Tensor, metadata: Dict, batch_index: int = None 251 | ) -> Union[Tensor, List[Tensor]]: 252 | """For a padded image or batch of padded images, undo padding. It is 253 | assumed that the original image is positioned in the top left and 254 | that the channel count has not changed. 255 | 256 | Args: 257 | image (Tensor): Single image or batch of images. 258 | metadata (Tensor): Metadata dictionary associated with images to 259 | unpad. 260 | 261 | Returns: 262 | Union[Tensor, List[Tensor]]: Unpadded image tensor if not batched. 263 | List of unpadded images if batched. 264 | """ 265 | inp_shape = metadata[NoisyDataset.Metadata.IMAGE_SHAPE] 266 | if batch_index is not None: 267 | image = image[batch_index] 268 | inp_shape = inp_shape[batch_index] 269 | return NoisyDataset._unpad(image, inp_shape) 270 | 271 | class Metadata(Enum): 272 | """ Enumeration of fields that can be contained in the metadata dictionary. 273 | """ 274 | 275 | CLEAN = auto() 276 | IMAGE_SHAPE = auto() 277 | INDEXES = auto() 278 | INPUT_NOISE_VALUES = auto() 279 | REFERENCE_NOISE_VALUES = auto() 280 | MASK_COORDS = auto() 281 | -------------------------------------------------------------------------------- /ssdn/ssdn/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | """Contains custom sampler to allow repeated use of the same data with fair extraction. 4 | """ 5 | __authors__ = "David Jones " 6 | 7 | import torch 8 | 9 | from torch.utils.data import Sampler, Dataset 10 | from typing import Generator, List, Dict 11 | 12 | 13 | class FixedLengthSampler(Sampler): 14 | """Sample in either sequential or a random order for the given number of samples. If the 15 | number of requested samples execeds the dataset, the dataset will loop. Unlike standard 16 | sampling with replacement this means a sample will only ever be used once more than any 17 | other sample. 18 | 19 | There is no option for fully random selection with replacement, use PyTorch's 20 | `RandomSampler` if this behaviour is desired. 21 | 22 | Args: 23 | data_source (Dataset): Dataset to load samples from. 24 | num_samples (int, optional): The number of samples to be returned by the dataset. 25 | Defaults to None; this is equivalent to the length of the dataset. 26 | shuffled (bool, optional): Whether to randomise order. Defaults to False. 27 | """ 28 | 29 | def __init__( 30 | self, data_source: Dataset, num_samples: int = None, shuffled: bool = False, 31 | ): 32 | self.data_source = data_source 33 | self._num_samples = num_samples 34 | self.shuffled = shuffled 35 | self._next_iter = None 36 | self._last_iter = None 37 | 38 | @property 39 | def num_samples(self) -> int: 40 | if self._num_samples is None: 41 | return len(self.data_source) 42 | else: 43 | return self._num_samples 44 | 45 | def sampler(self) -> Generator[int, None, None]: 46 | """Iterator handling both shuffled and non-shuffled behaviour. 47 | 48 | Yields: 49 | Generator[int, None, None]: Next index to sample. 50 | """ 51 | remaining = self.num_samples 52 | if self.shuffled: 53 | while remaining > 0: 54 | n = min(remaining, len(self.data_source)) 55 | for idx in torch.randperm(len(self.data_source))[0:n]: 56 | yield int(idx) 57 | remaining -= n 58 | else: 59 | current_idx = None 60 | while remaining > 0: 61 | if current_idx is None or current_idx >= len(self.data_source): 62 | current_idx = 0 63 | yield current_idx 64 | current_idx += 1 65 | remaining -= 1 66 | 67 | def __iter__(self) -> Generator[int, None, None]: 68 | if self._next_iter is None: 69 | sample_order = list(self.sampler()) 70 | self._last_iter = SamplingOrder(sample_order) 71 | return self._last_iter 72 | else: 73 | return self._next_iter 74 | 75 | def __len__(self) -> int: 76 | return self.num_samples 77 | 78 | def for_next_iter(self, iter_order: SamplingOrder): 79 | self._next_iter = iter_order 80 | self._last_iter = iter_order 81 | 82 | def last_iter(self) -> Generator[int, None, None]: 83 | return self._last_iter 84 | 85 | 86 | class SamplingOrder: 87 | def __init__(self, order: List[int], index: int = 0): 88 | self.order = order 89 | self.index = index 90 | 91 | def __iter__(self) -> Generator[int, None, None]: 92 | return self 93 | 94 | def __len__(self) -> int: 95 | return len(self.order) 96 | 97 | def __next__(self) -> int: 98 | if self.index < len(self.order): 99 | value = self.order[self.index] 100 | self.index += 1 101 | return value 102 | else: 103 | raise StopIteration() 104 | 105 | def state_dict(self) -> Dict: 106 | state_dict = {"order": self.order, "index": self.index} 107 | return state_dict 108 | 109 | @staticmethod 110 | def from_state_dict(state_dict: Dict) -> SamplingOrder: 111 | return SamplingOrder(state_dict["order"], state_dict["index"]) 112 | -------------------------------------------------------------------------------- /ssdn/ssdn/denoiser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import ssdn 7 | 8 | from torch import Tensor 9 | 10 | from ssdn.params import ( 11 | ConfigValue, 12 | PipelineOutput, 13 | Pipeline, 14 | NoiseValue, 15 | ) 16 | 17 | from ssdn.models import NoiseNetwork 18 | from ssdn.datasets import NoisyDataset 19 | 20 | from typing import Dict, List 21 | 22 | 23 | class Denoiser(nn.Module): 24 | 25 | MODEL = "denoiser_model" 26 | SIGMA_ESTIMATOR = "sigma_estimation_model" 27 | ESTIMATED_SIGMA = "estimated_sigma" 28 | 29 | def __init__( 30 | self, cfg: Dict, device: str = None, 31 | ): 32 | super().__init__() 33 | # Configure device 34 | if device: 35 | device = torch.device(device) 36 | else: 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | self.device = device 39 | 40 | # Store the denoiser configuration 41 | self.cfg = cfg 42 | 43 | # Models to use during training, these can be parallelised 44 | self.models = nn.ModuleDict() 45 | # References to models that are guaranteed to be device independent 46 | self._models = nn.ModuleDict() 47 | # Initialise networks using current configuration 48 | self.init_networks() 49 | # Learnable parameters 50 | self.l_params = nn.ParameterDict() 51 | self.init_l_params() 52 | 53 | def init_networks(self): 54 | # Calculate input and output channel count for networks 55 | in_channels = self.cfg[ConfigValue.IMAGE_CHANNELS] 56 | if self.cfg[ConfigValue.PIPELINE] == Pipeline.SSDN: 57 | if self.cfg[ConfigValue.DIAGONAL_COVARIANCE]: 58 | out_channels = in_channels * 2 # Means, diagonal of A 59 | else: 60 | out_channels = ( 61 | in_channels + (in_channels * (in_channels + 1)) // 2 62 | ) # Means, triangular A. 63 | else: 64 | out_channels = in_channels 65 | 66 | # Create general denoising model 67 | self.add_model( 68 | Denoiser.MODEL, 69 | NoiseNetwork( 70 | in_channels=in_channels, 71 | out_channels=out_channels, 72 | blindspot=self.cfg[ConfigValue.BLINDSPOT], 73 | ), 74 | ) 75 | # Create separate model for variable parameter estimation 76 | if ( 77 | self.cfg[ConfigValue.PIPELINE] == Pipeline.SSDN 78 | and self.cfg[ConfigValue.NOISE_VALUE] == NoiseValue.UNKNOWN_VARIABLE 79 | ): 80 | self.add_model( 81 | Denoiser.SIGMA_ESTIMATOR, 82 | NoiseNetwork( 83 | in_channels=in_channels, 84 | out_channels=1, 85 | blindspot=False, 86 | zero_output_weights=True, 87 | ), 88 | ) 89 | 90 | def init_l_params(self): 91 | if ( 92 | self.cfg[ConfigValue.PIPELINE] == Pipeline.SSDN 93 | and self.cfg[ConfigValue.NOISE_VALUE] == NoiseValue.UNKNOWN_CONSTANT 94 | ): 95 | init_value = torch.zeros((1, 1, 1, 1)) 96 | self.l_params[Denoiser.ESTIMATED_SIGMA] = nn.Parameter(init_value) 97 | 98 | def get_model(self, model_id: str, parallelised: bool = True) -> nn.Module: 99 | model_dict = self.models if parallelised else self._models 100 | return model_dict[model_id] 101 | 102 | def add_model(self, model_id: str, model: nn.Module, parallelise: bool = True): 103 | self._models[model_id] = model 104 | if parallelise: 105 | parallel_model = nn.DataParallel(model) 106 | else: 107 | parallel_model = model 108 | # Move to master device (GPU 0 or CPU) 109 | parallel_model.to(self.device) 110 | self.models[model_id] = parallel_model 111 | 112 | def forward(self, data: Tensor) -> Tensor: 113 | """Pass an input into the denoiser for inference. This will not train 114 | the network. Inference will be applied using current model state with 115 | the configured pipeline. 116 | 117 | Args: 118 | data (Tensor): Image or batch of images to denoise in BCHW format. 119 | 120 | Returns: 121 | Tensor: Denoised image or images. 122 | """ 123 | assert NoisyDataset.INPUT == 0 124 | inputs = [data] 125 | outputs = self.run_pipeline(inputs) 126 | return outputs[PipelineOutput.IMG_DENOISED] 127 | 128 | def run_pipeline(self, data: List, **kwargs): 129 | if self.cfg[ConfigValue.PIPELINE] == Pipeline.MSE: 130 | return self._mse_pipeline(data, **kwargs) 131 | elif self.cfg[ConfigValue.PIPELINE] == Pipeline.SSDN: 132 | return self._ssdn_pipeline(data, **kwargs) 133 | elif self.cfg[ConfigValue.PIPELINE] == Pipeline.MASK_MSE: 134 | return self._mask_mse_pipeline(data, **kwargs) 135 | # return self._mse_pipeline(data, **kwargs) 136 | 137 | else: 138 | raise NotImplementedError("Unsupported processing pipeline") 139 | 140 | def _mse_pipeline(self, data: List, **kwargs) -> Dict: 141 | outputs = {PipelineOutput.INPUTS: data} 142 | # Run the input through the model 143 | inp = data[NoisyDataset.INPUT].to(self.device) 144 | inp.requires_grad = True 145 | cleaned = self.models[Denoiser.MODEL](inp) 146 | outputs[PipelineOutput.IMG_DENOISED] = cleaned 147 | 148 | # If reference images are provided calculate the loss 149 | # as MSE whilst preserving individual loss per batch 150 | if len(data) >= NoisyDataset.REFERENCE: 151 | ref = data[NoisyDataset.REFERENCE].to(self.device) 152 | ref.requires_grad = True 153 | loss = nn.MSELoss(reduction="none")(cleaned, ref) 154 | loss = loss.view(loss.shape[0], -1).mean(1, keepdim=True) 155 | outputs[PipelineOutput.LOSS] = loss 156 | 157 | return outputs 158 | 159 | def _mask_mse_pipeline(self, data: List, **kwargs) -> Dict: 160 | outputs = {PipelineOutput.INPUTS: data} 161 | # Run the input through the model 162 | inp = data[NoisyDataset.INPUT].to(self.device) 163 | inp.requires_grad = True 164 | cleaned = self.models[Denoiser.MODEL](inp) 165 | outputs[PipelineOutput.IMG_DENOISED] = cleaned 166 | 167 | # If reference images are provided calculate the loss 168 | # as MSE whilst preserving individual loss per batch 169 | if len(data) >= NoisyDataset.REFERENCE: 170 | ref = data[NoisyDataset.REFERENCE].to(self.device) 171 | ref.requires_grad = True 172 | 173 | if NoisyDataset.Metadata.MASK_COORDS in data[NoisyDataset.METADATA]: 174 | mask_coords = data[NoisyDataset.METADATA][NoisyDataset.Metadata.MASK_COORDS] 175 | loss = ssdn.utils.n2v_loss.loss_mask_mse(mask_coords, cleaned, ref) 176 | loss = loss.view(loss.shape[0], -1).mean(1, keepdim=True) 177 | outputs[PipelineOutput.LOSS] = loss 178 | 179 | 180 | return outputs 181 | 182 | def _ssdn_pipeline(self, data: List, **kwargs) -> Dict: 183 | debug = False 184 | 185 | inp = data[NoisyDataset.INPUT] 186 | noisy_in = inp.to(self.device) 187 | 188 | # noisy_params_in = standard deviation of noise 189 | metadata = data[NoisyDataset.METADATA] 190 | noise_params_in = metadata[NoisyDataset.Metadata.INPUT_NOISE_VALUES] 191 | 192 | # config for noise params/style 193 | noise_style = self.cfg[ConfigValue.NOISE_STYLE] 194 | noise_params = self.cfg[ConfigValue.NOISE_VALUE] 195 | 196 | # Equivalent of blindspot_pipeline 197 | input_shape = metadata[NoisyDataset.Metadata.IMAGE_SHAPE] 198 | num_channels = self.cfg[ConfigValue.IMAGE_CHANNELS] 199 | assert num_channels in [1, 3] 200 | 201 | diagonal_covariance = self.cfg[ConfigValue.DIAGONAL_COVARIANCE] 202 | 203 | if debug: 204 | print("Image shape:", input_shape) 205 | if debug: 206 | print("Num. channels:", num_channels) 207 | 208 | # Clean data distribution. 209 | # Calculation still needed for line 175 210 | num_output_components = ( 211 | num_channels + (num_channels * (num_channels + 1)) // 2 212 | ) # Means, triangular A. 213 | if diagonal_covariance: 214 | num_output_components = num_channels * 2 # Means, diagonal of A. 215 | if debug: 216 | print("Num. output components:", num_output_components) 217 | # Call the NN with the current image etc. 218 | net_out = self.models[Denoiser.MODEL](noisy_in) 219 | # net_out = net_out.type(torch.float64) 220 | if debug: 221 | print("Net output shape:", net_out.shape) 222 | mu_x = net_out[:, 0:num_channels, ...] # Means (NCHW). 223 | A_c = net_out[ 224 | :, num_channels:num_output_components, ... 225 | ] # Components of triangular A. 226 | if debug: 227 | print("Shape of A_c:", A_c.shape) 228 | if num_channels == 1: 229 | sigma_x = A_c ** 2 # N1HW 230 | elif num_channels == 3: 231 | if debug: 232 | print("Shape before permute:", A_c.shape) 233 | A_c = A_c.permute(0, 2, 3, 1) # NHWC 234 | if debug: 235 | print("Shape after permute:", A_c.shape) 236 | if diagonal_covariance: 237 | c00 = A_c[..., 0] ** 2 238 | c11 = A_c[..., 1] ** 2 239 | c22 = A_c[..., 2] ** 2 240 | zro = torch.zeros(c00.shape()) 241 | c0 = torch.stack([c00, zro, zro], dim=-1) # NHW3 242 | c1 = torch.stack([zro, c11, zro], dim=-1) # NHW3 243 | c2 = torch.stack([zro, zro, c22], dim=-1) # NHW3 244 | else: 245 | # Calculate A^T * A 246 | c00 = A_c[..., 0] ** 2 + A_c[..., 1] ** 2 + A_c[..., 2] ** 2 # NHW 247 | c01 = A_c[..., 1] * A_c[..., 3] + A_c[..., 2] * A_c[..., 4] 248 | c02 = A_c[..., 2] * A_c[..., 5] 249 | c11 = A_c[..., 3] ** 2 + A_c[..., 4] ** 2 250 | c12 = A_c[..., 4] * A_c[..., 5] 251 | c22 = A_c[..., 5] ** 2 252 | c0 = torch.stack([c00, c01, c02], dim=-1) # NHW3 253 | c1 = torch.stack([c01, c11, c12], dim=-1) # NHW3 254 | c2 = torch.stack([c02, c12, c22], dim=-1) # NHW3 255 | sigma_x = torch.stack([c0, c1, c2], dim=-1) # NHW33 256 | 257 | # Data on which noise parameter estimation is based. 258 | if noise_params == NoiseValue.UNKNOWN_CONSTANT: 259 | # Global constant over the entire dataset. 260 | noise_est_out = self.l_params[Denoiser.ESTIMATED_SIGMA] 261 | elif noise_params == NoiseValue.UNKNOWN_VARIABLE: 262 | # Separate analysis network. 263 | param_est_net_out = self.models[Denoiser.SIGMA_ESTIMATOR](noisy_in) 264 | param_est_net_out = torch.mean(param_est_net_out, dim=(2, 3), keepdim=True) 265 | noise_est_out = param_est_net_out # .type(torch.float64) 266 | 267 | # Cast remaining data into float64. 268 | # noisy_in = noisy_in.type(torch.float64) 269 | # noise_params_in = noise_params_in.type(torch.float64) 270 | 271 | # Remap noise estimate to ensure it is always positive and starts near zero. 272 | if noise_params != NoiseValue.KNOWN: 273 | # default pytorch vals: beta=1, threshold=20 274 | softplus = torch.nn.Softplus() # yes this line is necessary, don't ask 275 | noise_est_out = softplus(noise_est_out - 4.0) + 1e-3 276 | 277 | # Distill noise parameters from learned/known data. 278 | if noise_style.startswith("gauss"): 279 | if noise_params == NoiseValue.KNOWN: 280 | noise_std = torch.max( 281 | noise_params_in, torch.tensor(1e-3) # , dtype=torch.float64) 282 | ) # N111 283 | else: 284 | noise_std = noise_est_out 285 | elif noise_style.startswith( 286 | "poisson" 287 | ): # Simple signal-dependent Poisson approximation [Hasinoff 2012]. 288 | if noise_params == NoiseValue.KNOWN: 289 | noise_std = ( 290 | torch.maximum(mu_x, torch.tensor(1e-3)) # , dtype=torch.float64)) 291 | / noise_params_in 292 | ) ** 0.5 # NCHW 293 | else: 294 | noise_std = ( 295 | torch.maximum(mu_x, torch.tensor(1e-3)) # , dtype=torch.float64)) 296 | * noise_est_out 297 | ) ** 0.5 # NCHW 298 | 299 | # Casts and vars. 300 | # noise_std = noise_std.type(torch.float64) 301 | noise_std = noise_std.to(self.device) 302 | # I = tf.eye(num_channels, batch_shape=[1, 1, 1], dtype=tf.float64) 303 | I = torch.eye(num_channels, device=self.device) # dtype=torch.float64 304 | I = I.reshape( 305 | 1, 1, 1, num_channels, num_channels 306 | ) # Creates the same shape as the tensorflow thing did, wouldn't work for other batch shapes 307 | Ieps = I * 1e-6 308 | zero64 = torch.tensor(0.0, device=self.device) # , dtype=torch.float64 309 | 310 | # Helpers. 311 | def batch_mvmul(m, v): # Batched (M * v). 312 | return torch.sum(m * v[..., None, :], dim=-1) 313 | 314 | def batch_vtmv(v, m): # Batched (v^T * M * v). 315 | return torch.sum(v[..., :, None] * v[..., None, :] * m, dim=[-2, -1]) 316 | 317 | def batch_vvt(v): # Batched (v * v^T). 318 | return v[..., :, None] * v[..., None, :] 319 | 320 | # Negative log-likelihood loss and posterior mean estimation. 321 | if noise_style.startswith("gauss") or noise_style.startswith("poisson"): 322 | if num_channels == 1: 323 | sigma_n = noise_std ** 2 # N111 / N1HW 324 | sigma_y = sigma_x + sigma_n # N1HW. Total variance. 325 | loss_out = ((noisy_in - mu_x) ** 2) / sigma_y + torch.log( 326 | sigma_y 327 | ) # N1HW 328 | pme_out = (noisy_in * sigma_x + mu_x * sigma_n) / ( 329 | sigma_x + sigma_n 330 | ) # N1HW 331 | net_std_out = (sigma_x ** 0.5)[:, 0, ...] # NHW 332 | noise_std_out = noise_std[:, 0, ...] # N11 / NHW 333 | if noise_params != NoiseValue.KNOWN: 334 | loss_out = loss_out - 0.1 * noise_std # Balance regularization. 335 | else: 336 | # Training loss. 337 | noise_std_sqr = noise_std ** 2 338 | sigma_n = ( 339 | noise_std_sqr.permute(0, 2, 3, 1)[..., None] * I 340 | ) # NHWC1 * NHWCC = NHWCC 341 | if debug: 342 | print("sigma_n device:", sigma_n.device) 343 | if debug: 344 | print("sigma_x device:", sigma_x.device) 345 | sigma_y = ( 346 | sigma_x + sigma_n 347 | ) # NHWCC, total covariance matrix. Cannot be singular because sigma_n is at least a small diagonal. 348 | if debug: 349 | print("sigma_y device:", sigma_y.device) 350 | sigma_y_inv = torch.inverse(sigma_y) # NHWCC 351 | mu_x2 = mu_x.permute(0, 2, 3, 1) # NHWC 352 | noisy_in2 = noisy_in.permute(0, 2, 3, 1) # NHWC 353 | diff = noisy_in2 - mu_x2 # NHWC 354 | diff = -0.5 * batch_vtmv(diff, sigma_y_inv) # NHW 355 | dets = torch.det(sigma_y) # NHW 356 | dets = torch.max( 357 | zero64, dets 358 | ) # NHW. Avoid division by zero and negative square roots. 359 | loss_out = 0.5 * torch.log(dets) - diff # NHW 360 | if noise_params != NoiseValue.KNOWN: 361 | loss_out = loss_out - 0.1 * torch.mean( 362 | noise_std, dim=1 363 | ) # Balance regularization. 364 | 365 | # Posterior mean estimate. 366 | sigma_x_inv = torch.inverse(sigma_x + Ieps) # NHWCC 367 | sigma_n_inv = torch.inverse(sigma_n + Ieps) # NHWCC 368 | pme_c1 = torch.inverse(sigma_x_inv + sigma_n_inv + Ieps) # NHWCC 369 | pme_c2 = batch_mvmul(sigma_x_inv, mu_x2) # NHWCC * NHWC -> NHWC 370 | pme_c2 = pme_c2 + batch_mvmul(sigma_n_inv, noisy_in2) # NHWC 371 | pme_out = batch_mvmul(pme_c1, pme_c2) # NHWC 372 | pme_out = pme_out.permute(0, 3, 1, 2) # NCHW 373 | 374 | # Summary statistics. 375 | net_std_out = torch.max(zero64, torch.det(sigma_x)) ** ( 376 | 1.0 / 6.0 377 | ) # NHW 378 | noise_std_out = torch.max(zero64, torch.det(sigma_n)) ** ( 379 | 1.0 / 6.0 380 | ) # N11 / NHW 381 | 382 | # mu_x = mean of x 383 | # pme_out = posterior mean estimate 384 | # loss_out = loss 385 | # net_std_out = std estimate from nn 386 | # noise_std_out = predicted noise std? 387 | # return mu_x, pme_out, loss_out, net_std_out, noise_std_out 388 | loss_out = loss_out.view(loss_out.shape[0], -1).mean(1, keepdim=True) 389 | return { 390 | PipelineOutput.INPUTS: data, 391 | PipelineOutput.IMG_MU: mu_x, 392 | # PipelineOutput.IMG_PME: pme_out, 393 | PipelineOutput.IMG_DENOISED: pme_out, 394 | PipelineOutput.LOSS: loss_out, 395 | PipelineOutput.NOISE_STD_DEV: noise_std_out, 396 | PipelineOutput.MODEL_STD_DEV: net_std_out, 397 | } 398 | 399 | def state_dict(self, params_only: bool = False) -> Dict: 400 | state_dict = state_dict = super().state_dict() 401 | if not params_only: 402 | state_dict["cfg"] = self.cfg 403 | return state_dict 404 | 405 | @staticmethod 406 | def from_state_dict(state_dict: Dict) -> Denoiser: 407 | denoiser = Denoiser(state_dict["cfg"]) 408 | denoiser.load_state_dict(state_dict, strict=False) 409 | return denoiser 410 | 411 | def config_name(self) -> str: 412 | return ssdn.cfg.config_name(self.cfg) 413 | -------------------------------------------------------------------------------- /ssdn/ssdn/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ssdn 3 | import logging 4 | import os 5 | 6 | from torch.utils.data import Dataset 7 | from ssdn.datasets import NoisyDataset 8 | from ssdn.denoiser import Denoiser 9 | from ssdn.train import DenoiserTrainer 10 | from ssdn.params import PipelineOutput 11 | from ssdn.cfg import DEFAULT_RUN_DIR 12 | from typing import Callable, Dict 13 | from tqdm import tqdm 14 | 15 | 16 | logger = logging.getLogger("ssdn.eval") 17 | 18 | 19 | class DenoiserEvaluator(DenoiserTrainer): 20 | """Class to start evaluation of a dataset on a trained Denoiser model. 21 | Initialise with constructor and set evaluation dataset with `set_test_data()` 22 | before calling `evaluate()`. 23 | 24 | Args: 25 | target_path (str): Path to weights or training file to evaluate. 26 | runs_dir (str, optional): Root directory to create run directory. 27 | Defaults to DEFAULT_RUN_DIR. 28 | run_dir (str, optional): Explicit run directory name, will automatically 29 | generate using configuration if not provided. Defaults to None. 30 | """ 31 | 32 | def __init__( 33 | self, target_path: str, runs_dir: str = DEFAULT_RUN_DIR, run_dir: str = None, 34 | ): 35 | super().__init__({}) 36 | state_dict = torch.load(target_path, map_location="cpu") 37 | if "denoiser" in state_dict: 38 | self.load_state_dict(state_dict) 39 | else: 40 | self.denoiser = Denoiser.from_state_dict(state_dict) 41 | self.cfg = self.denoiser.cfg 42 | self._run_dir = run_dir 43 | self.init_state() 44 | 45 | def evaluate(self): 46 | self.reset_metrics(train=False) 47 | if self.denoiser is None: 48 | raise RuntimeError("Denoiser not initialised for evaluation") 49 | # Ensure writer is initialised 50 | _ = self.writer 51 | ssdn.logging_helper.setup(self.run_dir_path, "log.txt") 52 | logger.info(ssdn.utils.separator()) 53 | logger.info("Loading Test Dataset...") 54 | self.testloader, self.testset, self.test_sampler = self.test_data() 55 | logger.info("Loaded Test Dataset.") 56 | 57 | logger.info(ssdn.utils.separator()) 58 | logger.info("EVALUATION STARTED") 59 | logger.info(ssdn.utils.separator()) 60 | 61 | dataloader = tqdm(self.testloader) 62 | save_callback = self.evaluation_output_callback(self.testset) 63 | self._evaluate(dataloader, save_callback) 64 | logger.info(self.eval_state_str("EVALUATION RESULT")) 65 | logger.info(ssdn.utils.separator()) 66 | logger.info("EVALUATION FINISHED") 67 | logger.info(ssdn.utils.separator()) 68 | 69 | @property 70 | def run_dir(self) -> str: 71 | """The run path to use for this run. When this method is first called 72 | a new directory name will be generated using the next run ID and current 73 | configuration. 74 | 75 | Returns: 76 | str: Run directory name, note this is not a full path. 77 | """ 78 | if self._run_dir is None: 79 | config_name = self.config_name() 80 | next_run_id = self.next_run_id() 81 | run_dir_name = "{:05d}-eval-{}".format(next_run_id, config_name) 82 | self._run_dir = run_dir_name 83 | 84 | return self._run_dir 85 | 86 | def evaluation_output_callback( 87 | self, dataset: Dataset 88 | ) -> Callable[[int, Dict], None]: 89 | """Callback that saves all dataset images for evaluation with an associated 90 | PSNR record. 91 | 92 | Args: 93 | dataset (Dataset): Dataset which determines how many images are saved in case 94 | of repeats. 95 | 96 | Returns: 97 | Callable[[int, Dict], None]: Callback function for evaluator. 98 | """ 99 | 100 | def callback(output_0_index: int, outputs: Dict): 101 | remaining = (len(dataset) - 1) - output_0_index 102 | inp = outputs[PipelineOutput.INPUTS][NoisyDataset.INPUT] 103 | metadata = outputs[PipelineOutput.INPUTS][NoisyDataset.METADATA] 104 | batch_size = inp.shape[0] 105 | if remaining > 0: 106 | bis = range(min(remaining, batch_size)) 107 | output_dir = os.path.join(self.run_dir_path, "eval_imgs") 108 | os.makedirs(output_dir, exist_ok=True) 109 | fileformat = "img_{index:05}_{desc}.png" 110 | self.save_image_outputs( 111 | outputs, output_dir, fileformat, batch_indexes=bis 112 | ) 113 | with open(os.path.join(self.run_dir_path, "psnrs.csv"), "a") as f: 114 | if output_0_index == 0: 115 | fields = ["id", "psnr_nsy"] + list(self.img_outputs(prefix="psnr").values()) 116 | f.write(",".join(fields) + "\n") 117 | # FIXME: Doing PSNR calculations again 118 | values = [] 119 | values += [ssdn.utils.calculate_psnr(inp, metadata[NoisyDataset.Metadata.CLEAN])] 120 | for key in self.img_outputs(prefix="psnr"): 121 | values += [self.calculate_psnr(outputs, key, unpad=True)] 122 | for i in range(batch_size): 123 | str_lst = ["{:04d}".format(output_0_index + i)] 124 | str_lst += ["{:.4f}".format(value[i]) for value in values] 125 | f.write(",".join(str_lst) + "\n") 126 | 127 | return callback 128 | -------------------------------------------------------------------------------- /ssdn/ssdn/logging_helper.py: -------------------------------------------------------------------------------- 1 | """Configure the package logger. 2 | 3 | Warning: 4 | This module has state and manipulates external use of the `logging` package. 5 | """ 6 | 7 | import sys 8 | import os 9 | import datetime 10 | import logging 11 | 12 | from colorlog import ColoredFormatter 13 | from colored_traceback import Colorizer 14 | 15 | FILE_FORMAT = "%(asctime)s %(name)-30s %(levelname)-8s %(message)s" 16 | FILE_DATE_FORMAT = "%m-%d %H:%M:%S" 17 | CONSOLE_FORMAT = "%(log_color)s%(message)s%(reset)s" 18 | 19 | 20 | # Module level variable for tracking the console output handler 21 | console_handle = None 22 | # Package level logger 23 | root_logger = logging.getLogger("") 24 | package_logger = logging.getLogger(__name__.split(".")[0]) 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def _log_exception(exc_type, exc_value, exc_traceback): 29 | # The exception will still get thrown, so ensure it will not get written 30 | # to the console twice by removing the console handler before readding it 31 | if console_handle: 32 | root_logger.removeHandler(console_handle) 33 | if not issubclass(exc_type, KeyboardInterrupt): 34 | root_logger.error( 35 | "Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback) 36 | ) 37 | if console_handle: 38 | root_logger.addHandler(console_handle) 39 | colorizer = Colorizer("default", False) 40 | sys.excepthook = colorizer.colorize_traceback 41 | colorizer.colorize_traceback(exc_type, exc_value, exc_traceback) 42 | 43 | 44 | def setup(log_dir: str = None, filename: str = None): 45 | """Automatic logging setup for the SSDN package. Logging goes to 46 | an output file and the console. Console output is coloured using `colorlog` and 47 | any exception traceback is coloured using `coloured_traceback`. Handlers are 48 | attached to the root logger and will therefore capture and potentially manipulate output 49 | from any external packages. 50 | 51 | Args: 52 | log_dir (str, optional): Directory path to store logs in. Defaults to None. 53 | When None is used logs are not stored to a file. 54 | filename (str, optional): Filename to store log file as. Defaults to None. 55 | When None the current date and time will be used as the log name. 56 | """ 57 | # Setup directory if it doesn't exist 58 | if log_dir is not None: 59 | if log_dir != "" and not os.path.exists(log_dir): 60 | os.makedirs(log_dir) 61 | # Use the current datetime if filename not set 62 | if filename is None: 63 | date = datetime.datetime.now() 64 | filename = "log_{date:%Y_%m_%d-%H_%M_%S}.txt".format(date=date) 65 | file_path = os.path.join(log_dir, filename) 66 | 67 | # Configure logging to a file 68 | file = logging.FileHandler(file_path, mode="a") 69 | formatter = logging.Formatter(fmt=FILE_FORMAT, datefmt=FILE_DATE_FORMAT) 70 | file.setLevel(logging.DEBUG) 71 | file.setFormatter(formatter) 72 | root_logger.addHandler(file) 73 | 74 | # Configure console logging 75 | # Must keep track of handle in global state so it can be removed before 76 | # uncaught exceptions are logged 77 | global console_handle # noqa: E261 78 | if console_handle is None: 79 | console = logging.StreamHandler() 80 | console.setLevel(logging.DEBUG) 81 | formatter = ColoredFormatter(CONSOLE_FORMAT, datefmt=None) 82 | console.setFormatter(formatter) 83 | root_logger.addHandler(console) 84 | console_handle = console 85 | # Attach hook to log any exceptions and add coloured traceback 86 | sys.excepthook = _log_exception 87 | # By default expose all messages by default 88 | package_logger.setLevel(logging.DEBUG) 89 | -------------------------------------------------------------------------------- /ssdn/ssdn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ssdn.models.utility import * 2 | from ssdn.models.noise_network import NoiseNetwork 3 | -------------------------------------------------------------------------------- /ssdn/ssdn/models/noise_network.py: -------------------------------------------------------------------------------- 1 | """ PyTorch implementation of U-Net model for N2N and SSDN. 2 | """ 3 | 4 | import torch 5 | import ssdn 6 | import torch.nn as nn 7 | 8 | from torch import Tensor 9 | 10 | from ssdn.models.utility import Shift2d 11 | 12 | 13 | class NoiseNetwork(nn.Module): 14 | """Custom U-Net architecture for Self Supervised Denoising (SSDN) and Noise2Noise (N2N). 15 | Base N2N implementation was made with reference to @joeylitalien's N2N implementation. 16 | Changes made are removal of weight sharing when blocks are reused. Usage of LeakyReLu 17 | over standard ReLu and incorporation of blindspot functionality. 18 | 19 | Unlike other typical U-Net implementations dropout is not used when the model is trained. 20 | 21 | When in blindspot mode the following behaviour changes occur: 22 | 23 | * Input batches are duplicated for rotations: 0, 90, 180, 270. This increases the 24 | batch size by 4x. After the encode-decode stage the rotations are undone and 25 | concatenated on the channel axis with the associated original image. This 4x 26 | increase in channel count is collapsed to the standard channel count in the 27 | first 1x1 kernel convolution. 28 | 29 | * To restrict the receptive field into the upward direction a shift is used for 30 | convolutions (see ShiftConv2d) and downsampling. Downsampling uses a single 31 | pixel shift prior to max pooling as dictated by Laine et al. This is equivalent 32 | to applying a shift on the upsample. 33 | 34 | Args: 35 | in_channels (int, optional): Number of input channels, this will typically be either 36 | 1 (Mono) or 3 (RGB) but can be more. Defaults to 3. 37 | out_channels (int, optional): Number of channels the final convolution should output. 38 | Defaults to 3. 39 | blindspot (bool, optional): Whether to enable the network blindspot. This will 40 | add in rotation stages and shift stages while max pooling and during convolutions. 41 | A futher shift will occur after upsample. Defaults to False. 42 | zero_output_weights (bool, optional): Whether to initialise the weights of 43 | `nin_c` to zero. This is not mentioned in literature but is done as part 44 | of the tensorflow implementation for the parameter estimation network. 45 | Defaults to False. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | in_channels: int = 3, 51 | out_channels: int = 3, 52 | blindspot: bool = False, 53 | zero_output_weights: bool = False, 54 | ): 55 | super(NoiseNetwork, self).__init__() 56 | self._blindspot = blindspot 57 | self._zero_output_weights = zero_output_weights 58 | self.Conv2d = ShiftConv2d if self.blindspot else nn.Conv2d 59 | 60 | #################################### 61 | # Encode Blocks 62 | #################################### 63 | 64 | def _max_pool_block(max_pool: nn.Module) -> nn.Module: 65 | if blindspot: 66 | return nn.Sequential(Shift2d((1, 0)), max_pool) 67 | return max_pool 68 | 69 | # Layers: enc_conv0, enc_conv1, pool1 70 | self.encode_block_1 = nn.Sequential( 71 | self.Conv2d(in_channels, 48, 3, stride=1, padding=1), 72 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 73 | self.Conv2d(48, 48, 3, padding=1), 74 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 75 | _max_pool_block(nn.MaxPool2d(2)), 76 | ) 77 | 78 | # Layers: enc_conv(i), pool(i); i=2..5 79 | def _encode_block_2_3_4_5() -> nn.Module: 80 | return nn.Sequential( 81 | self.Conv2d(48, 48, 3, stride=1, padding=1), 82 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 83 | _max_pool_block(nn.MaxPool2d(2)), 84 | ) 85 | 86 | # Separate instances of same encode module definition created 87 | self.encode_block_2 = _encode_block_2_3_4_5() 88 | self.encode_block_3 = _encode_block_2_3_4_5() 89 | self.encode_block_4 = _encode_block_2_3_4_5() 90 | self.encode_block_5 = _encode_block_2_3_4_5() 91 | 92 | # Layers: enc_conv6 93 | self.encode_block_6 = nn.Sequential( 94 | self.Conv2d(48, 48, 3, stride=1, padding=1), 95 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 96 | ) 97 | 98 | #################################### 99 | # Decode Blocks 100 | #################################### 101 | # Layers: upsample5 102 | self.decode_block_6 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest")) 103 | 104 | # Layers: dec_conv5a, dec_conv5b, upsample4 105 | self.decode_block_5 = nn.Sequential( 106 | self.Conv2d(96, 96, 3, stride=1, padding=1), 107 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 108 | self.Conv2d(96, 96, 3, stride=1, padding=1), 109 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 110 | nn.Upsample(scale_factor=2, mode="nearest"), 111 | ) 112 | 113 | # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2 114 | def _decode_block_4_3_2() -> nn.Module: 115 | return nn.Sequential( 116 | self.Conv2d(144, 96, 3, stride=1, padding=1), 117 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 118 | self.Conv2d(96, 96, 3, stride=1, padding=1), 119 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 120 | nn.Upsample(scale_factor=2, mode="nearest"), 121 | ) 122 | 123 | # Separate instances of same decode module definition created 124 | self.decode_block_4 = _decode_block_4_3_2() 125 | self.decode_block_3 = _decode_block_4_3_2() 126 | self.decode_block_2 = _decode_block_4_3_2() 127 | 128 | # Layers: dec_conv1a, dec_conv1b, dec_conv1c, 129 | self.decode_block_1 = nn.Sequential( 130 | self.Conv2d(96 + in_channels, 96, 3, stride=1, padding=1), 131 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 132 | self.Conv2d(96, 96, 3, stride=1, padding=1), 133 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 134 | ) 135 | 136 | #################################### 137 | # Output Block 138 | #################################### 139 | 140 | if self.blindspot: 141 | # Shift 1 pixel down 142 | self.shift = Shift2d((1, 0)) 143 | # 4 x Channels due to batch rotations 144 | nin_a_io = 384 145 | else: 146 | nin_a_io = 96 147 | 148 | # nin_a,b,c, linear_act 149 | self.output_conv = self.Conv2d(96, out_channels, 1) 150 | self.output_block = nn.Sequential( 151 | self.Conv2d(nin_a_io, nin_a_io, 1), 152 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 153 | self.Conv2d(nin_a_io, 96, 1), 154 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 155 | self.output_conv, 156 | ) 157 | 158 | # Initialize weights 159 | self.init_weights() 160 | 161 | @property 162 | def blindspot(self) -> bool: 163 | return self._blindspot 164 | 165 | def init_weights(self): 166 | """Initializes weights using Kaiming He et al. (2015). 167 | 168 | Only convolution layers have learnable weights. All convolutions use a leaky 169 | relu activation function (negative_slope = 0.1) except the last which is just 170 | a linear output. 171 | """ 172 | with torch.no_grad(): 173 | self._init_weights() 174 | 175 | def _init_weights(self): 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | nn.init.kaiming_normal_(m.weight.data, a=0.1) 179 | m.bias.data.zero_() 180 | # Initialise last output layer 181 | if self._zero_output_weights: 182 | self.output_conv.weight.zero_() 183 | else: 184 | nn.init.kaiming_normal_(self.output_conv.weight.data, nonlinearity="linear") 185 | 186 | def forward(self, x: Tensor) -> Tensor: 187 | if self.blindspot: 188 | rotated = [ssdn.utils.rotate(x, rot) for rot in (0, 90, 180, 270)] 189 | x = torch.cat((rotated), dim=0) 190 | 191 | # Encoder 192 | pool1 = self.encode_block_1(x) 193 | pool2 = self.encode_block_2(pool1) 194 | pool3 = self.encode_block_3(pool2) 195 | pool4 = self.encode_block_4(pool3) 196 | pool5 = self.encode_block_5(pool4) 197 | encoded = self.encode_block_6(pool5) 198 | 199 | # Decoder 200 | upsample5 = self.decode_block_6(encoded) 201 | concat5 = torch.cat((upsample5, pool4), dim=1) 202 | upsample4 = self.decode_block_5(concat5) 203 | concat4 = torch.cat((upsample4, pool3), dim=1) 204 | upsample3 = self.decode_block_4(concat4) 205 | concat3 = torch.cat((upsample3, pool2), dim=1) 206 | upsample2 = self.decode_block_3(concat3) 207 | concat2 = torch.cat((upsample2, pool1), dim=1) 208 | upsample1 = self.decode_block_2(concat2) 209 | concat1 = torch.cat((upsample1, x), dim=1) 210 | x = self.decode_block_1(concat1) 211 | 212 | # Output 213 | if self.blindspot: 214 | # Apply shift 215 | shifted = self.shift(x) 216 | # Unstack, rotate and combine 217 | rotated_batch = torch.chunk(shifted, 4, dim=0) 218 | aligned = [ 219 | ssdn.utils.rotate(rotated, rot) 220 | for rotated, rot in zip(rotated_batch, (0, 270, 180, 90)) 221 | ] 222 | x = torch.cat(aligned, dim=1) 223 | 224 | x = self.output_block(x) 225 | 226 | return x 227 | 228 | @staticmethod 229 | def input_wh_mul() -> int: 230 | """Multiple that both the width and height dimensions of an input must be to be 231 | processed by the network. This is devised from the number of pooling layers that 232 | reduce the input size. 233 | 234 | Returns: 235 | int: Dimension multiplier 236 | """ 237 | max_pool_layers = 5 238 | return 2 ** max_pool_layers 239 | 240 | 241 | class ShiftConv2d(nn.Conv2d): 242 | def __init__(self, *args, **kwargs): 243 | """Custom convolution layer as defined by Laine et al. for restricting the 244 | receptive field of a convolution layer to only be upwards. For a h × w kernel, 245 | a downwards offset of k = [h/2] pixels is used. This is applied as a k sized pad 246 | to the top of the input before applying the convolution. The bottom k rows are 247 | cropped out for output. 248 | """ 249 | super().__init__(*args, **kwargs) 250 | self.shift_size = (self.kernel_size[0] // 2, 0) 251 | # Use individual layers of shift for wrapping conv with shift 252 | shift = Shift2d(self.shift_size) 253 | self.pad = shift.pad 254 | self.crop = shift.crop 255 | 256 | def forward(self, x: Tensor) -> Tensor: 257 | x = self.pad(x) 258 | x = super().forward(x) 259 | x = self.crop(x) 260 | return x 261 | -------------------------------------------------------------------------------- /ssdn/ssdn/models/utility.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch import Tensor 4 | from typing import Tuple 5 | 6 | 7 | class Crop2d(nn.Module): 8 | """Crop input using slicing. Assumes BCHW data. 9 | 10 | Args: 11 | crop (Tuple[int, int, int, int]): Amounts to crop from each side of the image. 12 | Tuple is treated as [left, right, top, bottom]/ 13 | """ 14 | 15 | def __init__(self, crop: Tuple[int, int, int, int]): 16 | super().__init__() 17 | self.crop = crop 18 | assert len(crop) == 4 19 | 20 | def forward(self, x: Tensor) -> Tensor: 21 | (left, right, top, bottom) = self.crop 22 | x0, x1 = left, x.shape[-1] - right 23 | y0, y1 = top, x.shape[-2] - bottom 24 | return x[:, :, y0:y1, x0:x1] 25 | 26 | 27 | class Shift2d(nn.Module): 28 | """Shift an image in either or both of the vertical and horizontal axis by first 29 | zero padding on the opposite side that the image is shifting towards before 30 | cropping the side being shifted towards. 31 | 32 | Args: 33 | shift (Tuple[int, int]): Tuple of vertical and horizontal shift. Positive values 34 | shift towards right and bottom, negative values shift towards left and top. 35 | """ 36 | 37 | def __init__(self, shift: Tuple[int, int]): 38 | super().__init__() 39 | self.shift = shift 40 | vert, horz = self.shift 41 | y_a, y_b = abs(vert), 0 42 | x_a, x_b = abs(horz), 0 43 | if vert < 0: 44 | y_a, y_b = y_b, y_a 45 | if horz < 0: 46 | x_a, x_b = x_b, x_a 47 | # Order : Left, Right, Top Bottom 48 | self.pad = nn.ZeroPad2d((x_a, x_b, y_a, y_b)) 49 | self.crop = Crop2d((x_b, x_a, y_b, y_a)) 50 | self.shift_block = nn.Sequential(self.pad, self.crop) 51 | 52 | def forward(self, x: Tensor) -> Tensor: 53 | return self.shift_block(x) 54 | -------------------------------------------------------------------------------- /ssdn/ssdn/params.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum, auto 4 | from typing import List 5 | 6 | 7 | class NoiseAlgorithm(Enum): 8 | SELFSUPERVISED_DENOISING = "ssdn" 9 | SELFSUPERVISED_DENOISING_MEAN_ONLY = "ssdn_u_only" 10 | NOISE_TO_NOISE = "n2n" 11 | NOISE_TO_CLEAN = "n2c" 12 | NOISE_TO_VOID = "n2v" # Unsupported 13 | 14 | 15 | class NoiseValue(Enum): 16 | UNKNOWN_CONSTANT = "const" 17 | UNKNOWN_VARIABLE = "var" 18 | KNOWN = "known" 19 | 20 | 21 | class Pipeline(Enum): 22 | MSE = "mse" 23 | SSDN = "ssdn" 24 | MASK_MSE = "mask_mse" 25 | 26 | 27 | class Blindspot(Enum): 28 | ENABLED = "blindspot" 29 | DISABLED = "normal" 30 | 31 | 32 | class ConfigValue(Enum): 33 | INFER_CFG = auto() 34 | ALGORITHM = auto() 35 | BLINDSPOT = auto() 36 | PIPELINE = auto() 37 | IMAGE_CHANNELS = auto() 38 | 39 | NOISE_STYLE = auto() 40 | 41 | LEARNING_RATE = auto() 42 | LR_RAMPUP_FRACTION = auto() 43 | LR_RAMPDOWN_FRACTION = auto() 44 | 45 | NOISE_VALUE = auto() 46 | DIAGONAL_COVARIANCE = auto() 47 | 48 | EVAL_INTERVAL = auto() 49 | PRINT_INTERVAL = auto() 50 | SNAPSHOT_INTERVAL = auto() 51 | TRAIN_ITERATIONS = auto() 52 | 53 | DATALOADER_WORKERS = auto() 54 | TRAIN_DATASET_NAME = auto() 55 | TRAIN_DATASET_TYPE = auto() 56 | TRAIN_DATA_PATH = auto() 57 | TRAIN_PATCH_SIZE = auto() 58 | TRAIN_MINIBATCH_SIZE = auto() 59 | 60 | TEST_DATASET_NAME = auto() 61 | TEST_DATASET_TYPE = auto() 62 | TEST_DATA_PATH = auto() 63 | TEST_MINIBATCH_SIZE = auto() 64 | PIN_DATA_MEMORY = auto() 65 | 66 | 67 | class DatasetType(Enum): 68 | HDF5 = auto() 69 | FOLDER = auto() 70 | 71 | 72 | class StateValue(Enum): 73 | INITIALISED = auto() 74 | MODE = auto() 75 | 76 | ITERATION = auto() 77 | REFERENCE = auto() 78 | HISTORY = auto() 79 | 80 | 81 | class HistoryValue(Enum): 82 | TRAIN = auto() 83 | EVAL = auto() 84 | TIMINGS = auto() 85 | 86 | 87 | class PipelineOutput(Enum): 88 | INPUTS = auto() 89 | LOSS = "loss" 90 | IMG_DENOISED = "out" 91 | IMG_MU = "out_mu" 92 | NOISE_STD_DEV = "noise_std" 93 | MODEL_STD_DEV = "model_std" 94 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ssdn.utils.utils import * 2 | from ssdn.utils.data import * 3 | from ssdn.utils import noise, transforms 4 | from ssdn.utils import n2v_ups, n2v_loss -------------------------------------------------------------------------------- /ssdn/ssdn/utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | 8 | from PIL import Image 9 | from ssdn.utils.data_format import ( 10 | DataFormat, 11 | DataDim, 12 | DATA_FORMAT_DIM_INDEX, 13 | permute_tuple, 14 | batch, 15 | unbatch, 16 | ) 17 | 18 | 19 | def clip_img(img: Tensor, inplace: bool = False) -> Tensor: 20 | """Clip tensor data so that it is within a valid image range. That is 21 | between 0-1 for float images and 0-255 for int images. Values are clamped 22 | meaning any values outside this range are set to the limits, other values 23 | are not touched. 24 | 25 | Args: 26 | img (Tensor): Image or batch of images to clip. 27 | inplace (bool, optional): Whether to do the operation in place. 28 | Defaults to False; this will first clone the data. 29 | 30 | Returns: 31 | Tensor: Reference to input image or new image. 32 | """ 33 | if not inplace: 34 | img = img.clone() 35 | if img.is_floating_point(): 36 | c_min, c_max = (0, 1) 37 | else: 38 | c_min, c_max = (0, 255) 39 | return torch.clamp_(img, c_min, c_max) 40 | 41 | 42 | def rotate( 43 | x: torch.Tensor, angle: int, data_format: str = DataFormat.BCHW 44 | ) -> torch.Tensor: 45 | """Rotate images by 90 degrees clockwise. Can handle any 2D data format. 46 | Args: 47 | x (Tensor): Image or batch of images. 48 | angle (int): Clockwise rotation angle in multiples of 90. 49 | data_format (str, optional): Format of input image data, e.g. BCHW, 50 | HWC. Defaults to BCHW. 51 | Returns: 52 | Tensor: Copy of tensor with rotation applied. 53 | """ 54 | dims = DATA_FORMAT_DIM_INDEX[data_format] 55 | h_dim = dims[DataDim.HEIGHT] 56 | w_dim = dims[DataDim.WIDTH] 57 | 58 | if angle == 0: 59 | return x 60 | elif angle == 90: 61 | return x.flip(w_dim).transpose(h_dim, w_dim) 62 | elif angle == 180: 63 | return x.flip(w_dim).flip(h_dim) 64 | elif angle == 270: 65 | return x.flip(h_dim).transpose(h_dim, w_dim) 66 | else: 67 | raise NotImplementedError("Must be rotation divisible by 90 degrees") 68 | 69 | 70 | def tensor2image(img: Tensor, data_format: str = DataFormat.CHW) -> Image: 71 | img = img.cpu().detach() 72 | # Create a grid of images if batched 73 | if isinstance(img, list) or len(img.shape) == 4: 74 | img = img.permute(permute_tuple(batch(data_format), DataFormat.BCHW)) 75 | img = torchvision.utils.make_grid(img) 76 | img = img.permute(permute_tuple(DataFormat.CHW, unbatch(data_format))) 77 | 78 | np_img = img.numpy() 79 | np_img = np.clip(np_img, 0, 1) 80 | np_img = np_img.transpose(*permute_tuple(data_format, DataFormat.WHC)) 81 | channels = np_img.shape[-1] 82 | if channels == 3: 83 | mode = "RGB" 84 | elif channels == 1: 85 | mode = "L" 86 | np_img = np.squeeze(np_img) 87 | else: 88 | raise NotImplementedError( 89 | "Cannot convert image with {} channels to PIL image.".format(channels) 90 | ) 91 | return Image.fromarray(np.uint8(np_img * 255), mode=mode) 92 | 93 | 94 | def mse2psnr(mse: Tensor, float_imgs: bool = True): 95 | high_val = torch.tensor(1.0) if float_imgs else torch.tensor(255) 96 | return 20 * torch.log10(high_val) - 10 * torch.log10(mse) 97 | 98 | 99 | def calculate_psnr(img: Tensor, ref: Tensor, data_format: str = DataFormat.BCHW): 100 | dim_indexes = dict(DATA_FORMAT_DIM_INDEX[data_format]) # shallow copy 101 | dim_indexes.pop(DataDim.BATCH, None) 102 | dims = tuple(dim_indexes.values()) 103 | mse = F.mse_loss(img, ref, reduction="none") 104 | mse = torch.mean(mse, dim=dims) 105 | return mse2psnr(mse, img.is_floating_point()) 106 | 107 | 108 | def show_tensor_image(img: Tensor, data_format: str = DataFormat.CHW): 109 | pil_img = tensor2image(img, data_format=data_format) 110 | pil_img.show() 111 | 112 | 113 | def save_tensor_image(img: Tensor, path: str, data_format: str = DataFormat.CHW): 114 | pil_img = tensor2image(img, data_format=data_format) 115 | pil_img.save(path) 116 | 117 | 118 | def set_color_channels(img: Image, channels: int) -> Image: 119 | cur_channels = len(img.getbands()) 120 | if cur_channels != channels: 121 | if channels == 1: 122 | return img.convert("L") 123 | if channels == 3: 124 | return img.convert("RGB") 125 | return img 126 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/data_format.py: -------------------------------------------------------------------------------- 1 | import ssdn.utils.utils as utils 2 | 3 | from enum import Enum, auto 4 | from collections import OrderedDict 5 | from typing import Dict, Tuple 6 | 7 | 8 | class DataDim(Enum): 9 | BATCH = auto() 10 | CHANNEL = auto() 11 | WIDTH = auto() 12 | HEIGHT = auto() 13 | 14 | 15 | DIM_CHAR_DICT = { 16 | DataDim.BATCH: "B", 17 | DataDim.CHANNEL: "C", 18 | DataDim.HEIGHT: "H", 19 | DataDim.WIDTH: "W", 20 | } 21 | """ Enumeration association to char representations. 22 | """ 23 | 24 | CHAR_DIM_DICT = dict((v, k) for k, v in DIM_CHAR_DICT.items()) 25 | """ Character association to enumeration representations. 26 | """ 27 | 28 | 29 | def batch(data_format: str) -> str: 30 | """ Append batching to a format if it is not already there. Assume LHS. 31 | """ 32 | if DIM_CHAR_DICT[DataDim.BATCH] not in data_format: 33 | return DIM_CHAR_DICT[DataDim.BATCH] + data_format 34 | else: 35 | return data_format 36 | 37 | 38 | def unbatch(data_format: str) -> str: 39 | """ Append batching to a format if it is not already there. Assume LHS. 40 | """ 41 | return data_format.replace(DIM_CHAR_DICT[DataDim.BATCH], "") 42 | 43 | 44 | class DataFormat: 45 | BHWC = "BHWC" 46 | BWHC = "BWHC" 47 | BCHW = "BCHW" 48 | BCWH = "BCWH" 49 | HWC = "HWC" 50 | WHC = "WHC" 51 | CHW = "CHW" 52 | CWH = "CWH" 53 | 54 | 55 | PIL_FORMAT = DataFormat.CWH 56 | PIL_BATCH_FORMAT = DataFormat.BCWH 57 | """ Formats used by Pillow/PIL. 58 | """ 59 | 60 | 61 | DATA_FORMAT_INDEX_DIM = {} 62 | """ Storage for pre-defined dimension format dictionaries that map 63 | axis index to dimension type. 64 | """ 65 | 66 | DATA_FORMAT_DIM_INDEX = {} 67 | """ Storage for pre-defined dimension format dictionaries that map 68 | dimension type to axis index. 69 | """ 70 | 71 | 72 | def make_index_dim_dict(data_format: str) -> Dict: 73 | dim_dict = OrderedDict() 74 | for i, c in enumerate(data_format): 75 | dim_dict[i] = CHAR_DIM_DICT[c] 76 | return dim_dict 77 | 78 | 79 | def make_dim_index_dict(data_format: str) -> Dict: 80 | dim_dict = OrderedDict() 81 | for i, c in enumerate(data_format): 82 | dim_dict[CHAR_DIM_DICT[c]] = i 83 | return dim_dict 84 | 85 | 86 | def add_format(data_format: str): 87 | global DATA_FORMAT_INDEX_DIM 88 | DATA_FORMAT_INDEX_DIM[data_format] = make_index_dim_dict(data_format) 89 | global DATA_FORMAT_DIM_INDEX 90 | DATA_FORMAT_DIM_INDEX[data_format] = make_dim_index_dict(data_format) 91 | 92 | 93 | # Create dictionary entries for all formats in DataFormat class 94 | for data_format in utils.list_constants(DataFormat): 95 | add_format(data_format) 96 | 97 | 98 | def permute_tuple(cur: str, target: str) -> Tuple[int]: 99 | assert sorted(cur) == sorted(target) 100 | 101 | # Ensure reference dictionaries exist 102 | if cur not in DATA_FORMAT_INDEX_DIM: 103 | add_format(cur) 104 | if target not in DATA_FORMAT_DIM_INDEX: 105 | add_format(target) 106 | 107 | dims_cur = DATA_FORMAT_DIM_INDEX[cur] 108 | dims_target = DATA_FORMAT_DIM_INDEX[target] 109 | transpose = [dims_cur[target] for target in dims_target.keys()] 110 | return tuple(transpose) 111 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/n2v_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | 6 | def loss_mask_mse( 7 | masked_coords: Tensor, 8 | input: Tensor, 9 | target: Tensor 10 | ): 11 | mse = 0 12 | coords = masked_coords.tolist()[0] 13 | for coord in coords: 14 | x, y = coord 15 | diff = target[:, :, x, y] - input[:, :, x, y] 16 | mse += (diff ** 2) 17 | return mse 18 | 19 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/n2v_ups.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | from torch import Tensor 6 | 7 | def manipulate( 8 | image: Tensor, 9 | subpatch_size: int = 5, 10 | inplace: bool = False 11 | ): 12 | """ 13 | Take sub-patches from within an image at random coordinates 14 | to replace a given percentage of central pixels with a random 15 | pixel within the sub-patch (aka Uniform Pixel Selection). 16 | 17 | Args: 18 | image (Tensor): a 64x64 image tensor. 19 | subpatch_size (Number): the size of the sub-patch. 20 | 21 | Returns: 22 | Tensor: the image after the Uniform Pixel Selection. 23 | """ 24 | if subpatch_size % 2 == 0: 25 | raise ValueError("subpatch_size must be odd") 26 | 27 | if not inplace: 28 | image = image.clone() 29 | 30 | image_x = image.shape[2] 31 | image_y = image.shape[1] 32 | subpatch_radius = math.floor(subpatch_size / 2) 33 | coords = get_stratified_coords((image_x, image_y)) 34 | 35 | mask_coords = [] 36 | for coord in zip(*coords): 37 | x, y = coord 38 | mask_coords.append((x,y)) 39 | 40 | min_x = min([x - subpatch_radius, 0]) 41 | max_x = min([x + subpatch_radius, image_x - 1]) 42 | min_y = min([y - subpatch_radius, 0]) 43 | max_y = min([y + subpatch_radius, image_y -1]) 44 | 45 | rand_x = rand_num_exclude(min_x, max_x, [x]) 46 | rand_y = rand_num_exclude(min_y, max_y, [y]) 47 | # Now replace pixel at x,y with pixel from rand_x,rand_y 48 | image[:, y, x] = image[:, rand_y, rand_x] 49 | return image, torch.tensor(mask_coords) 50 | 51 | def rand_num_exclude(_min: int, _max: int, exclude: list): 52 | """ 53 | Get a random integer in a range but excluding any in the list 54 | or integers to be excluded. 55 | 56 | Args: 57 | _min (Number): minimum integer value. 58 | _max (Number): maximum integer value. 59 | exclude (list): list of integers to be excluded. 60 | """ 61 | rand = torch.randint(_min, _max, (1,))[0] 62 | return rand_num_exclude(_min, _max, exclude) if rand in exclude else rand 63 | 64 | 65 | def get_stratified_coords(shape): 66 | """ 67 | Args: 68 | shape (Tuple[Number,Number]): the shape of the input patch 69 | 70 | Credit: https://github.com/juglab/n2v 71 | """ 72 | perc_pix = 1.5 # TODO put in config 73 | box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int) 74 | coord_gen = get_random_coords(box_size) 75 | 76 | box_count_y = int(np.ceil(shape[0] / box_size)) 77 | box_count_x = int(np.ceil(shape[1] / box_size)) 78 | x_coords = [] 79 | y_coords = [] 80 | for i in range(box_count_y): 81 | for j in range(box_count_x): 82 | y, x = next(coord_gen) 83 | y = int(i * box_size + y) 84 | x = int(j * box_size + x) 85 | if (y < shape[0] and x < shape[1]): 86 | y_coords.append(y) 87 | x_coords.append(x) 88 | return (y_coords, x_coords) 89 | 90 | 91 | def get_random_coords(box_size): 92 | """ 93 | Credit: https://github.com/juglab/n2v 94 | """ 95 | while True: 96 | # yield used so can call next() on this to get next random coords :D ez 97 | yield (torch.rand(1) * box_size, torch.rand(1) * box_size) 98 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/noise.py: -------------------------------------------------------------------------------- 1 | """Contains noise methods for use with batched image inputs. 2 | """ 3 | 4 | import torch 5 | import ssdn 6 | import re 7 | 8 | from torch import Tensor 9 | from torch.distributions import Uniform, Poisson 10 | from numbers import Number 11 | from typing import Union, Tuple 12 | 13 | 14 | def add_gaussian( 15 | tensor: Tensor, 16 | std_dev: Union[Number, Tuple[Number, Number]], 17 | mean: Number = 0, 18 | inplace: bool = False, 19 | clip: bool = True, 20 | ) -> Tuple[Tensor, Union[Number, Tensor]]: 21 | """Adds Gaussian noise to a batch of input images. 22 | 23 | Args: 24 | tensor (Tensor): Tensor to add noise to; this should be in a B*** format, e.g. BCHW. 25 | std_dev (Union[Number, Tuple[Number, Number]]): Standard deviation of noise being 26 | added. If a Tuple is provided then a standard deviation pulled from the 27 | uniform distribution between the two value is used for each batched input (B***). 28 | If the input value(s) are integers they will be divided by 255 inline with the 29 | input image dynamic ranges. 30 | mean (Number, optional): Mean of noise being added. Defaults to 0. 31 | inplace (bool, optional): Whether to add the noise in-place. Defaults to False. 32 | clip (bool, optional): Whether to clip between image bounds (0.0-1.0 or 0-255). 33 | Defaults to True. 34 | 35 | Returns: 36 | Tuple[Tensor, Union[Number, Tensor]]: Tuple containing: 37 | * Copy of or reference to input tensor with noise added. 38 | * Standard deviation (SD) used for noise generation. This will be an array of 39 | the different SDs used if a range of SDs are being used. 40 | """ 41 | if not inplace: 42 | tensor = tensor.clone() 43 | 44 | if isinstance(std_dev, (list, tuple)): 45 | if len(std_dev) == 1: 46 | std_dev = std_dev[0] 47 | else: 48 | assert len(std_dev) == 2 49 | (min_std_dev, max_std_dev) = std_dev 50 | if isinstance(min_std_dev, int): 51 | min_std_dev /= 255 52 | if isinstance(max_std_dev, int): 53 | max_std_dev /= 255 54 | uniform_generator = Uniform(min_std_dev, max_std_dev) 55 | shape = [tensor.shape[0]] + [1] * (len(tensor.shape) - 1) 56 | std_dev = uniform_generator.sample(shape) 57 | if isinstance(std_dev, int): 58 | std_dev = std_dev / 255 59 | tensor = tensor.add_(torch.randn(tensor.size()) * std_dev + mean) 60 | if clip: 61 | tensor = ssdn.utils.clip_img(tensor, inplace=True) 62 | 63 | return tensor, std_dev 64 | 65 | 66 | def add_poisson( 67 | tensor: Tensor, 68 | lam: Union[Number, Tuple[Number, Number]], 69 | inplace: bool = False, 70 | clip: bool = True, 71 | ) -> Tuple[Tensor, Union[Number, Tensor]]: 72 | """Adds Poisson noise to a batch of input images. 73 | 74 | Args: 75 | tensor (Tensor): Tensor to add noise to; this should be in a B*** format, e.g. BCHW. 76 | lam (Union[Number, Tuple[Number, Number]]): Distribution rate parameter (lambda) for 77 | noise being added. If a Tuple is provided then the lambda is pulled from the 78 | uniform distribution between the two value is used for each batched input (B***). 79 | inplace (bool, optional): Whether to add the noise in-place. Defaults to False. 80 | clip (bool, optional): Whether to clip between image bounds (0.0-1.0 or 0-255). 81 | Defaults to True. 82 | 83 | Returns: 84 | Tuple[Tensor, Union[Number, Tensor]]: Tuple containing: 85 | * Copy of or reference to input tensor with noise added. 86 | * Lambda used for noise generation. This will be an array of the different 87 | lambda used if a range of lambda are being used. 88 | """ 89 | if not inplace: 90 | tensor = tensor.clone() 91 | 92 | if isinstance(lam, (list, tuple)): 93 | if len(lam) == 1: 94 | lam = lam[0] 95 | else: 96 | assert len(lam) == 2 97 | (min_lam, max_lam) = lam 98 | uniform_generator = Uniform(min_lam, max_lam) 99 | shape = [tensor.shape[0]] + [1] * (len(tensor.shape) - 1) 100 | lam = uniform_generator.sample(shape) 101 | tensor.mul_(lam) 102 | poisson_generator = Poisson(torch.tensor(1, dtype=float)) 103 | noise = poisson_generator.sample(tensor.shape) 104 | tensor.add_(noise) 105 | tensor.div_(lam) 106 | if clip: 107 | tensor = ssdn.utils.clip_img(tensor, inplace=True) 108 | 109 | return tensor, lam 110 | 111 | 112 | def add_style( 113 | images: Tensor, style: str, inplace: bool = False 114 | ) -> Tuple[Tensor, Union[Number, Tensor]]: 115 | """Adds a style using a string configuration in the format: {noise_type}{args} 116 | where {args} are the arguments passed to the noise function. The formats for the 117 | supported noise types include 'gauss{SD}', 'gauss{MIN_SD}_{MAX_SD}', 'poisson{LAMBDA}', 118 | 'poisson{MIN_LAMBDA}_{MAX_LAMBDA}'. If parameters contain a decimal point they are 119 | treated as floats. This means the underlying noise adding method will not attempt to 120 | scale them. An extra optional parameter can be passed after noise arguments to disable 121 | clipping between normal image bounds (0.0-1.0 or 0-255): 'gauss{SD}_nc'. This is provided 122 | as as the original paper does not clip images at this point. 123 | 124 | Args: 125 | images (Tensor): Tensor to add noise to; this should be in a B*** format, e.g. BCHW. 126 | style (str): Style string. NotImplementedError will be thrown if the noise type is 127 | not valid. 128 | inplace (bool, optional): Whether to add the noise in-place. Defaults to False. 129 | 130 | Returns: 131 | Tuple[Tensor, Union[Number, Tensor]]: Tuple containing: 132 | * Copy or reference of input tensor with noise added. 133 | * Noise parameters from underlying noise generation. 134 | """ 135 | # Extract noise type 136 | noise_type = re.findall(r"[a-zA-Z]+", style)[0] 137 | params = [p for p in style.replace(noise_type, "").split("_")] 138 | # Extract clipping parameter 139 | clip = "nc" not in params 140 | params = [x for x in params if x != "nc" and x != ""] 141 | # Map remaining parameters to either floats or ints 142 | floats = any(map(lambda x: "." in x, params)) 143 | if floats: 144 | params = [float(p) for p in params] 145 | else: 146 | params = [int(p) for p in params] 147 | # Apply noise 148 | if noise_type == "gauss": 149 | return add_gaussian(images, params, inplace=inplace, clip=clip) 150 | elif noise_type == "poisson": 151 | return add_poisson(images, params, inplace=inplace, clip=clip) 152 | else: 153 | raise NotImplementedError("Noise type not supported") 154 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/pickle_fix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | 5 | from pathlib import Path 6 | from ssdn.utils import MetricDict 7 | from ssdn.params import HistoryValue, StateValue 8 | OrderedDefaultDict = MetricDict 9 | 10 | if __name__ == "__main__": 11 | """Some old training files were created using a class that appeared in the __main__ 12 | namespace, this breaks unpickling when that class is not found in __main__ 13 | Run this to replace the reference. All training files will be found recursively 14 | from the provided root directory 15 | """ 16 | if len(sys.argv) < 2: 17 | raise ValueError("Expected root path argument") 18 | 19 | root = sys.argv[1] 20 | for path in Path(root).rglob('*.training'): 21 | print(path) 22 | state_dict = torch.load(path, map_location="cpu") 23 | history = state_dict["state"][StateValue.HISTORY] 24 | history[HistoryValue.TRAIN] = MetricDict(history[HistoryValue.TRAIN]) 25 | history[HistoryValue.EVAL] = MetricDict(history[HistoryValue.EVAL]) 26 | bak_path = str(path) + "_bak" 27 | os.rename(path, bak_path) 28 | torch.save(state_dict, path) 29 | os.remove(bak_path) 30 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import ssdn 2 | from typing import NewType 3 | 4 | 5 | Transform = NewType("Transform", object) 6 | """Typing label for otherwise undefined Transform type. 7 | """ 8 | 9 | 10 | class NoiseTransform(object): 11 | def __init__(self, style: str): 12 | self.style = style 13 | 14 | def __call__(self, imgs): 15 | imgs, _ = ssdn.utils.noise.add_style(imgs, self.style) 16 | return imgs 17 | -------------------------------------------------------------------------------- /ssdn/ssdn/utils/utils.py: -------------------------------------------------------------------------------- 1 | """Generic utilities for use with package. 2 | """ 3 | 4 | __authors__ = "David Jones " 5 | 6 | import numpy as np 7 | import torch 8 | import re 9 | import os 10 | import time 11 | 12 | from typing import Any, List 13 | from torch import Tensor 14 | from contextlib import contextmanager 15 | from collections import OrderedDict 16 | 17 | 18 | def compute_ramped_lrate( 19 | i: int, 20 | iteration_count: int, 21 | ramp_up_fraction: float, 22 | ramp_down_fraction: float, 23 | learning_rate: float, 24 | ) -> float: 25 | if ramp_up_fraction > 0.0: 26 | ramp_up_end_iter = iteration_count * ramp_up_fraction 27 | if i <= ramp_up_end_iter: 28 | t = (i / ramp_up_fraction) / iteration_count 29 | learning_rate = learning_rate * (0.5 - np.cos(t * np.pi) / 2) 30 | 31 | if ramp_down_fraction > 0.0: 32 | ramp_down_start_iter = iteration_count * (1 - ramp_down_fraction) 33 | if i >= ramp_down_start_iter: 34 | t = ((i - ramp_down_start_iter) / ramp_down_fraction) / iteration_count 35 | learning_rate = learning_rate * (0.5 + np.cos(t * np.pi) / 2) ** 2 36 | 37 | return learning_rate 38 | 39 | 40 | def list_constants(clazz: Any, private: bool = False) -> List[Any]: 41 | """Fetch all values from variables formatted as constants in a class. 42 | 43 | Args: 44 | clazz (Any): Class to fetch constants from. 45 | 46 | Returns: 47 | List[Any]: List of values. 48 | """ 49 | variables = [i for i in dir(clazz) if not callable(i)] 50 | regex = re.compile(r"^{}[A-Z0-9_]*$".format("" if private else "[A-Z]")) 51 | names = list(filter(regex.match, variables)) 52 | values = [clazz.__dict__[name] for name in names] 53 | return values 54 | 55 | 56 | @contextmanager 57 | def cd(newdir: str): 58 | """Context manager for managing changes of directory where when the context is left 59 | the original directory is restored. 60 | 61 | Args: 62 | newdir (str): New directory to enter 63 | """ 64 | prevdir = os.getcwd() 65 | os.chdir(os.path.expanduser(newdir)) 66 | try: 67 | yield 68 | finally: 69 | os.chdir(prevdir) 70 | 71 | 72 | class TrackedTime: 73 | """Class for tracking an ongoing total time. Every update tracks the previous 74 | time for future updates. 75 | """ 76 | 77 | def __init__(self): 78 | self.total = 0 79 | self.last_time = None 80 | 81 | def update(self): 82 | """Update the total time with the time since the last tracked time. 83 | """ 84 | current_time = time.time() 85 | if self.last_time is not None: 86 | self.total += current_time - self.last_time 87 | self.last_time = current_time 88 | 89 | def forget(self): 90 | """Clear the last tracked time. 91 | """ 92 | self.last_time = None 93 | 94 | 95 | def seconds_to_dhms(seconds: float, trim: bool = True) -> str: 96 | """Convert time in seconds to a string of seconds, minutes, hours, days. 97 | 98 | Args: 99 | seconds (float): Time to convert. 100 | trim (bool, optional): Whether to remove leading time units if not needed. 101 | 102 | Returns: 103 | str: String representation of time. 104 | """ 105 | s = seconds % 60 106 | m = (seconds // 60) % 60 107 | h = seconds // (60 * 60) % 24 108 | d = seconds // (60 * 60 * 24) 109 | times = [(d, "d"), (h, "h"), (m, "m"), (s, "s")] 110 | time_str = "" 111 | for t, char in times: 112 | if trim and t < 1: 113 | continue 114 | trim = False 115 | time_str += "{:02}{}".format(int(t), char) 116 | return time_str 117 | 118 | 119 | class Metric: 120 | """ Only works if batch is in first dim. 121 | """ 122 | 123 | def __init__(self, batched: bool = True, collapse: bool = True): 124 | self.reset() 125 | self.batched = batched 126 | self.collapse = collapse 127 | 128 | def add(self, value: Tensor): 129 | n = value.shape[0] if self.batched else 1 130 | if self.collapse: 131 | data_start = 1 if self.batched else 0 132 | mean_dims = list(range(data_start, len(value.shape))) 133 | if len(mean_dims) > 0: 134 | value = torch.mean(value, dim=mean_dims) 135 | if self.batched: 136 | value = torch.sum(value, dim=0) 137 | if self.total is None: 138 | self.total = value 139 | else: 140 | self.total += value 141 | self.n += n 142 | 143 | def __add__(self, value: Tensor): 144 | self.add(value) 145 | return self 146 | 147 | def accumulated(self, reset: bool = False): 148 | if self.n == 0: 149 | return None 150 | acc = self.total / self.n 151 | if reset: 152 | self.reset() 153 | return acc 154 | 155 | def reset(self): 156 | self.total = None 157 | self.n = 0 158 | 159 | def empty(self) -> bool: 160 | return self.n == 0 161 | 162 | 163 | class MetricDict(OrderedDict): 164 | def __missing__(self, key): 165 | self[key] = value = Metric() 166 | return value 167 | 168 | 169 | def separator(cols=100) -> str: 170 | return "#" * cols 171 | -------------------------------------------------------------------------------- /ssdn/ssdn/version.py: -------------------------------------------------------------------------------- 1 | """Module containing the current library version. 2 | """ 3 | __version__ = "1.0" 4 | -------------------------------------------------------------------------------- /ssdn/tests/test_sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | from ssdn.datasets import FixedLengthSampler, SamplingOrder 3 | from typing import List, Any 4 | 5 | num_workers = 4 6 | 7 | 8 | class MockIndexDataset(Dataset): 9 | """Simple dataset which just returns the index that was read for 10 | tracking read order. 11 | 12 | Args: 13 | length (int): Maximum index to return. 14 | """ 15 | 16 | def __init__(self, length: int): 17 | self.length = length 18 | 19 | def __getitem__(self, index: int): 20 | return index 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | 26 | def read_dataloader(dataloader: DataLoader) -> List[int]: 27 | return [data for data in dataloader] 28 | 29 | 30 | def collapse_index_batches(index_batches: List[List[Any]]) -> List[Any]: 31 | indexes = [] 32 | for batch in index_batches: 33 | indexes.extend(batch) 34 | return indexes 35 | 36 | 37 | def split_list(list: List, n: int) -> List[List[Any]]: 38 | return [list[i * n : (i + 1) * n] for i in range((len(list) + n - 1) // n)] 39 | 40 | 41 | def check_sequential_with_reset(indexes: List[int], max_index: int): 42 | expect = 0 43 | for index in indexes: 44 | if index != expect: 45 | raise AssertionError("Got: {}, Expected: {}".format(index, expect)) 46 | if index + 1 >= max_index: 47 | expect = 0 48 | else: 49 | expect = index + 1 50 | 51 | 52 | def check_full_batches(index_batches: List[List[Any]], batch_size: int): 53 | if any(map(lambda x: len(x) != batch_size, index_batches)): 54 | raise AssertionError("Expected '{}' length index sets ".format(batch_size)) 55 | 56 | 57 | def check_index_repetition(indexes: List[int], max_index: int): 58 | # Ensure no repeats within shuffled subsection 59 | index_sets = split_list(indexes, max_index) 60 | if any(map(lambda x: len(set(x)) != len(x), index_sets)): 61 | raise AssertionError("An index was repeated within a subsection") 62 | 63 | 64 | def _test_sequential(dataset_length: int, num_samples: int): 65 | dataset = MockIndexDataset(dataset_length) 66 | sampler = FixedLengthSampler(dataset, num_samples=num_samples, shuffled=False) 67 | loader_params = { 68 | "batch_size": 1, 69 | "num_workers": num_workers, 70 | "drop_last": False, 71 | "sampler": sampler, 72 | } 73 | dataloader = DataLoader(dataset, **loader_params) 74 | indexes = read_dataloader(dataloader) 75 | check_sequential_with_reset(indexes, dataset_length) 76 | 77 | 78 | def test_sequential_oversampled(): 79 | _test_sequential(5, 8) 80 | 81 | 82 | def test_sequential_undersampled(): 83 | _test_sequential(20, 8) 84 | 85 | 86 | def test_sequential_source_length(): 87 | _test_sequential(20, None) 88 | 89 | 90 | def test_sequential_batched(): 91 | dataset_length = 5 92 | batch_size = 2 93 | dataset = MockIndexDataset(dataset_length) 94 | sampler = FixedLengthSampler(dataset, num_samples=9, shuffled=False) 95 | loader_params = { 96 | "batch_size": batch_size, 97 | "num_workers": num_workers, 98 | "drop_last": True, 99 | "sampler": sampler, 100 | } 101 | dataloader = DataLoader(dataset, **loader_params) 102 | index_batches = read_dataloader(dataloader) 103 | # Note drop_last == True ensures we only get full batches back 104 | check_full_batches(index_batches, batch_size) 105 | indexes = collapse_index_batches(index_batches) 106 | check_sequential_with_reset(indexes, dataset_length) 107 | 108 | 109 | def _test_shuffled(dataset_length: int, num_samples: int): 110 | dataset = MockIndexDataset(dataset_length) 111 | sampler = FixedLengthSampler(dataset, num_samples=num_samples, shuffled=True) 112 | loader_params = { 113 | "batch_size": 1, 114 | "num_workers": num_workers, 115 | "drop_last": False, 116 | "sampler": sampler, 117 | } 118 | dataloader = DataLoader(dataset, **loader_params) 119 | indexes = read_dataloader(dataloader) 120 | check_index_repetition(indexes, dataset_length) 121 | 122 | 123 | def test_shuffled_oversampled(): 124 | _test_shuffled(5, 8) 125 | 126 | 127 | def test_shuffled_undersampled(): 128 | _test_shuffled(20, 8) 129 | 130 | 131 | def test_shuffled_source_length(): 132 | _test_shuffled(20, None) 133 | 134 | 135 | def test_shuffled_batched(): 136 | dataset_length = 5 137 | batch_size = 2 138 | dataset = MockIndexDataset(dataset_length) 139 | sampler = FixedLengthSampler(dataset, num_samples=8, shuffled=True) 140 | loader_params = { 141 | "batch_size": batch_size, 142 | "num_workers": num_workers, 143 | "drop_last": True, 144 | "sampler": sampler, 145 | } 146 | dataloader = DataLoader(dataset, **loader_params) 147 | index_batches = read_dataloader(dataloader) 148 | # Note drop_last == True ensures we only get full batches back 149 | check_full_batches(index_batches, batch_size) 150 | indexes = collapse_index_batches(index_batches) 151 | check_index_repetition(indexes, dataset_length) 152 | 153 | 154 | def _test_state_saving(): 155 | dataset_length = 5 156 | batch_size = 2 157 | pause_batch = 2 158 | dataset = MockIndexDataset(dataset_length) 159 | sampler = FixedLengthSampler(dataset, num_samples=9, shuffled=False) 160 | loader_params = { 161 | "batch_size": batch_size, 162 | "num_workers": num_workers, 163 | "drop_last": False, 164 | "sampler": sampler, 165 | } 166 | dataloader = DataLoader(dataset, **loader_params) 167 | index_batches_ref = read_dataloader(dataloader) 168 | indexes_ref = collapse_index_batches(index_batches_ref) 169 | 170 | index_batches = [] 171 | for i, data in enumerate(dataloader): 172 | index_batches += [data] 173 | if i == pause_batch: 174 | break 175 | 176 | # State save, must update actual read count in case data has been loaded 177 | # that has not yet been used 178 | saved_iterator = sampler.last_iter() 179 | read_count = len(collapse_index_batches(index_batches)) 180 | sampler_state_dict = saved_iterator.state_dict() 181 | sampler_state_dict["index"] = read_count 182 | 183 | # Attempt to resume 184 | dataset = MockIndexDataset(dataset_length) 185 | dataloader = DataLoader(dataset, **loader_params) 186 | loaded_iterator = SamplingOrder.from_state_dict(sampler_state_dict) 187 | sampler.for_next_iter(loaded_iterator) 188 | for data in dataloader: 189 | index_batches += [data] 190 | 191 | # Resumed should match reference read in one iteration 192 | indexes_ref = collapse_index_batches(index_batches_ref) 193 | indexes = collapse_index_batches(index_batches) 194 | zipped = zip(index_batches_ref, index_batches) 195 | if not indexes_ref == indexes: 196 | raise AssertionError("Got: {}, Expected: {}".format(indexes, indexes_ref)) 197 | 198 | zipped = zip(index_batches_ref, index_batches) 199 | if not all(map(lambda x: all(x[0] == x[1]), zipped)): 200 | raise AssertionError( 201 | "Got: {}, Expected: {}".format(index_batches, index_batches_ref) 202 | ) 203 | 204 | 205 | def test_state_saving(): 206 | # Run test multiple times to account for multiple workers 207 | for i in range(5): 208 | _test_state_saving() 209 | --------------------------------------------------------------------------------