├── LICENSE.txt ├── README.md ├── data ├── __init__.py └── datasets.py ├── dataset ├── test │ └── download_testset.sh ├── train │ └── download_trainset.sh └── val │ └── download_valset.sh ├── demo.py ├── demo_dir.py ├── earlystop.py ├── eval.py ├── eval_config.py ├── examples ├── fake.png ├── real.png └── realfakedir │ ├── 0_real │ └── real.png │ └── 1_fake │ └── fake.png ├── networks ├── __init__.py ├── base_model.py ├── lpf.py ├── resnet.py ├── resnet_lpf.py └── trainer.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── train.py ├── util.py ├── validate.py └── weights └── download_weights.sh /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Sheng-Yu Wang, Oliver Wang, Richard Zhang, Andrew Owens, Alexei A. Efros. All rights reserved. 2 | 3 | 4 | Attribution-NonCommercial-ShareAlike 4.0 International 5 | 6 | ======================================================================= 7 | 8 | Creative Commons Corporation ("Creative Commons") is not a law firm and 9 | does not provide legal services or legal advice. Distribution of 10 | Creative Commons public licenses does not create a lawyer-client or 11 | other relationship. Creative Commons makes its licenses and related 12 | information available on an "as-is" basis. Creative Commons gives no 13 | warranties regarding its licenses, any material licensed under their 14 | terms and conditions, or any related information. Creative Commons 15 | disclaims all liability for damages resulting from their use to the 16 | fullest extent possible. 17 | 18 | Using Creative Commons Public Licenses 19 | 20 | Creative Commons public licenses provide a standard set of terms and 21 | conditions that creators and other rights holders may use to share 22 | original works of authorship and other material subject to copyright 23 | and certain other rights specified in the public license below. The 24 | following considerations are for informational purposes only, are not 25 | exhaustive, and do not form part of our licenses. 26 | 27 | Considerations for licensors: Our public licenses are 28 | intended for use by those authorized to give the public 29 | permission to use material in ways otherwise restricted by 30 | copyright and certain other rights. Our licenses are 31 | irrevocable. Licensors should read and understand the terms 32 | and conditions of the license they choose before applying it. 33 | Licensors should also secure all rights necessary before 34 | applying our licenses so that the public can reuse the 35 | material as expected. Licensors should clearly mark any 36 | material not subject to the license. This includes other CC- 37 | licensed material, or material used under an exception or 38 | limitation to copyright. More considerations for licensors: 39 | wiki.creativecommons.org/Considerations_for_licensors 40 | 41 | Considerations for the public: By using one of our public 42 | licenses, a licensor grants the public permission to use the 43 | licensed material under specified terms and conditions. If 44 | the licensor's permission is not necessary for any reason--for 45 | example, because of any applicable exception or limitation to 46 | copyright--then that use is not regulated by the license. Our 47 | licenses grant only permissions under copyright and certain 48 | other rights that a licensor has authority to grant. Use of 49 | the licensed material may still be restricted for other 50 | reasons, including because others have copyright or other 51 | rights in the material. A licensor may make special requests, 52 | such as asking that all changes be marked or described. 53 | Although not required by our licenses, you are encouraged to 54 | respect those requests where reasonable. More considerations 55 | for the public: 56 | wiki.creativecommons.org/Considerations_for_licensees 57 | 58 | ======================================================================= 59 | 60 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 61 | Public License 62 | 63 | By exercising the Licensed Rights (defined below), You accept and agree 64 | to be bound by the terms and conditions of this Creative Commons 65 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 66 | ("Public License"). To the extent this Public License may be 67 | interpreted as a contract, You are granted the Licensed Rights in 68 | consideration of Your acceptance of these terms and conditions, and the 69 | Licensor grants You such rights in consideration of benefits the 70 | Licensor receives from making the Licensed Material available under 71 | these terms and conditions. 72 | 73 | 74 | Section 1 -- Definitions. 75 | 76 | a. Adapted Material means material subject to Copyright and Similar 77 | Rights that is derived from or based upon the Licensed Material 78 | and in which the Licensed Material is translated, altered, 79 | arranged, transformed, or otherwise modified in a manner requiring 80 | permission under the Copyright and Similar Rights held by the 81 | Licensor. For purposes of this Public License, where the Licensed 82 | Material is a musical work, performance, or sound recording, 83 | Adapted Material is always produced where the Licensed Material is 84 | synched in timed relation with a moving image. 85 | 86 | b. Adapter's License means the license You apply to Your Copyright 87 | and Similar Rights in Your contributions to Adapted Material in 88 | accordance with the terms and conditions of this Public License. 89 | 90 | c. BY-NC-SA Compatible License means a license listed at 91 | creativecommons.org/compatiblelicenses, approved by Creative 92 | Commons as essentially the equivalent of this Public License. 93 | 94 | d. Copyright and Similar Rights means copyright and/or similar rights 95 | closely related to copyright including, without limitation, 96 | performance, broadcast, sound recording, and Sui Generis Database 97 | Rights, without regard to how the rights are labeled or 98 | categorized. For purposes of this Public License, the rights 99 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 100 | Rights. 101 | 102 | e. Effective Technological Measures means those measures that, in the 103 | absence of proper authority, may not be circumvented under laws 104 | fulfilling obligations under Article 11 of the WIPO Copyright 105 | Treaty adopted on December 20, 1996, and/or similar international 106 | agreements. 107 | 108 | f. Exceptions and Limitations means fair use, fair dealing, and/or 109 | any other exception or limitation to Copyright and Similar Rights 110 | that applies to Your use of the Licensed Material. 111 | 112 | g. License Elements means the license attributes listed in the name 113 | of a Creative Commons Public License. The License Elements of this 114 | Public License are Attribution, NonCommercial, and ShareAlike. 115 | 116 | h. Licensed Material means the artistic or literary work, database, 117 | or other material to which the Licensor applied this Public 118 | License. 119 | 120 | i. Licensed Rights means the rights granted to You subject to the 121 | terms and conditions of this Public License, which are limited to 122 | all Copyright and Similar Rights that apply to Your use of the 123 | Licensed Material and that the Licensor has authority to license. 124 | 125 | j. Licensor means the individual(s) or entity(ies) granting rights 126 | under this Public License. 127 | 128 | k. NonCommercial means not primarily intended for or directed towards 129 | commercial advantage or monetary compensation. For purposes of 130 | this Public License, the exchange of the Licensed Material for 131 | other material subject to Copyright and Similar Rights by digital 132 | file-sharing or similar means is NonCommercial provided there is 133 | no payment of monetary compensation in connection with the 134 | exchange. 135 | 136 | l. Share means to provide material to the public by any means or 137 | process that requires permission under the Licensed Rights, such 138 | as reproduction, public display, public performance, distribution, 139 | dissemination, communication, or importation, and to make material 140 | available to the public including in ways that members of the 141 | public may access the material from a place and at a time 142 | individually chosen by them. 143 | 144 | m. Sui Generis Database Rights means rights other than copyright 145 | resulting from Directive 96/9/EC of the European Parliament and of 146 | the Council of 11 March 1996 on the legal protection of databases, 147 | as amended and/or succeeded, as well as other essentially 148 | equivalent rights anywhere in the world. 149 | 150 | n. You means the individual or entity exercising the Licensed Rights 151 | under this Public License. Your has a corresponding meaning. 152 | 153 | 154 | Section 2 -- Scope. 155 | 156 | a. License grant. 157 | 158 | 1. Subject to the terms and conditions of this Public License, 159 | the Licensor hereby grants You a worldwide, royalty-free, 160 | non-sublicensable, non-exclusive, irrevocable license to 161 | exercise the Licensed Rights in the Licensed Material to: 162 | 163 | a. reproduce and Share the Licensed Material, in whole or 164 | in part, for NonCommercial purposes only; and 165 | 166 | b. produce, reproduce, and Share Adapted Material for 167 | NonCommercial purposes only. 168 | 169 | 2. Exceptions and Limitations. For the avoidance of doubt, where 170 | Exceptions and Limitations apply to Your use, this Public 171 | License does not apply, and You do not need to comply with 172 | its terms and conditions. 173 | 174 | 3. Term. The term of this Public License is specified in Section 175 | 6(a). 176 | 177 | 4. Media and formats; technical modifications allowed. The 178 | Licensor authorizes You to exercise the Licensed Rights in 179 | all media and formats whether now known or hereafter created, 180 | and to make technical modifications necessary to do so. The 181 | Licensor waives and/or agrees not to assert any right or 182 | authority to forbid You from making technical modifications 183 | necessary to exercise the Licensed Rights, including 184 | technical modifications necessary to circumvent Effective 185 | Technological Measures. For purposes of this Public License, 186 | simply making modifications authorized by this Section 2(a) 187 | (4) never produces Adapted Material. 188 | 189 | 5. Downstream recipients. 190 | 191 | a. Offer from the Licensor -- Licensed Material. Every 192 | recipient of the Licensed Material automatically 193 | receives an offer from the Licensor to exercise the 194 | Licensed Rights under the terms and conditions of this 195 | Public License. 196 | 197 | b. Additional offer from the Licensor -- Adapted Material. 198 | Every recipient of Adapted Material from You 199 | automatically receives an offer from the Licensor to 200 | exercise the Licensed Rights in the Adapted Material 201 | under the conditions of the Adapter's License You apply. 202 | 203 | c. No downstream restrictions. You may not offer or impose 204 | any additional or different terms or conditions on, or 205 | apply any Effective Technological Measures to, the 206 | Licensed Material if doing so restricts exercise of the 207 | Licensed Rights by any recipient of the Licensed 208 | Material. 209 | 210 | 6. No endorsement. Nothing in this Public License constitutes or 211 | may be construed as permission to assert or imply that You 212 | are, or that Your use of the Licensed Material is, connected 213 | with, or sponsored, endorsed, or granted official status by, 214 | the Licensor or others designated to receive attribution as 215 | provided in Section 3(a)(1)(A)(i). 216 | 217 | b. Other rights. 218 | 219 | 1. Moral rights, such as the right of integrity, are not 220 | licensed under this Public License, nor are publicity, 221 | privacy, and/or other similar personality rights; however, to 222 | the extent possible, the Licensor waives and/or agrees not to 223 | assert any such rights held by the Licensor to the limited 224 | extent necessary to allow You to exercise the Licensed 225 | Rights, but not otherwise. 226 | 227 | 2. Patent and trademark rights are not licensed under this 228 | Public License. 229 | 230 | 3. To the extent possible, the Licensor waives any right to 231 | collect royalties from You for the exercise of the Licensed 232 | Rights, whether directly or through a collecting society 233 | under any voluntary or waivable statutory or compulsory 234 | licensing scheme. In all other cases the Licensor expressly 235 | reserves any right to collect such royalties, including when 236 | the Licensed Material is used other than for NonCommercial 237 | purposes. 238 | 239 | 240 | Section 3 -- License Conditions. 241 | 242 | Your exercise of the Licensed Rights is expressly made subject to the 243 | following conditions. 244 | 245 | a. Attribution. 246 | 247 | 1. If You Share the Licensed Material (including in modified 248 | form), You must: 249 | 250 | a. retain the following if it is supplied by the Licensor 251 | with the Licensed Material: 252 | 253 | i. identification of the creator(s) of the Licensed 254 | Material and any others designated to receive 255 | attribution, in any reasonable manner requested by 256 | the Licensor (including by pseudonym if 257 | designated); 258 | 259 | ii. a copyright notice; 260 | 261 | iii. a notice that refers to this Public License; 262 | 263 | iv. a notice that refers to the disclaimer of 264 | warranties; 265 | 266 | v. a URI or hyperlink to the Licensed Material to the 267 | extent reasonably practicable; 268 | 269 | b. indicate if You modified the Licensed Material and 270 | retain an indication of any previous modifications; and 271 | 272 | c. indicate the Licensed Material is licensed under this 273 | Public License, and include the text of, or the URI or 274 | hyperlink to, this Public License. 275 | 276 | 2. You may satisfy the conditions in Section 3(a)(1) in any 277 | reasonable manner based on the medium, means, and context in 278 | which You Share the Licensed Material. For example, it may be 279 | reasonable to satisfy the conditions by providing a URI or 280 | hyperlink to a resource that includes the required 281 | information. 282 | 3. If requested by the Licensor, You must remove any of the 283 | information required by Section 3(a)(1)(A) to the extent 284 | reasonably practicable. 285 | 286 | b. ShareAlike. 287 | 288 | In addition to the conditions in Section 3(a), if You Share 289 | Adapted Material You produce, the following conditions also apply. 290 | 291 | 1. The Adapter's License You apply must be a Creative Commons 292 | license with the same License Elements, this version or 293 | later, or a BY-NC-SA Compatible License. 294 | 295 | 2. You must include the text of, or the URI or hyperlink to, the 296 | Adapter's License You apply. You may satisfy this condition 297 | in any reasonable manner based on the medium, means, and 298 | context in which You Share Adapted Material. 299 | 300 | 3. You may not offer or impose any additional or different terms 301 | or conditions on, or apply any Effective Technological 302 | Measures to, Adapted Material that restrict exercise of the 303 | rights granted under the Adapter's License You apply. 304 | 305 | 306 | Section 4 -- Sui Generis Database Rights. 307 | 308 | Where the Licensed Rights include Sui Generis Database Rights that 309 | apply to Your use of the Licensed Material: 310 | 311 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 312 | to extract, reuse, reproduce, and Share all or a substantial 313 | portion of the contents of the database for NonCommercial purposes 314 | only; 315 | 316 | b. if You include all or a substantial portion of the database 317 | contents in a database in which You have Sui Generis Database 318 | Rights, then the database in which You have Sui Generis Database 319 | Rights (but not its individual contents) is Adapted Material, 320 | including for purposes of Section 3(b); and 321 | 322 | c. You must comply with the conditions in Section 3(a) if You Share 323 | all or a substantial portion of the contents of the database. 324 | 325 | For the avoidance of doubt, this Section 4 supplements and does not 326 | replace Your obligations under this Public License where the Licensed 327 | Rights include other Copyright and Similar Rights. 328 | 329 | 330 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 331 | 332 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 333 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 334 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 335 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 336 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 337 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 338 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 339 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 340 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 341 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 342 | 343 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 344 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 345 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 346 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 347 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 348 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 349 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 350 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 351 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 352 | 353 | c. The disclaimer of warranties and limitation of liability provided 354 | above shall be interpreted in a manner that, to the extent 355 | possible, most closely approximates an absolute disclaimer and 356 | waiver of all liability. 357 | 358 | 359 | Section 6 -- Term and Termination. 360 | 361 | a. This Public License applies for the term of the Copyright and 362 | Similar Rights licensed here. However, if You fail to comply with 363 | this Public License, then Your rights under this Public License 364 | terminate automatically. 365 | 366 | b. Where Your right to use the Licensed Material has terminated under 367 | Section 6(a), it reinstates: 368 | 369 | 1. automatically as of the date the violation is cured, provided 370 | it is cured within 30 days of Your discovery of the 371 | violation; or 372 | 373 | 2. upon express reinstatement by the Licensor. 374 | 375 | For the avoidance of doubt, this Section 6(b) does not affect any 376 | right the Licensor may have to seek remedies for Your violations 377 | of this Public License. 378 | 379 | c. For the avoidance of doubt, the Licensor may also offer the 380 | Licensed Material under separate terms or conditions or stop 381 | distributing the Licensed Material at any time; however, doing so 382 | will not terminate this Public License. 383 | 384 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 385 | License. 386 | 387 | 388 | Section 7 -- Other Terms and Conditions. 389 | 390 | a. The Licensor shall not be bound by any additional or different 391 | terms or conditions communicated by You unless expressly agreed. 392 | 393 | b. Any arrangements, understandings, or agreements regarding the 394 | Licensed Material not stated herein are separate from and 395 | independent of the terms and conditions of this Public License. 396 | 397 | 398 | Section 8 -- Interpretation. 399 | 400 | a. For the avoidance of doubt, this Public License does not, and 401 | shall not be interpreted to, reduce, limit, restrict, or impose 402 | conditions on any use of the Licensed Material that could lawfully 403 | be made without permission under this Public License. 404 | 405 | b. To the extent possible, if any provision of this Public License is 406 | deemed unenforceable, it shall be automatically reformed to the 407 | minimum extent necessary to make it enforceable. If the provision 408 | cannot be reformed, it shall be severed from this Public License 409 | without affecting the enforceability of the remaining terms and 410 | conditions. 411 | 412 | c. No term or condition of this Public License will be waived and no 413 | failure to comply consented to unless expressly agreed to by the 414 | Licensor. 415 | 416 | d. Nothing in this Public License constitutes or may be interpreted 417 | as a limitation upon, or waiver of, any privileges and immunities 418 | that apply to the Licensor or You, including from the legal 419 | processes of any jurisdiction or authority. 420 | 421 | ======================================================================= 422 | 423 | Creative Commons is not a party to its public 424 | licenses. Notwithstanding, Creative Commons may elect to apply one of 425 | its public licenses to material it publishes and in those instances 426 | will be considered the “Licensor.” The text of the Creative Commons 427 | public licenses is dedicated to the public domain under the CC0 Public 428 | Domain Dedication. Except for the limited purpose of indicating that 429 | material is shared under a Creative Commons public license or as 430 | otherwise permitted by the Creative Commons policies published at 431 | creativecommons.org/policies, Creative Commons does not authorize the 432 | use of the trademark "Creative Commons" or any other trademark or logo 433 | of Creative Commons without its prior written consent including, 434 | without limitation, in connection with any unauthorized modifications 435 | to any of its public licenses or any other arrangements, 436 | understandings, or agreements concerning use of licensed material. For 437 | the avoidance of doubt, this paragraph does not form part of the 438 | public licenses. 439 | 440 | Creative Commons may be contacted at creativecommons.org. 441 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Detecting CNN-Generated Images [[Project Page]](https://peterwang512.github.io/CNNDetection/) 2 | 3 | **CNN-generated images are surprisingly easy to spot...for now** 4 | [Sheng-Yu Wang](https://peterwang512.github.io/), [Oliver Wang](http://www.oliverwang.info/), [Richard Zhang](https://richzhang.github.io/), [Andrew Owens](http://andrewowens.com/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/). 5 |
In [CVPR](https://arxiv.org/abs/1912.11035), 2020. 6 | 7 | 8 | 9 | This repository contains models, evaluation code, and training code on datasets from our paper. **If you would like to run our pretrained model on your image/dataset see [(2) Quick start](https://github.com/PeterWang512/CNNDetection#2-quick-start).** 10 | 11 | **Jun 20th 2020 Update** Training code and dataset released; test results on uncropped images added (recommended for best performance). 12 | 13 | **Oct 26th 2020 Update** Some reported the download link for training data does not work. If this happens, please try the updated alternative links: [1](https://drive.google.com/drive/u/2/folders/14E_R19lqIE9JgotGz09fLPQ4NVqlYbVc) and [2](https://cmu.app.box.com/folder/124997172518?s=4syr4womrggfin0tsfhxohaec5dh6n48) 14 | 15 | **Oct 18th 2021 Update** Our method gets 92% AUC on the recently released StyleGAN3 model! For more details, please visit this [link](https://github.com/NVlabs/stylegan3-detector). 16 | 17 | **Jul 24th, 2024 Update** Unfortunately, the previous Google Drive link for the dataset is no longer available. Please use this temporary download [link](https://drive.google.com/drive/folders/1RwCSaraEUctIwFgoQXWMKFvW07gM80_3?usp=drive_link). I am planning to host the dataset on Huggingface within a week. 18 | 19 | **Jul 26th, 2024 Update** The link has been fixed! Please follow the README to download the dataset. You will need to install 7z to prepare the dataset. For linux, run `sudo apt-get install p7zip-full` to install. 20 | 21 | ## (1) Setup 22 | 23 | ### Install packages 24 | - Install PyTorch ([pytorch.org](http://pytorch.org)) 25 | - `pip install -r requirements.txt` 26 | 27 | ### Download model weights 28 | - Run `bash weights/download_weights.sh` 29 | 30 | 31 | ## (2) Quick start 32 | 33 | ### Run on a single image 34 | 35 | This command runs the model on a single image, and outputs the uncalibrated prediction. 36 | 37 | ``` 38 | # Model weights need to be downloaded. 39 | python demo.py -f examples/real.png -m weights/blur_jpg_prob0.5.pth 40 | python demo.py -f examples/fake.png -m weights/blur_jpg_prob0.5.pth 41 | ``` 42 | 43 | ### Run on a dataset 44 | 45 | This command computes AP and accuracy on a dataset. See the [provided directory](examples/realfakedir) for an example. Put your real/fake images into the appropriate subfolders to test. 46 | 47 | ``` 48 | python demo_dir.py -d examples/realfakedir -m weights/blur_jpg_prob0.5.pth 49 | ``` 50 | 51 | ## (3) Dataset 52 | 53 | ### Testset 54 | The testset evaluated in the paper can be downloaded [here](https://drive.google.com/file/d/1z_fD3UKgWQyOTZIBbYSaQ-hz4AzUrLC1/view?usp=sharing). 55 | 56 | The zip file contains images from 13 CNN-based synthesis algorithms, including the 12 testsets from the paper and images downloaded from whichfaceisreal.com. Images from each algorithm are stored in a separate folder. In each category, real images are in the `0_real` folder, and synthetic images are in the `1_fake` folder. 57 | 58 | Note: ProGAN, StyleGAN, StyleGAN2, CycleGAN testset contains multiple classes, which are stored in separate subdirectories. 59 | 60 | ### Training set 61 | The training set used in the paper can be downloaded [here](https://drive.google.com/file/d/1iVNBV0glknyTYGA9bCxT_d0CVTOgGcKh/view?usp=sharing) (Try alternative links [1](https://drive.google.com/drive/u/2/folders/14E_R19lqIE9JgotGz09fLPQ4NVqlYbVc),[2](https://cmu.app.box.com/folder/124997172518?s=4syr4womrggfin0tsfhxohaec5dh6n48) if the previous link does not work). All images are from LSUN or generated by ProGAN, and they are separated in 20 object categories. Similarly, in each category, real images are in the `0_real` folder, and synthetic images are in the `1_fake` folder. 62 | 63 | ### Validation set 64 | The validation set consists of held-out ProGAN real and fake images, and can be downloaded [here](https://drive.google.com/file/d/1FU7xF8Wl_F8b0tgL0529qg2nZ_RpdVNL/view?usp=sharing). The directory structure is identical to that of the training set. 65 | 66 | ### Download the dataset 67 | Before downloading, install 7z if needed. 68 | ``` 69 | # Download script for linux 70 | sudo apt-get install p7zip-full 71 | ``` 72 | 73 | A script for downloading the dataset is as follows: 74 | ``` 75 | # Download the testset 76 | cd dataset/test 77 | bash download_testset.sh 78 | cd ../.. 79 | 80 | # Download the training set 81 | cd dataset/train 82 | bash download_trainset.sh 83 | cd ../.. 84 | 85 | # Download the validation set 86 | cd dataset/val 87 | bash download_valset.sh 88 | cd ../.. 89 | ``` 90 | 91 | **If the script doesn't work, an alternative will be to download the zip files manually from the above google drive links. One can place the testset, training, and validation set zip files in `dataset/test`, `dataset/train`, and `dataset/val` folders, respectively, and then unzip the zip files to set everything up.** 92 | 93 | ## (4) Train your models 94 | We provide two example scripts to train our `Blur+JPEG(0.5)` and `Blur+JPEG(0.1)` models. We use `checkpoints/[model_name]/model_epoch_best.pth` as our final model. 95 | ``` 96 | # Train Blur+JPEG(0.5) 97 | python train.py --name blur_jpg_prob0.5 --blur_prob 0.5 --blur_sig 0.0,3.0 --jpg_prob 0.5 --jpg_method cv2,pil --jpg_qual 30,100 --dataroot ./dataset/ --classes airplane,bird,bicycle,boat,bottle,bus,car,cat,cow,chair,diningtable,dog,person,pottedplant,motorbike,tvmonitor,train,sheep,sofa,horse 98 | 99 | # Train Blur+JPEG(0.1) 100 | python train.py --name blur_jpg_prob0.1 --blur_prob 0.1 --blur_sig 0.0,3.0 --jpg_prob 0.1 --jpg_method cv2,pil --jpg_qual 30,100 --dataroot ./dataset/ --classes airplane,bird,bicycle,boat,bottle,bus,car,cat,cow,chair,diningtable,dog,person,pottedplant,motorbike,tvmonitor,train,sheep,sofa,horse 101 | ``` 102 | 103 | ## (5) Evaluation 104 | 105 | After the testset and the model weights are downloaded, one can evaluate the models by running: 106 | 107 | ``` 108 | # Run evaluation script. Model weights need to be downloaded. See eval_config.py for flags 109 | python eval.py 110 | ``` 111 | 112 | Besides print-outs, the results will also be stored in a csv file in the `results` folder. Configurations such as the path of the dataset, model weight are in `eval_config.py`, and one can modify the evaluation by changing the configurations. 113 | 114 | 115 | **6/13/2020 Update** Additionally, we tested on uncropped images, and observed better performances on most categories. To evaluate without center-cropping: 116 | ``` 117 | # Run evaluation script without cropping. Model weights need to be downloaded. 118 | python eval.py --no_crop --batch_size 1 119 | ``` 120 | 121 | The following are the models' performances on the released set, with cropping to 224x224 (as in the paper), and without cropping. 122 | 123 | [Blur+JPEG(0.5)] 124 | 125 | |Testset |Acc (224)| AP (224) |Acc (No crop)| AP (No crop)| 126 | |:--------:|:------:|:----:|:------:|:----:| 127 | |ProGAN | 100.0% |100.0%| 100.0% |100.0%| 128 | |StyleGAN | 73.4% |98.5% | 77.5% |99.3% | 129 | |BigGAN | 59.0% |88.2% | 59.5% |90.4% | 130 | |CycleGAN | 80.8% |96.8% | 84.6% |97.9% | 131 | |StarGAN | 81.0% |95.4% | 84.7% |97.5% | 132 | |GauGAN | 79.3% |98.1% | 82.9% |98.8% | 133 | |CRN | 87.6% |98.9% | 97.8% |100.0% | 134 | |IMLE | 94.1% |99.5% | 98.8% |100.0% | 135 | |SITD | 78.3% |92.7% | 93.9% |99.6% | 136 | |SAN | 50.0% |63.9% | 50.0% |62.8% | 137 | |Deepfake | 51.1% |66.3% | 50.4% |63.1% | 138 | |StyleGAN2 | 68.4% |98.0% | 72.4% |99.1% | 139 | |Whichfaceisreal| 63.9% |88.8% | 75.2% |100.0% | 140 | 141 | 142 | [Blur+JPEG(0.1)] 143 | 144 | |Testset |Acc (224)| AP (224) |Acc (No crop)| AP (No crop)| 145 | |:--------:|:------:|:----:|:------:|:----:| 146 | |ProGAN |100.0% |100.0%| 100.0% |100.0%| 147 | |StyleGAN |87.1% |99.6%| 90.2% |99.8% | 148 | |BigGAN |70.2% |84.5%| 71.2% |86.0% | 149 | |CycleGAN |85.2% |93.5%| 87.6% |94.9% | 150 | |StarGAN |91.7% |98.2%| 94.6% |99.0% | 151 | |GauGAN |78.9% |89.5%| 81.4% |90.8% | 152 | |CRN |86.3% |98.2%| 86.3% |99.8% | 153 | |IMLE |86.2% |98.4%| 86.3% |99.8% | 154 | |SITD |90.3% |97.2%| 98.1% |99.8% | 155 | |SAN |50.5% |70.5%| 50.0% |68.6% | 156 | |Deepfake |53.5% |89.0%| 50.7% |84.5% | 157 | |StyleGAN2 |84.4% |99.1%| 86.9% |99.5% | 158 | |Whichfaceisreal|83.6% |93.2%| 91.6% |99.8%| 159 | 160 | ## (A) Acknowledgments 161 | 162 | This repository borrows partially from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and the PyTorch [torchvision models](https://github.com/pytorch/vision/tree/master/torchvision/models) repositories. 163 | 164 | ## (B) Citation, Contact 165 | 166 | If you find this useful for your research, please consider citing this [bibtex](https://peterwang512.github.io/CNNDetection/bibtex.txt). Please contact Sheng-Yu Wang \ with any comments or feedback. 167 | 168 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data.sampler import WeightedRandomSampler 4 | 5 | from .datasets import dataset_folder 6 | 7 | 8 | def get_dataset(opt): 9 | dset_lst = [] 10 | for cls in opt.classes: 11 | root = opt.dataroot + '/' + cls 12 | dset = dataset_folder(opt, root) 13 | dset_lst.append(dset) 14 | return torch.utils.data.ConcatDataset(dset_lst) 15 | 16 | 17 | def get_bal_sampler(dataset): 18 | targets = [] 19 | for d in dataset.datasets: 20 | targets.extend(d.targets) 21 | 22 | ratio = np.bincount(targets) 23 | w = 1. / torch.tensor(ratio, dtype=torch.float) 24 | sample_weights = w[targets] 25 | sampler = WeightedRandomSampler(weights=sample_weights, 26 | num_samples=len(sample_weights)) 27 | return sampler 28 | 29 | 30 | def create_dataloader(opt): 31 | shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False 32 | dataset = get_dataset(opt) 33 | sampler = get_bal_sampler(dataset) if opt.class_bal else None 34 | 35 | data_loader = torch.utils.data.DataLoader(dataset, 36 | batch_size=opt.batch_size, 37 | shuffle=shuffle, 38 | sampler=sampler, 39 | num_workers=int(opt.num_threads)) 40 | return data_loader 41 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torchvision.datasets as datasets 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as TF 6 | from random import random, choice 7 | from io import BytesIO 8 | from PIL import Image 9 | from PIL import ImageFile 10 | from scipy.ndimage.filters import gaussian_filter 11 | 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | def dataset_folder(opt, root): 16 | if opt.mode == 'binary': 17 | return binary_dataset(opt, root) 18 | if opt.mode == 'filename': 19 | return FileNameDataset(opt, root) 20 | raise ValueError('opt.mode needs to be binary or filename.') 21 | 22 | 23 | def binary_dataset(opt, root): 24 | if opt.isTrain: 25 | crop_func = transforms.RandomCrop(opt.cropSize) 26 | elif opt.no_crop: 27 | crop_func = transforms.Lambda(lambda img: img) 28 | else: 29 | crop_func = transforms.CenterCrop(opt.cropSize) 30 | 31 | if opt.isTrain and not opt.no_flip: 32 | flip_func = transforms.RandomHorizontalFlip() 33 | else: 34 | flip_func = transforms.Lambda(lambda img: img) 35 | if not opt.isTrain and opt.no_resize: 36 | rz_func = transforms.Lambda(lambda img: img) 37 | else: 38 | rz_func = transforms.Lambda(lambda img: custom_resize(img, opt)) 39 | 40 | dset = datasets.ImageFolder( 41 | root, 42 | transforms.Compose([ 43 | rz_func, 44 | transforms.Lambda(lambda img: data_augment(img, opt)), 45 | crop_func, 46 | flip_func, 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 49 | ])) 50 | return dset 51 | 52 | 53 | class FileNameDataset(datasets.ImageFolder): 54 | def name(self): 55 | return 'FileNameDataset' 56 | 57 | def __init__(self, opt, root): 58 | self.opt = opt 59 | super().__init__(root) 60 | 61 | def __getitem__(self, index): 62 | # Loading sample 63 | path, target = self.samples[index] 64 | return path 65 | 66 | 67 | def data_augment(img, opt): 68 | img = np.array(img) 69 | 70 | if random() < opt.blur_prob: 71 | sig = sample_continuous(opt.blur_sig) 72 | gaussian_blur(img, sig) 73 | 74 | if random() < opt.jpg_prob: 75 | method = sample_discrete(opt.jpg_method) 76 | qual = sample_discrete(opt.jpg_qual) 77 | img = jpeg_from_key(img, qual, method) 78 | 79 | return Image.fromarray(img) 80 | 81 | 82 | def sample_continuous(s): 83 | if len(s) == 1: 84 | return s[0] 85 | if len(s) == 2: 86 | rg = s[1] - s[0] 87 | return random() * rg + s[0] 88 | raise ValueError("Length of iterable s should be 1 or 2.") 89 | 90 | 91 | def sample_discrete(s): 92 | if len(s) == 1: 93 | return s[0] 94 | return choice(s) 95 | 96 | 97 | def gaussian_blur(img, sigma): 98 | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) 99 | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) 100 | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) 101 | 102 | 103 | def cv2_jpg(img, compress_val): 104 | img_cv2 = img[:,:,::-1] 105 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] 106 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) 107 | decimg = cv2.imdecode(encimg, 1) 108 | return decimg[:,:,::-1] 109 | 110 | 111 | def pil_jpg(img, compress_val): 112 | out = BytesIO() 113 | img = Image.fromarray(img) 114 | img.save(out, format='jpeg', quality=compress_val) 115 | img = Image.open(out) 116 | # load from memory before ByteIO closes 117 | img = np.array(img) 118 | out.close() 119 | return img 120 | 121 | 122 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} 123 | def jpeg_from_key(img, compress_val, key): 124 | method = jpeg_dict[key] 125 | return method(img, compress_val) 126 | 127 | 128 | rz_dict = {'bilinear': Image.BILINEAR, 129 | 'bicubic': Image.BICUBIC, 130 | 'lanczos': Image.LANCZOS, 131 | 'nearest': Image.NEAREST} 132 | def custom_resize(img, opt): 133 | interp = sample_discrete(opt.rz_interp) 134 | return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp]) 135 | -------------------------------------------------------------------------------- /dataset/test/download_testset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/CNN_synth_testset.zip 3 | 4 | unzip CNN_synth_testset.zip 5 | rm CNN_synth_testset.zip 6 | -------------------------------------------------------------------------------- /dataset/train/download_trainset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.001 & 3 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.002 & 4 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.003 & 5 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.004 & 6 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.005 & 7 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.006 & 8 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_train.7z.007 & 9 | wait $(jobs -p) 10 | 11 | 7z x progan_train.7z.001 12 | rm progan_train.7z.* 13 | unzip progan_train.zip 14 | rm progan_train.zip 15 | -------------------------------------------------------------------------------- /dataset/val/download_valset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://huggingface.co/datasets/sywang/CNNDetection/resolve/main/progan_val.zip 3 | 4 | unzip progan_val.zip 5 | rm progan_val.zip 6 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn 5 | import argparse 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | from PIL import Image 10 | from networks.resnet import resnet50 11 | 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | parser.add_argument('-f','--file', default='examples_realfakedir') 14 | parser.add_argument('-m','--model_path', type=str, default='weights/blur_jpg_prob0.5.pth') 15 | parser.add_argument('-c','--crop', type=int, default=None, help='by default, do not crop. specify crop size') 16 | parser.add_argument('--use_cpu', action='store_true', help='uses gpu by default, turn on to use cpu') 17 | 18 | opt = parser.parse_args() 19 | 20 | model = resnet50(num_classes=1) 21 | state_dict = torch.load(opt.model_path, map_location='cpu') 22 | model.load_state_dict(state_dict['model']) 23 | if(not opt.use_cpu): 24 | model.cuda() 25 | model.eval() 26 | 27 | # Transform 28 | trans_init = [] 29 | if(opt.crop is not None): 30 | trans_init = [transforms.CenterCrop(opt.crop),] 31 | print('Cropping to [%i]'%opt.crop) 32 | else: 33 | print('Not cropping') 34 | trans = transforms.Compose(trans_init + [ 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 37 | ]) 38 | 39 | img = trans(Image.open(opt.file).convert('RGB')) 40 | 41 | with torch.no_grad(): 42 | in_tens = img.unsqueeze(0) 43 | if(not opt.use_cpu): 44 | in_tens = in_tens.cuda() 45 | prob = model(in_tens).sigmoid().item() 46 | 47 | print('probability of being synthetic: {:.2f}%'.format(prob * 100)) 48 | -------------------------------------------------------------------------------- /demo_dir.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import csv 5 | import torch 6 | import torchvision.datasets as datasets 7 | import torchvision.transforms as transforms 8 | import torch.utils.data 9 | import numpy as np 10 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score 11 | 12 | from networks.resnet import resnet50 13 | 14 | from tqdm import tqdm 15 | 16 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('-d','--dir', nargs='+', type=str, default='examples/realfakedir') 18 | parser.add_argument('-m','--model_path', type=str, default='weights/blur_jpg_prob0.5.pth') 19 | parser.add_argument('-b','--batch_size', type=int, default=32) 20 | parser.add_argument('-j','--workers', type=int, default=4, help='number of workers') 21 | parser.add_argument('-c','--crop', type=int, default=None, help='by default, do not crop. specify crop size') 22 | parser.add_argument('--use_cpu', action='store_true', help='uses gpu by default, turn on to use cpu') 23 | parser.add_argument('--size_only', action='store_true', help='only look at sizes of images in dataset') 24 | 25 | opt = parser.parse_args() 26 | 27 | # Load model 28 | if(not opt.size_only): 29 | model = resnet50(num_classes=1) 30 | if(opt.model_path is not None): 31 | state_dict = torch.load(opt.model_path, map_location='cpu') 32 | model.load_state_dict(state_dict['model']) 33 | model.eval() 34 | if(not opt.use_cpu): 35 | model.cuda() 36 | 37 | # Transform 38 | trans_init = [] 39 | if(opt.crop is not None): 40 | trans_init = [transforms.CenterCrop(opt.crop),] 41 | print('Cropping to [%i]'%opt.crop) 42 | else: 43 | print('Not cropping') 44 | trans = transforms.Compose(trans_init + [ 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 47 | ]) 48 | 49 | # Dataset loader 50 | if(type(opt.dir)==str): 51 | opt.dir = [opt.dir,] 52 | 53 | print('Loading [%i] datasets'%len(opt.dir)) 54 | data_loaders = [] 55 | for dir in opt.dir: 56 | dataset = datasets.ImageFolder(dir, transform=trans) 57 | data_loaders+=[torch.utils.data.DataLoader(dataset, 58 | batch_size=opt.batch_size, 59 | shuffle=False, 60 | num_workers=opt.workers),] 61 | 62 | y_true, y_pred = [], [] 63 | Hs, Ws = [], [] 64 | with torch.no_grad(): 65 | for data_loader in data_loaders: 66 | for data, label in tqdm(data_loader): 67 | # for data, label in data_loader: 68 | Hs.append(data.shape[2]) 69 | Ws.append(data.shape[3]) 70 | 71 | y_true.extend(label.flatten().tolist()) 72 | if(not opt.size_only): 73 | if(not opt.use_cpu): 74 | data = data.cuda() 75 | y_pred.extend(model(data).sigmoid().flatten().tolist()) 76 | 77 | Hs, Ws = np.array(Hs), np.array(Ws) 78 | y_true, y_pred = np.array(y_true), np.array(y_pred) 79 | 80 | print('Average sizes: [{:2.2f}+/-{:2.2f}] x [{:2.2f}+/-{:2.2f}] = [{:2.2f}+/-{:2.2f} Mpix]'.format(np.mean(Hs), np.std(Hs), np.mean(Ws), np.std(Ws), np.mean(Hs*Ws)/1e6, np.std(Hs*Ws)/1e6)) 81 | print('Num reals: {}, Num fakes: {}'.format(np.sum(1-y_true), np.sum(y_true))) 82 | 83 | if(not opt.size_only): 84 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > 0.5) 85 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > 0.5) 86 | acc = accuracy_score(y_true, y_pred > 0.5) 87 | ap = average_precision_score(y_true, y_pred) 88 | 89 | print('AP: {:2.2f}, Acc: {:2.2f}, Acc (real): {:2.2f}, Acc (fake): {:2.2f}'.format(ap*100., acc*100., r_acc*100., f_acc*100.)) 90 | 91 | 92 | -------------------------------------------------------------------------------- /earlystop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class EarlyStopping: 6 | """Early stops the training if validation loss doesn't improve after a given patience.""" 7 | def __init__(self, patience=1, verbose=False, delta=0): 8 | """ 9 | Args: 10 | patience (int): How long to wait after last time validation loss improved. 11 | Default: 7 12 | verbose (bool): If True, prints a message for each validation loss improvement. 13 | Default: False 14 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 15 | Default: 0 16 | """ 17 | self.patience = patience 18 | self.verbose = verbose 19 | self.counter = 0 20 | self.best_score = None 21 | self.early_stop = False 22 | self.score_max = -np.Inf 23 | self.delta = delta 24 | 25 | def __call__(self, score, model): 26 | if self.best_score is None: 27 | self.best_score = score 28 | self.save_checkpoint(score, model) 29 | elif score < self.best_score - self.delta: 30 | self.counter += 1 31 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 32 | if self.counter >= self.patience: 33 | self.early_stop = True 34 | else: 35 | self.best_score = score 36 | self.save_checkpoint(score, model) 37 | self.counter = 0 38 | 39 | def save_checkpoint(self, score, model): 40 | '''Saves model when validation loss decrease.''' 41 | if self.verbose: 42 | print(f'Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...') 43 | model.save_networks('best') 44 | self.score_max = score 45 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | 5 | from validate import validate 6 | from networks.resnet import resnet50 7 | from options.test_options import TestOptions 8 | from eval_config import * 9 | 10 | 11 | # Running tests 12 | opt = TestOptions().parse(print_options=False) 13 | model_name = os.path.basename(model_path).replace('.pth', '') 14 | rows = [["{} model testing on...".format(model_name)], 15 | ['testset', 'accuracy', 'avg precision']] 16 | 17 | print("{} model testing on...".format(model_name)) 18 | for v_id, val in enumerate(vals): 19 | opt.dataroot = '{}/{}'.format(dataroot, val) 20 | opt.classes = os.listdir(opt.dataroot) if multiclass[v_id] else [''] 21 | opt.no_resize = True # testing without resizing by default 22 | 23 | model = resnet50(num_classes=1) 24 | state_dict = torch.load(model_path, map_location='cpu') 25 | model.load_state_dict(state_dict['model']) 26 | model.cuda() 27 | model.eval() 28 | 29 | acc, ap, _, _, _, _ = validate(model, opt) 30 | rows.append([val, acc, ap]) 31 | print("({}) acc: {}; ap: {}".format(val, acc, ap)) 32 | 33 | csv_name = results_dir + '/{}.csv'.format(model_name) 34 | with open(csv_name, 'w') as f: 35 | csv_writer = csv.writer(f, delimiter=',') 36 | csv_writer.writerows(rows) 37 | -------------------------------------------------------------------------------- /eval_config.py: -------------------------------------------------------------------------------- 1 | from util import mkdir 2 | 3 | 4 | # directory to store the results 5 | results_dir = './results/' 6 | mkdir(results_dir) 7 | 8 | # root to the testsets 9 | dataroot = './dataset/test/' 10 | 11 | # list of synthesis algorithms 12 | vals = ['progan', 'stylegan', 'biggan', 'cyclegan', 'stargan', 'gaugan', 13 | 'crn', 'imle', 'seeingdark', 'san', 'deepfake', 'stylegan2', 'whichfaceisreal'] 14 | 15 | # indicates if corresponding testset has multiple classes 16 | multiclass = [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0] 17 | 18 | # model 19 | model_path = 'weights/blur_jpg_prob0.5.pth' 20 | -------------------------------------------------------------------------------- /examples/fake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterWang512/CNNDetection/ea0b5622365e3a9cd31d1b54b6b5971131a839ab/examples/fake.png -------------------------------------------------------------------------------- /examples/real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterWang512/CNNDetection/ea0b5622365e3a9cd31d1b54b6b5971131a839ab/examples/real.png -------------------------------------------------------------------------------- /examples/realfakedir/0_real/real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterWang512/CNNDetection/ea0b5622365e3a9cd31d1b54b6b5971131a839ab/examples/realfakedir/0_real/real.png -------------------------------------------------------------------------------- /examples/realfakedir/1_fake/fake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterWang512/CNNDetection/ea0b5622365e3a9cd31d1b54b6b5971131a839ab/examples/realfakedir/1_fake/fake.png -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterWang512/CNNDetection/ea0b5622365e3a9cd31d1b54b6b5971131a839ab/networks/__init__.py -------------------------------------------------------------------------------- /networks/base_model.py: -------------------------------------------------------------------------------- 1 | # from pix2pix 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.optim import lr_scheduler 7 | 8 | 9 | class BaseModel(nn.Module): 10 | def __init__(self, opt): 11 | super(BaseModel, self).__init__() 12 | self.opt = opt 13 | self.total_steps = 0 14 | self.isTrain = opt.isTrain 15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 16 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 17 | 18 | def save_networks(self, epoch): 19 | save_filename = 'model_epoch_%s.pth' % epoch 20 | save_path = os.path.join(self.save_dir, save_filename) 21 | 22 | # serialize model and optimizer to dict 23 | state_dict = { 24 | 'model': self.model.state_dict(), 25 | 'optimizer' : self.optimizer.state_dict(), 26 | 'total_steps' : self.total_steps, 27 | } 28 | 29 | torch.save(state_dict, save_path) 30 | 31 | # load models from the disk 32 | def load_networks(self, epoch): 33 | load_filename = 'model_epoch_%s.pth' % epoch 34 | load_path = os.path.join(self.save_dir, load_filename) 35 | 36 | print('loading the model from %s' % load_path) 37 | # if you are using PyTorch newer than 0.4 (e.g., built from 38 | # GitHub source), you can remove str() on self.device 39 | state_dict = torch.load(load_path, map_location=self.device) 40 | if hasattr(state_dict, '_metadata'): 41 | del state_dict._metadata 42 | 43 | self.model.load_state_dict(state_dict['model']) 44 | self.total_steps = state_dict['total_steps'] 45 | 46 | if self.isTrain and not self.opt.new_optim: 47 | self.optimizer.load_state_dict(state_dict['optimizer']) 48 | ### move optimizer state to GPU 49 | for state in self.optimizer.state.values(): 50 | for k, v in state.items(): 51 | if torch.is_tensor(v): 52 | state[k] = v.to(self.device) 53 | 54 | for g in self.optimizer.param_groups: 55 | g['lr'] = self.opt.lr 56 | 57 | def eval(self): 58 | self.model.eval() 59 | 60 | def test(self): 61 | with torch.no_grad(): 62 | self.forward() 63 | 64 | 65 | def init_weights(net, init_type='normal', gain=0.02): 66 | def init_func(m): 67 | classname = m.__class__.__name__ 68 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 69 | if init_type == 'normal': 70 | init.normal_(m.weight.data, 0.0, gain) 71 | elif init_type == 'xavier': 72 | init.xavier_normal_(m.weight.data, gain=gain) 73 | elif init_type == 'kaiming': 74 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 75 | elif init_type == 'orthogonal': 76 | init.orthogonal_(m.weight.data, gain=gain) 77 | else: 78 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 79 | if hasattr(m, 'bias') and m.bias is not None: 80 | init.constant_(m.bias.data, 0.0) 81 | elif classname.find('BatchNorm2d') != -1: 82 | init.normal_(m.weight.data, 1.0, gain) 83 | init.constant_(m.bias.data, 0.0) 84 | 85 | print('initialize network with %s' % init_type) 86 | net.apply(init_func) 87 | -------------------------------------------------------------------------------- /networks/lpf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Adobe Inc. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4 | # 4.0 International Public License. To view a copy of this license, visit 5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 6 | 7 | import torch 8 | import torch.nn.parallel 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from IPython import embed 13 | 14 | class Downsample(nn.Module): 15 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 16 | super(Downsample, self).__init__() 17 | self.filt_size = filt_size 18 | self.pad_off = pad_off 19 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 20 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 21 | self.stride = stride 22 | self.off = int((self.stride-1)/2.) 23 | self.channels = channels 24 | 25 | # print('Filter size [%i]'%filt_size) 26 | if(self.filt_size==1): 27 | a = np.array([1.,]) 28 | elif(self.filt_size==2): 29 | a = np.array([1., 1.]) 30 | elif(self.filt_size==3): 31 | a = np.array([1., 2., 1.]) 32 | elif(self.filt_size==4): 33 | a = np.array([1., 3., 3., 1.]) 34 | elif(self.filt_size==5): 35 | a = np.array([1., 4., 6., 4., 1.]) 36 | elif(self.filt_size==6): 37 | a = np.array([1., 5., 10., 10., 5., 1.]) 38 | elif(self.filt_size==7): 39 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 40 | 41 | filt = torch.Tensor(a[:,None]*a[None,:]) 42 | filt = filt/torch.sum(filt) 43 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 44 | 45 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 46 | 47 | def forward(self, inp): 48 | if(self.filt_size==1): 49 | if(self.pad_off==0): 50 | return inp[:,:,::self.stride,::self.stride] 51 | else: 52 | return self.pad(inp)[:,:,::self.stride,::self.stride] 53 | else: 54 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 55 | 56 | def get_pad_layer(pad_type): 57 | if(pad_type in ['refl','reflect']): 58 | PadLayer = nn.ReflectionPad2d 59 | elif(pad_type in ['repl','replicate']): 60 | PadLayer = nn.ReplicationPad2d 61 | elif(pad_type=='zero'): 62 | PadLayer = nn.ZeroPad2d 63 | else: 64 | print('Pad type [%s] not recognized'%pad_type) 65 | return PadLayer 66 | 67 | 68 | class Downsample1D(nn.Module): 69 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 70 | super(Downsample1D, self).__init__() 71 | self.filt_size = filt_size 72 | self.pad_off = pad_off 73 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 74 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 75 | self.stride = stride 76 | self.off = int((self.stride - 1) / 2.) 77 | self.channels = channels 78 | 79 | # print('Filter size [%i]' % filt_size) 80 | if(self.filt_size == 1): 81 | a = np.array([1., ]) 82 | elif(self.filt_size == 2): 83 | a = np.array([1., 1.]) 84 | elif(self.filt_size == 3): 85 | a = np.array([1., 2., 1.]) 86 | elif(self.filt_size == 4): 87 | a = np.array([1., 3., 3., 1.]) 88 | elif(self.filt_size == 5): 89 | a = np.array([1., 4., 6., 4., 1.]) 90 | elif(self.filt_size == 6): 91 | a = np.array([1., 5., 10., 10., 5., 1.]) 92 | elif(self.filt_size == 7): 93 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 94 | 95 | filt = torch.Tensor(a) 96 | filt = filt / torch.sum(filt) 97 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) 98 | 99 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 100 | 101 | def forward(self, inp): 102 | if(self.filt_size == 1): 103 | if(self.pad_off == 0): 104 | return inp[:, :, ::self.stride] 105 | else: 106 | return self.pad(inp)[:, :, ::self.stride] 107 | else: 108 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 109 | 110 | 111 | def get_pad_layer_1d(pad_type): 112 | if(pad_type in ['refl', 'reflect']): 113 | PadLayer = nn.ReflectionPad1d 114 | elif(pad_type in ['repl', 'replicate']): 115 | PadLayer = nn.ReplicationPad1d 116 | elif(pad_type == 'zero'): 117 | PadLayer = nn.ZeroPad1d 118 | else: 119 | print('Pad type [%s] not recognized' % pad_type) 120 | return PadLayer 121 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | identity = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 102 | super(ResNet, self).__init__() 103 | self.inplanes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | # Zero-initialize the last BN in each residual branch, 124 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 125 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 126 | if zero_init_residual: 127 | for m in self.modules(): 128 | if isinstance(m, Bottleneck): 129 | nn.init.constant_(m.bn3.weight, 0) 130 | elif isinstance(m, BasicBlock): 131 | nn.init.constant_(m.bn2.weight, 0) 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | conv1x1(self.inplanes, planes * block.expansion, stride), 138 | nn.BatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for _ in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | 167 | def resnet18(pretrained=False, **kwargs): 168 | """Constructs a ResNet-18 model. 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 175 | return model 176 | 177 | 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 186 | return model 187 | 188 | 189 | def resnet50(pretrained=False, **kwargs): 190 | """Constructs a ResNet-50 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 195 | if pretrained: 196 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 197 | return model 198 | 199 | 200 | def resnet101(pretrained=False, **kwargs): 201 | """Constructs a ResNet-101 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 208 | return model 209 | 210 | 211 | def resnet152(pretrained=False, **kwargs): 212 | """Constructs a ResNet-152 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 217 | if pretrained: 218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 219 | return model 220 | -------------------------------------------------------------------------------- /networks/resnet_lpf.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | # 5 | # ========================================================================================== 6 | # 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | # 12 | # ========================================================================================== 13 | # 14 | # BSD-3 License 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | # 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | # 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | # 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | import torch.nn as nn 41 | import torch.utils.model_zoo as model_zoo 42 | from .lpf import * 43 | 44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 45 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 46 | 47 | 48 | # model_urls = { 49 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 50 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 51 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 52 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 53 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 54 | # } 55 | 56 | 57 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 58 | """3x3 convolution with padding""" 59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 60 | padding=1, groups=groups, bias=False) 61 | 62 | def conv1x1(in_planes, out_planes, stride=1): 63 | """1x1 convolution""" 64 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 65 | 66 | class BasicBlock(nn.Module): 67 | expansion = 1 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): 70 | super(BasicBlock, self).__init__() 71 | if norm_layer is None: 72 | norm_layer = nn.BatchNorm2d 73 | if groups != 1: 74 | raise ValueError('BasicBlock only supports groups=1') 75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv3x3(inplanes, planes) 77 | self.bn1 = norm_layer(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | if(stride==1): 80 | self.conv2 = conv3x3(planes,planes) 81 | else: 82 | self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), 83 | conv3x3(planes, planes),) 84 | self.bn2 = norm_layer(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | identity = self.downsample(x) 100 | 101 | out += identity 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | 110 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): 111 | super(Bottleneck, self).__init__() 112 | if norm_layer is None: 113 | norm_layer = nn.BatchNorm2d 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, planes) 116 | self.bn1 = norm_layer(planes) 117 | self.conv2 = conv3x3(planes, planes, groups) # stride moved 118 | self.bn2 = norm_layer(planes) 119 | if(stride==1): 120 | self.conv3 = conv1x1(planes, planes * self.expansion) 121 | else: 122 | self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), 123 | conv1x1(planes, planes * self.expansion)) 124 | self.bn3 = norm_layer(planes * self.expansion) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.downsample = downsample 127 | self.stride = stride 128 | 129 | def forward(self, x): 130 | identity = x 131 | 132 | out = self.conv1(x) 133 | out = self.bn1(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv2(out) 137 | out = self.bn2(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv3(out) 141 | out = self.bn3(out) 142 | 143 | if self.downsample is not None: 144 | identity = self.downsample(x) 145 | 146 | out += identity 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | 152 | class ResNet(nn.Module): 153 | 154 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 155 | groups=1, width_per_group=64, norm_layer=None, filter_size=1, pool_only=True): 156 | super(ResNet, self).__init__() 157 | if norm_layer is None: 158 | norm_layer = nn.BatchNorm2d 159 | planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] 160 | self.inplanes = planes[0] 161 | 162 | if(pool_only): 163 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False) 164 | else: 165 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False) 166 | self.bn1 = norm_layer(planes[0]) 167 | self.relu = nn.ReLU(inplace=True) 168 | 169 | if(pool_only): 170 | self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), 171 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) 172 | else: 173 | self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]), 174 | nn.MaxPool2d(kernel_size=2, stride=1), 175 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) 176 | 177 | self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer) 178 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 179 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 180 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 182 | self.fc = nn.Linear(planes[3] * block.expansion, num_classes) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): 187 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics 188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 189 | else: 190 | print('Not initializing') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1): 206 | if norm_layer is None: 207 | norm_layer = nn.BatchNorm2d 208 | downsample = None 209 | if stride != 1 or self.inplanes != planes * block.expansion: 210 | # downsample = nn.Sequential( 211 | # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size), 212 | # norm_layer(planes * block.expansion), 213 | # ) 214 | 215 | downsample = [Downsample(filt_size=filter_size, stride=stride, channels=self.inplanes),] if(stride !=1) else [] 216 | downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), 217 | norm_layer(planes * block.expansion)] 218 | # print(downsample) 219 | downsample = nn.Sequential(*downsample) 220 | 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer, filter_size=filter_size)) 223 | self.inplanes = planes * block.expansion 224 | for _ in range(1, blocks): 225 | layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer, filter_size=filter_size)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.relu(x) 233 | x = self.maxpool(x) 234 | 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | 240 | x = self.avgpool(x) 241 | x = x.view(x.size(0), -1) 242 | x = self.fc(x) 243 | 244 | return x 245 | 246 | 247 | def resnet18(pretrained=False, filter_size=1, pool_only=True, **kwargs): 248 | """Constructs a ResNet-18 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(BasicBlock, [2, 2, 2, 2], filter_size=filter_size, pool_only=pool_only, **kwargs) 253 | if pretrained: 254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 255 | return model 256 | 257 | 258 | def resnet34(pretrained=False, filter_size=1, pool_only=True, **kwargs): 259 | """Constructs a ResNet-34 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | """ 263 | model = ResNet(BasicBlock, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 264 | if pretrained: 265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 266 | return model 267 | 268 | 269 | def resnet50(pretrained=False, filter_size=1, pool_only=True, **kwargs): 270 | """Constructs a ResNet-50 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | model = ResNet(Bottleneck, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 275 | if pretrained: 276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 277 | return model 278 | 279 | 280 | def resnet101(pretrained=False, filter_size=1, pool_only=True, **kwargs): 281 | """Constructs a ResNet-101 model. 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | """ 285 | model = ResNet(Bottleneck, [3, 4, 23, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 286 | if pretrained: 287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 288 | return model 289 | 290 | 291 | def resnet152(pretrained=False, filter_size=1, pool_only=True, **kwargs): 292 | """Constructs a ResNet-152 model. 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | """ 296 | model = ResNet(Bottleneck, [3, 8, 36, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 297 | if pretrained: 298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 299 | return model 300 | 301 | 302 | def resnext50_32x4d(pretrained=False, filter_size=1, pool_only=True, **kwargs): 303 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) 304 | # if pretrained: 305 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 306 | return model 307 | 308 | 309 | def resnext101_32x8d(pretrained=False, filter_size=1, pool_only=True, **kwargs): 310 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) 311 | # if pretrained: 312 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 313 | return model 314 | -------------------------------------------------------------------------------- /networks/trainer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | from networks.resnet import resnet50 5 | from networks.base_model import BaseModel, init_weights 6 | 7 | 8 | class Trainer(BaseModel): 9 | def name(self): 10 | return 'Trainer' 11 | 12 | def __init__(self, opt): 13 | super(Trainer, self).__init__(opt) 14 | 15 | if self.isTrain and not opt.continue_train: 16 | self.model = resnet50(pretrained=True) 17 | self.model.fc = nn.Linear(2048, 1) 18 | torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain) 19 | 20 | if not self.isTrain or opt.continue_train: 21 | self.model = resnet50(num_classes=1) 22 | 23 | if self.isTrain: 24 | self.loss_fn = nn.BCEWithLogitsLoss() 25 | # initialize optimizers 26 | if opt.optim == 'adam': 27 | self.optimizer = torch.optim.Adam(self.model.parameters(), 28 | lr=opt.lr, betas=(opt.beta1, 0.999)) 29 | elif opt.optim == 'sgd': 30 | self.optimizer = torch.optim.SGD(self.model.parameters(), 31 | lr=opt.lr, momentum=0.0, weight_decay=0) 32 | else: 33 | raise ValueError("optim should be [adam, sgd]") 34 | 35 | if not self.isTrain or opt.continue_train: 36 | self.load_networks(opt.epoch) 37 | self.model.to(opt.gpu_ids[0]) 38 | 39 | 40 | def adjust_learning_rate(self, min_lr=1e-6): 41 | for param_group in self.optimizer.param_groups: 42 | param_group['lr'] /= 10. 43 | if param_group['lr'] < min_lr: 44 | return False 45 | return True 46 | 47 | def set_input(self, input): 48 | self.input = input[0].to(self.device) 49 | self.label = input[1].to(self.device).float() 50 | 51 | 52 | def forward(self): 53 | self.output = self.model(self.input) 54 | 55 | def get_loss(self): 56 | return self.loss_fn(self.output.squeeze(1), self.label) 57 | 58 | def optimize_parameters(self): 59 | self.forward() 60 | self.loss = self.loss_fn(self.output.squeeze(1), self.label) 61 | self.optimizer.zero_grad() 62 | self.loss.backward() 63 | self.optimizer.step() 64 | 65 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterWang512/CNNDetection/ea0b5622365e3a9cd31d1b54b6b5971131a839ab/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import util 4 | import torch 5 | #import models 6 | #import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--mode', default='binary') 15 | parser.add_argument('--arch', type=str, default='res50', help='architecture for binary classification') 16 | 17 | # data augmentation 18 | parser.add_argument('--rz_interp', default='bilinear') 19 | parser.add_argument('--blur_prob', type=float, default=0) 20 | parser.add_argument('--blur_sig', default='0.5') 21 | parser.add_argument('--jpg_prob', type=float, default=0) 22 | parser.add_argument('--jpg_method', default='cv2') 23 | parser.add_argument('--jpg_qual', default='75') 24 | 25 | parser.add_argument('--dataroot', default='./dataset/', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 26 | parser.add_argument('--classes', default='', help='image classes to train on') 27 | parser.add_argument('--class_bal', action='store_true') 28 | parser.add_argument('--batch_size', type=int, default=64, help='input batch size') 29 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') 30 | parser.add_argument('--cropSize', type=int, default=224, help='then crop to this size') 31 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 32 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 33 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 34 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 35 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 36 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 37 | parser.add_argument('--resize_or_crop', type=str, default='scale_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]') 38 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 39 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 40 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 41 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 42 | self.initialized = True 43 | return parser 44 | 45 | def gather_options(self): 46 | # initialize parser with basic options 47 | if not self.initialized: 48 | parser = argparse.ArgumentParser( 49 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 50 | parser = self.initialize(parser) 51 | 52 | # get the basic options 53 | opt, _ = parser.parse_known_args() 54 | self.parser = parser 55 | 56 | return parser.parse_args() 57 | 58 | def print_options(self, opt): 59 | message = '' 60 | message += '----------------- Options ---------------\n' 61 | for k, v in sorted(vars(opt).items()): 62 | comment = '' 63 | default = self.parser.get_default(k) 64 | if v != default: 65 | comment = '\t[default: %s]' % str(default) 66 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 67 | message += '----------------- End -------------------' 68 | print(message) 69 | 70 | # save to the disk 71 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 72 | util.mkdirs(expr_dir) 73 | file_name = os.path.join(expr_dir, 'opt.txt') 74 | with open(file_name, 'wt') as opt_file: 75 | opt_file.write(message) 76 | opt_file.write('\n') 77 | 78 | def parse(self, print_options=True): 79 | 80 | opt = self.gather_options() 81 | opt.isTrain = self.isTrain # train or test 82 | 83 | # process opt.suffix 84 | if opt.suffix: 85 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 86 | opt.name = opt.name + suffix 87 | 88 | if print_options: 89 | self.print_options(opt) 90 | 91 | # set gpu ids 92 | str_ids = opt.gpu_ids.split(',') 93 | opt.gpu_ids = [] 94 | for str_id in str_ids: 95 | id = int(str_id) 96 | if id >= 0: 97 | opt.gpu_ids.append(id) 98 | if len(opt.gpu_ids) > 0: 99 | torch.cuda.set_device(opt.gpu_ids[0]) 100 | 101 | # additional 102 | opt.classes = opt.classes.split(',') 103 | opt.rz_interp = opt.rz_interp.split(',') 104 | opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')] 105 | opt.jpg_method = opt.jpg_method.split(',') 106 | opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')] 107 | if len(opt.jpg_qual) == 2: 108 | opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1)) 109 | elif len(opt.jpg_qual) > 2: 110 | raise ValueError("Shouldn't have more than 2 values for --jpg_qual.") 111 | 112 | self.opt = opt 113 | return self.opt 114 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--model_path') 8 | parser.add_argument('--no_resize', action='store_true') 9 | parser.add_argument('--no_crop', action='store_true') 10 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 11 | 12 | self.isTrain = False 13 | return parser 14 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--earlystop_epoch', type=int, default=5) 8 | parser.add_argument('--data_aug', action='store_true', help='if specified, perform additional data augmentation (photometric, blurring, jpegging)') 9 | parser.add_argument('--optim', type=str, default='adam', help='optim to use [sgd, adam]') 10 | parser.add_argument('--new_optim', action='store_true', help='new optimizer instead of loading the optim state') 11 | parser.add_argument('--loss_freq', type=int, default=400, help='frequency of showing loss on tensorboard') 12 | parser.add_argument('--save_latest_freq', type=int, default=2000, help='frequency of saving the latest results') 13 | parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs') 14 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 15 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 16 | parser.add_argument('--last_epoch', type=int, default=-1, help='starting epoch count for scheduler intialization') 17 | parser.add_argument('--train_split', type=str, default='train', help='train, val, test, etc') 18 | parser.add_argument('--val_split', type=str, default='val', help='train, val, test, etc') 19 | parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate') 20 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 21 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 22 | 23 | self.isTrain = True 24 | return parser 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | scikit-learn 3 | numpy 4 | opencv_python 5 | Pillow 6 | torch>=1.2.0 7 | torchvision 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.nn 6 | import argparse 7 | from PIL import Image 8 | from tensorboardX import SummaryWriter 9 | 10 | from validate import validate 11 | from data import create_dataloader 12 | from earlystop import EarlyStopping 13 | from networks.trainer import Trainer 14 | from options.train_options import TrainOptions 15 | 16 | 17 | """Currently assumes jpg_prob, blur_prob 0 or 1""" 18 | def get_val_opt(): 19 | val_opt = TrainOptions().parse(print_options=False) 20 | val_opt.dataroot = '{}/{}/'.format(val_opt.dataroot, val_opt.val_split) 21 | val_opt.isTrain = False 22 | val_opt.no_resize = False 23 | val_opt.no_crop = False 24 | val_opt.serial_batches = True 25 | val_opt.jpg_method = ['pil'] 26 | if len(val_opt.blur_sig) == 2: 27 | b_sig = val_opt.blur_sig 28 | val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2] 29 | if len(val_opt.jpg_qual) != 1: 30 | j_qual = val_opt.jpg_qual 31 | val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)] 32 | 33 | return val_opt 34 | 35 | 36 | if __name__ == '__main__': 37 | opt = TrainOptions().parse() 38 | opt.dataroot = '{}/{}/'.format(opt.dataroot, opt.train_split) 39 | val_opt = get_val_opt() 40 | 41 | data_loader = create_dataloader(opt) 42 | dataset_size = len(data_loader) 43 | print('#training images = %d' % dataset_size) 44 | 45 | train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train")) 46 | val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val")) 47 | 48 | model = Trainer(opt) 49 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True) 50 | for epoch in range(opt.niter): 51 | epoch_start_time = time.time() 52 | iter_data_time = time.time() 53 | epoch_iter = 0 54 | 55 | for i, data in enumerate(data_loader): 56 | model.total_steps += 1 57 | epoch_iter += opt.batch_size 58 | 59 | model.set_input(data) 60 | model.optimize_parameters() 61 | 62 | if model.total_steps % opt.loss_freq == 0: 63 | print("Train loss: {} at step: {}".format(model.loss, model.total_steps)) 64 | train_writer.add_scalar('loss', model.loss, model.total_steps) 65 | 66 | if model.total_steps % opt.save_latest_freq == 0: 67 | print('saving the latest model %s (epoch %d, model.total_steps %d)' % 68 | (opt.name, epoch, model.total_steps)) 69 | model.save_networks('latest') 70 | 71 | # print("Iter time: %d sec" % (time.time()-iter_data_time)) 72 | # iter_data_time = time.time() 73 | 74 | if epoch % opt.save_epoch_freq == 0: 75 | print('saving the model at the end of epoch %d, iters %d' % 76 | (epoch, model.total_steps)) 77 | model.save_networks('latest') 78 | model.save_networks(epoch) 79 | 80 | # Validation 81 | model.eval() 82 | acc, ap = validate(model.model, val_opt)[:2] 83 | val_writer.add_scalar('accuracy', acc, model.total_steps) 84 | val_writer.add_scalar('ap', ap, model.total_steps) 85 | print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) 86 | 87 | early_stopping(acc, model) 88 | if early_stopping.early_stop: 89 | cont_train = model.adjust_learning_rate() 90 | if cont_train: 91 | print("Learning rate dropped by 10, continue training...") 92 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.002, verbose=True) 93 | else: 94 | print("Early stopping.") 95 | break 96 | model.train() 97 | 98 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | 13 | def mkdir(path): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | 18 | def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 19 | # assume tensor of shape NxCxHxW 20 | return tens * torch.Tensor(std)[None, :, None, None] + torch.Tensor( 21 | mean)[None, :, None, None] 22 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from networks.resnet import resnet50 4 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score 5 | from options.test_options import TestOptions 6 | from data import create_dataloader 7 | 8 | 9 | def validate(model, opt): 10 | data_loader = create_dataloader(opt) 11 | 12 | with torch.no_grad(): 13 | y_true, y_pred = [], [] 14 | for img, label in data_loader: 15 | in_tens = img.cuda() 16 | y_pred.extend(model(in_tens).sigmoid().flatten().tolist()) 17 | y_true.extend(label.flatten().tolist()) 18 | 19 | y_true, y_pred = np.array(y_true), np.array(y_pred) 20 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > 0.5) 21 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > 0.5) 22 | acc = accuracy_score(y_true, y_pred > 0.5) 23 | ap = average_precision_score(y_true, y_pred) 24 | return acc, ap, r_acc, f_acc, y_true, y_pred 25 | 26 | 27 | if __name__ == '__main__': 28 | opt = TestOptions().parse(print_options=False) 29 | 30 | model = resnet50(num_classes=1) 31 | state_dict = torch.load(opt.model_path, map_location='cpu') 32 | model.load_state_dict(state_dict['model']) 33 | model.cuda() 34 | model.eval() 35 | 36 | acc, avg_precision, r_acc, f_acc, y_true, y_pred = validate(model, opt) 37 | 38 | print("accuracy:", acc) 39 | print("average precision:", avg_precision) 40 | 41 | print("accuracy of real images:", r_acc) 42 | print("accuracy of fake images:", f_acc) 43 | -------------------------------------------------------------------------------- /weights/download_weights.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/2g2jagq2jn1fd0i/blur_jpg_prob0.5.pth?dl=0 -O ./weights/blur_jpg_prob0.5.pth 2 | wget https://www.dropbox.com/s/h7tkpcgiwuftb6g/blur_jpg_prob0.1.pth?dl=0 -O ./weights/blur_jpg_prob0.1.pth 3 | 4 | --------------------------------------------------------------------------------