├── LICENSE.md ├── README.md ├── config.yaml ├── example ├── celeba │ ├── 1.jpg │ ├── 1_mask.png │ ├── 1_tsmooth.png │ ├── 2.jpg │ ├── 2_mask.png │ ├── 2_tsmooth.png │ ├── 3.jpg │ ├── 3_mask.png │ ├── 3_tsmooth.png │ ├── 4.jpg │ ├── 4_mask.png │ └── 4_tsmooth.jpg ├── paris │ ├── 1.jpg │ ├── 1_mask.png │ ├── 1_tsmooth.jpg │ ├── 2.png │ ├── 2_mask.png │ └── 2_tsmooth.jpg └── places │ ├── 1.jpg │ ├── 1_mask.png │ ├── 1_tsmooth.png │ ├── 2.jpg │ ├── 2_mask.png │ ├── 2_tsmooth.jpg │ ├── 3.jpg │ ├── 3_mask.png │ ├── 3_tsmooth.jpg │ ├── 4.jpg │ ├── 4_mask.png │ └── 4_tsmooth.jpg ├── main.py ├── model_config.yaml ├── resample2d_package ├── __init__.py ├── resample2d_cuda.cc ├── resample2d_kernel.cu ├── resample2d_kernel.cuh └── setup.py ├── scripts ├── flist.py ├── inception.py ├── matlab │ ├── code │ │ └── tsmooth.m │ ├── dirPlus.m │ └── generate_structure_images.m └── metrics.py ├── src ├── base_model.py ├── config.py ├── data.py ├── loss.py ├── metrics.py ├── models.py ├── network.py ├── resample2d.py ├── structure_flow.py └── utils.py ├── test.py └── train.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 32 | 33 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 34 | 35 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 36 | 37 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 38 | 39 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 40 | 41 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 42 | 43 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 44 | 45 | ### Section 2 – Scope. 46 | 47 | a. ___License grant.___ 48 | 49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 50 | 51 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 52 | 53 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 54 | 55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 56 | 57 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 58 | 59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 60 | 61 | 5. __Downstream recipients.__ 62 | 63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 64 | 65 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 66 | 67 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 68 | 69 | b. ___Other rights.___ 70 | 71 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 72 | 73 | 2. Patent and trademark rights are not licensed under this Public License. 74 | 75 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 76 | 77 | ### Section 3 – License Conditions. 78 | 79 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 80 | 81 | a. ___Attribution.___ 82 | 83 | 1. If You Share the Licensed Material (including in modified form), You must: 84 | 85 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 86 | 87 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 88 | 89 | ii. a copyright notice; 90 | 91 | iii. a notice that refers to this Public License; 92 | 93 | iv. a notice that refers to the disclaimer of warranties; 94 | 95 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 96 | 97 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 98 | 99 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 100 | 101 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 102 | 103 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 104 | 105 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. 106 | 107 | ### Section 4 – Sui Generis Database Rights. 108 | 109 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 110 | 111 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 112 | 113 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 114 | 115 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 116 | 117 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 118 | 119 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 120 | 121 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 122 | 123 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 124 | 125 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 126 | 127 | ### Section 6 – Term and Termination. 128 | 129 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 130 | 131 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 132 | 133 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 134 | 135 | 2. upon express reinstatement by the Licensor. 136 | 137 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 138 | 139 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 140 | 141 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 142 | 143 | ### Section 7 – Other Terms and Conditions. 144 | 145 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 146 | 147 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 148 | 149 | ### Section 8 – Interpretation. 150 | 151 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 152 | 153 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 154 | 155 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 156 | 157 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 158 | 159 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 160 | > 161 | > Creative Commons may be contacted at creativecommons.org 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StructureFlow 2 | Code for our paper "[StructureFlow: Image Inpainting via Structure-aware Appearance Flow](https://arxiv.org/abs/1908.03852)" (ICCV 2019) 3 | 4 | ### Introduction 5 | 6 | We propose a two-stage image inpainting network which splits the task into two parts: **structure reconstruction** and **texture generation**. In the first stage, edge-preserved smooth images are employed to train a structure reconstructor which completes the missing structures of the inputs. In the second stage, based on the reconstructed structures, a texture generator using appearance flow is designed to yield image details. 7 | 8 |

9 | 10 |

11 | 12 | *(From left to right) Input corrupted images, reconstructed structure images, visualizations of the appearance flow fields, final output images. To visualize the appearance flow fields, we plot the sample points of some typical missing regions. The arrows show the direction of the appearance flow.* 13 | 14 | ### Requirements 15 | 16 | 1. Pytorch >= 1.0 17 | 2. Python 3 18 | 3. NVIDIA GPU + CUDA 9.0 19 | 4. Tensorboard 20 | 5. Matlab 21 | 22 | ### Installation 23 | 24 | 1. Clone this repository 25 | 26 | ```bash 27 | git clone https://github.com/RenYurui/StructureFlow 28 | ``` 29 | 30 | 2. Build Gaussian Sampling CUDA package 31 | 32 | ```bash 33 | cd ./StructureFlow/resample2d_package 34 | python setup.py install --user 35 | ``` 36 | 37 | 38 | ### Running 39 | 40 | **1. Image Prepare** 41 | 42 | We train our model on three public datasets including Places2, Celeba, and Paris StreetView. We use the irregular mask dataset provided by [PConv](https://arxiv.org/abs/1804.07723). You can download these datasets from their project website. 43 | 44 | 1. [Places2](http://places2.csail.mit.edu) 45 | 2. [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 46 | 3. [Paris Street-View](https://github.com/pathak22/context-encoder) 47 | 4. [Irregular Masks](http://masc.cs.gmu.edu/wiki/partialconv) 48 | 49 | After downloading the datasets, The edge-preserved smooth images can be obtained by using [RTV smooth method](http://www.cse.cuhk.edu.hk/~leojia/projects/texturesep/). Run generation function [`scripts/matlab/generate_structre_images.m`](scripts/matlab/generate_structure_images.m) in your matlab. For example, if you want to generate smooth images for Places2, you can run the following code: 50 | 51 | ```matlab 52 | generate_structure_images("path to Places2 dataset root", "path to output folder"); 53 | ``` 54 | 55 | Finally, you can generate the image list using script [`scripts/flist.py`](scripts/flist.py) for training and testing. 56 | 57 | **2. Training** 58 | 59 | To train our model, modify the model config file [model_config.yaml](model_config.yaml). You may need to change the path of dataset or the parameters of the networks etc. Then run the following code: 60 | 61 | ```bash 62 | python train.py \ 63 | --name=[the name of your experiment] \ 64 | --path=[path save the results] 65 | ``` 66 | 67 | **3. Testing** 68 | 69 | To output the generated results of the inputs, you can use the [test.py](test.py). Please run the following code: 70 | 71 | ```bash 72 | python test.py \ 73 | --name=[the name of your experiment] \ 74 | --path=[path of your experiments] \ 75 | --input=[input images] \ 76 | --mask=[mask images] \ 77 | --structure=[structure images] \ 78 | --output=[path to save the output images] \ 79 | --model=[which model to be tested] 80 | ``` 81 | 82 | To evaluate the model performance over a dateset, you can use the provided script [./scripts/matric.py](scripts/metrics.py). This script can provide the PSNR, SSIM and Fréchet Inception Distance ([FID score](https://github.com/mseitzer/pytorch-fid)) of the results. 83 | 84 | ```bash 85 | python ./scripts/metrics.py \ 86 | --input_path=[path to ground-truth images] \ 87 | --output_path=[path to model outputs] \ 88 | --fid_real_path=[path to the real images using to calculate fid] 89 | ``` 90 | 91 | **The pre-trained weights can be downloaded from [Places2](https://drive.google.com/open?id=1K7U6fYthC4Acsx0GBde5iszHJWymyv1A), [Celeba](https://drive.google.com/open?id=1PrLgcEd964etxZcHIOE93uUONB9-b6pI), [Paris Street](https://drive.google.com/open?id=18AQpgsYZtA_eL-aJb6n8-geWLdihwXAi).** 92 | 93 | Download the checkpoints and save them to './path_of_your_experiments/name_of_your_experiment/checkpoints' 94 | 95 | For example you can download the checkpoints of Places2 and save them to './results/places/checkpoints' and run the following code: 96 | 97 | ```bash 98 | python test.py \ 99 | --name=places \ 100 | --path=results \ 101 | --input=./example/places/1.jpg \ 102 | --mask=./example/places/1_mask.png \ 103 | --structure=./example/places/1_tsmooth.png \ 104 | --output=./result_images \ 105 | --model=3 106 | ``` 107 | 108 | ### Citation 109 | 110 | If you find this code is helpful for your research, please cite our paper: 111 | 112 | ``` 113 | @inproceedings{ren2019structureflow, 114 | author = {Ren, Yurui and Yu, Xiaoming and Zhang, Ruonan and Li, Thomas H. and Liu, Shan and Li, Ge}, 115 | title = {StructureFlow: Image Inpainting via Structure-aware Appearance Flow}, 116 | booktitle={IEEE International Conference on Computer Vision (ICCV)}, 117 | year = {2019} 118 | } 119 | ``` 120 | 121 | 122 | 123 | ### Acknowledgements 124 | 125 | We built our code based on [Edge-Connect](https://github.com/knazeri/edge-connect). Part of the code were derived from [FlowNet2](https://github.com/NVIDIA/flownet2-pytorch). Please consider to cite their papers. 126 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 3 # 1: structure model, 2: inpaint model, 3: structure-inpaint model 2 | VERBOSE: 1 # turns on verbose mode in the output console 3 | GPU: [1] # gpu ids 4 | 5 | 6 | MAX_ITERS: 8e6 # maximum number of iterations to train the model 7 | LR: 0.0001 # learning rate 8 | BETA1: 0 # adam optimizer beta1 9 | BETA2: 0.999 # adam optimizer beta2 10 | LR_POLICY: constant # the method to adjust learning rate (eg: constant|step) 11 | STEP_SIZE: 100000 # Period of learning rate decay (only used when choosing "step" as the lr adjusment method) 12 | GAMMA: 0.5 # Multiplicative factor of learning rate decay. (only used when choosing "step" as the lr adjusment method) 13 | INIT_TYPE: xavier # initialization [gaussian/kaiming/xavier/orthogonal] 14 | 15 | SAVE_INTERVAL: 30 # how many iterations to wait before saving model (0: never) 16 | SAVE_LATEST: 10 # how many iterations to wait before saving lastest model (0: never) 17 | SAMPLE_INTERVAL: 10 # how many iterations to wait before sampling (0: never) 18 | SAMPLE_SIZE: 4 # number of images to sample 19 | EVAL_INTERVAL: 1000 # how many iterations to wait before model evaluation (0: never) 20 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never) 21 | WHICH_ITER: latest # which iterations to load 22 | 23 | DIS_GAN_LOSS: lsgan # type of gan loss 24 | 25 | STRUCTURE_L1: 4 # structure net parameter of l1 loss 26 | STRUCTURE_ADV_GEN: 1 # structure net parameter of gan loss 27 | 28 | FLOW_ADV_GEN: 1 # texture net parameter of gan loss 29 | FLOW_L1: 5 # texture net parameter of l1 loss 30 | FLOW_CORRECTNESS: 0.25 # texture net parameter of sampling correctness loss 31 | VGG_STYLE: 250 # texture net parameter of vgg_style loss (Optional loss on stage_3) 32 | VGG_CONTENT: 0.1 # texture net parameter of vgg_content loss (Optional loss on stage_3) 33 | 34 | 35 | TRAIN_BATCH_SIZE: 8 # batch size 36 | DATA_TRAIN_SIZE: 256 # image size for training 37 | DATA_TEST_SIZE: False # image size for testing (False for never resize) 38 | DATA_FLIP: False # filp image or not when training 39 | DATA_CROP: FALSE #[537,537] # crop size when training (False for never crop ) 40 | DATA_MASK_TYPE: from_file # mask type (random_bbox|random_free_form|from_file) 41 | DATA_RANDOM_BBOX_SETTING: # parameters for random_bbox 42 | random_size: False # random hole size according to shape [0.4*shape shape] 43 | shape: [80, 80] # hole size 44 | margin: [0, 0] # minimum distance from the image boundary 45 | num: 3 46 | DATA_RANDOM_FF_SETTING: #parameters for random_free_form 47 | mv: 5 48 | ma: 4.0 49 | ml: 40 50 | mbw: 10 51 | DATA_MASK_FILE: ./txt/irregular_mask.txt #parameters for from_file 52 | 53 | DATA_TRAIN_GT: ./txt/places_gt_train.txt 54 | DATA_TRAIN_STRUCTURE: ./txt/places_structure_train.txt 55 | DATA_VAL_GT: ./txt/places_gt_val.txt 56 | DATA_VAL_STRUCTURE: ./txt/places_structure_val.txt 57 | DATA_VAL_MASK: ./txt/places_mask_val.txt 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /example/celeba/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/1.jpg -------------------------------------------------------------------------------- /example/celeba/1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/1_mask.png -------------------------------------------------------------------------------- /example/celeba/1_tsmooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/1_tsmooth.png -------------------------------------------------------------------------------- /example/celeba/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/2.jpg -------------------------------------------------------------------------------- /example/celeba/2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/2_mask.png -------------------------------------------------------------------------------- /example/celeba/2_tsmooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/2_tsmooth.png -------------------------------------------------------------------------------- /example/celeba/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/3.jpg -------------------------------------------------------------------------------- /example/celeba/3_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/3_mask.png -------------------------------------------------------------------------------- /example/celeba/3_tsmooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/3_tsmooth.png -------------------------------------------------------------------------------- /example/celeba/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/4.jpg -------------------------------------------------------------------------------- /example/celeba/4_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/4_mask.png -------------------------------------------------------------------------------- /example/celeba/4_tsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/4_tsmooth.jpg -------------------------------------------------------------------------------- /example/paris/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/1.jpg -------------------------------------------------------------------------------- /example/paris/1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/1_mask.png -------------------------------------------------------------------------------- /example/paris/1_tsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/1_tsmooth.jpg -------------------------------------------------------------------------------- /example/paris/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/2.png -------------------------------------------------------------------------------- /example/paris/2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/2_mask.png -------------------------------------------------------------------------------- /example/paris/2_tsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/2_tsmooth.jpg -------------------------------------------------------------------------------- /example/places/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/1.jpg -------------------------------------------------------------------------------- /example/places/1_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/1_mask.png -------------------------------------------------------------------------------- /example/places/1_tsmooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/1_tsmooth.png -------------------------------------------------------------------------------- /example/places/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/2.jpg -------------------------------------------------------------------------------- /example/places/2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/2_mask.png -------------------------------------------------------------------------------- /example/places/2_tsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/2_tsmooth.jpg -------------------------------------------------------------------------------- /example/places/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/3.jpg -------------------------------------------------------------------------------- /example/places/3_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/3_mask.png -------------------------------------------------------------------------------- /example/places/3_tsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/3_tsmooth.jpg -------------------------------------------------------------------------------- /example/places/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/4.jpg -------------------------------------------------------------------------------- /example/places/4_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/4_mask.png -------------------------------------------------------------------------------- /example/places/4_tsmooth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/4_tsmooth.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import shutil 5 | from src.config import Config 6 | from src.structure_flow import StructureFlow 7 | 8 | def main(mode=None): 9 | r"""starts the model 10 | Args: 11 | mode : train, test, eval, reads from config file if not specified 12 | """ 13 | 14 | config = load_config(mode) 15 | config.MODE = mode 16 | os.environ['CUDA_VISIBLE_DEVICES'] = ''.join(str(e) for e in config.GPU) 17 | 18 | if torch.cuda.is_available(): 19 | config.DEVICE = torch.device("cuda") 20 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner 21 | else: 22 | config.DEVICE = torch.device("cpu") 23 | 24 | model = StructureFlow(config) 25 | 26 | if mode == 'train': 27 | # config.print() 28 | print('\nstart training...\n') 29 | model.train() 30 | 31 | elif mode == 'test': 32 | print('\nstart test...\n') 33 | model.test() 34 | 35 | elif mode == 'eval': 36 | print('\nstart eval...\n') 37 | model.eval() 38 | 39 | 40 | def load_config(mode=None): 41 | r"""loads model config 42 | """ 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--name', type=str, help='output model name.') 45 | parser.add_argument('--config', type=str, default='model_config.yaml', help='Path to the config file.') 46 | parser.add_argument('--path', type=str, default='./results', help='outputs path') 47 | parser.add_argument("--resume_all", action="store_true", help='load model from checkpoints') 48 | parser.add_argument("--remove_log", action="store_true", help='remove previous tensorboard log files') 49 | 50 | 51 | if mode == 'test': 52 | parser.add_argument('--input', type=str, help='path to the input image files') 53 | parser.add_argument('--mask', type=str, help='path to the mask files') 54 | parser.add_argument('--structure', type=str, help='path to the structure files') 55 | parser.add_argument('--output', type=str, help='path to the output directory') 56 | parser.add_argument('--model', type=int, default=3, help='which model to test') 57 | 58 | opts = parser.parse_args() 59 | config = Config(opts, mode) 60 | output_dir = os.path.join(opts.path, opts.name) 61 | perpare_sub_floder(output_dir) 62 | 63 | if mode == 'train': 64 | config_dir = os.path.join(output_dir, 'config.yaml') 65 | shutil.copyfile(opts.config, config_dir) 66 | return config 67 | 68 | 69 | def perpare_sub_floder(output_path): 70 | img_dir = os.path.join(output_path, 'images') 71 | if not os.path.exists(img_dir): 72 | print("Creating directory: {}".format(img_dir)) 73 | os.makedirs(img_dir) 74 | 75 | 76 | checkpoints_dir = os.path.join(output_path, 'checkpoints') 77 | if not os.path.exists(checkpoints_dir): 78 | print("Creating directory: {}".format(checkpoints_dir)) 79 | os.makedirs(checkpoints_dir) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /model_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 3 # 1: structure model, 2: inpaint model, 3: structure-inpaint model 2 | VERBOSE: 1 # turns on verbose mode in the output console 3 | GPU: [1] # gpu ids 4 | 5 | 6 | MAX_ITERS: 8e6 # maximum number of iterations to train the model 7 | LR: 0.0001 # learning rate 8 | BETA1: 0 # adam optimizer beta1 9 | BETA2: 0.999 # adam optimizer beta2 10 | LR_POLICY: constant # the method to adjust learning rate (eg: constant|step) 11 | STEP_SIZE: 100000 # Period of learning rate decay (only used when choosing "step" as the lr adjusment method) 12 | GAMMA: 0.5 # Multiplicative factor of learning rate decay. (only used when choosing "step" as the lr adjusment method) 13 | INIT_TYPE: xavier # initialization [gaussian/kaiming/xavier/orthogonal] 14 | 15 | SAVE_INTERVAL: 10000 # how many iterations to wait before saving model 16 | SAVE_LATEST: 1000 # how many iterations to wait before saving lastest model 17 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling 18 | SAMPLE_SIZE: 4 # number of images to sample 19 | EVAL_INTERVAL: 10000 # how many iterations to wait before model evaluation 20 | LOG_INTERVAL: 100 # how many iterations to wait before logging training status 21 | WHICH_ITER: latest # which iterations to load 22 | 23 | DIS_GAN_LOSS: lsgan # type of gan loss 24 | 25 | STRUCTURE_L1: 4 # structure net parameter of l1 loss 26 | STRUCTURE_ADV_GEN: 1 # structure net parameter of gan loss 27 | 28 | FLOW_ADV_GEN: 1 # texture net parameter of gan loss 29 | FLOW_L1: 5 # texture net parameter of l1 loss 30 | FLOW_CORRECTNESS: 0.25 # texture net parameter of sampling correctness loss 31 | VGG_STYLE: 250 # texture net parameter of vgg_style loss (Optional loss on stage_3) 32 | VGG_CONTENT: 0.1 # texture net parameter of vgg_content loss (Optional loss on stage_3) 33 | 34 | 35 | TRAIN_BATCH_SIZE: 8 # batch size 36 | DATA_TRAIN_SIZE: 256 # image size for training 37 | DATA_TEST_SIZE: False # image size for testing (False for never resize) 38 | DATA_FLIP: False # filp image or not when training 39 | DATA_CROP: FALSE #[537,537] # crop size when training (False for never cro) 40 | DATA_MASK_TYPE: from_file # mask type (random_bbox|random_free_form|from_file) 41 | DATA_RANDOM_BBOX_SETTING: # parameters for random_bbox 42 | random_size: False # random hole size according to shape [0.4*shape shape] 43 | shape: [80, 80] # hole size 44 | margin: [0, 0] # minimum distance from the image boundary 45 | num: 3 46 | DATA_RANDOM_FF_SETTING: # parameters for random_free_form 47 | mv: 5 48 | ma: 4.0 49 | ml: 40 50 | mbw: 10 51 | DATA_MASK_FILE: ./txt/irregular_mask.list #parameters for from_file 52 | 53 | # use places365 dataset 54 | DATA_TRAIN_GT: ./txt/places_gt_train.list 55 | DATA_TRAIN_STRUCTURE: ./txt/places_structure_train.list 56 | DATA_VAL_GT: ./txt/places_gt_val.list 57 | DATA_VAL_STRUCTURE: ./txt/places_structure_val.list 58 | DATA_VAL_MASK: ./txt/places_mask_val.list 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /resample2d_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/resample2d_package/__init__.py -------------------------------------------------------------------------------- /resample2d_package/resample2d_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "resample2d_kernel.cuh" 5 | 6 | int resample2d_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& input2, 9 | at::Tensor& output, 10 | int kernel_size, 11 | int dilation) { 12 | resample2d_kernel_forward(input1, input2, output, kernel_size, dilation); 13 | return 1; 14 | } 15 | 16 | int resample2d_cuda_backward( 17 | at::Tensor& input1, 18 | at::Tensor& input2, 19 | at::Tensor& gradOutput, 20 | at::Tensor& gradInput1, 21 | at::Tensor& gradInput2, 22 | int kernel_size, 23 | int dilation) { 24 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, dilation); 25 | return 1; 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); 32 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); 33 | } 34 | 35 | -------------------------------------------------------------------------------- /resample2d_package/resample2d_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define CUDA_NUM_THREADS 256 6 | #define THREADS_PER_BLOCK 64 7 | 8 | #define DIM0(TENSOR) ((TENSOR).x) 9 | #define DIM1(TENSOR) ((TENSOR).y) 10 | #define DIM2(TENSOR) ((TENSOR).z) 11 | #define DIM3(TENSOR) ((TENSOR).w) 12 | 13 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 14 | #define EPS 1e-8 15 | #define SAFE_DIV(a, b) ( (b==0)? ( (a)/(EPS) ): ( (a)/(b) ) ) 16 | 17 | 18 | 19 | 20 | template 21 | __global__ void kernel_resample2d_update_output(const int n, 22 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 23 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 24 | scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size, int dilation) { 25 | int index = blockIdx.x * blockDim.x + threadIdx.x; 26 | 27 | if (index >= n) { 28 | return; 29 | } 30 | 31 | scalar_t val = 0.0f; 32 | scalar_t sum = 0.0f; 33 | 34 | 35 | int dim_b = DIM0(output_size); 36 | int dim_c = DIM1(output_size); 37 | int dim_h = DIM2(output_size); 38 | int dim_w = DIM3(output_size); 39 | int dim_chw = dim_c * dim_h * dim_w; 40 | int dim_hw = dim_h * dim_w; 41 | 42 | int b = ( index / dim_chw ) % dim_b; 43 | int c = ( index / dim_hw ) % dim_c; 44 | int y = ( index / dim_w ) % dim_h; 45 | int x = ( index ) % dim_w; 46 | 47 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 48 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 49 | scalar_t sigma = DIM3_INDEX(input2, b, 2, y, x); 50 | 51 | 52 | scalar_t xf = static_cast(x) + dx; 53 | scalar_t yf = static_cast(y) + dy; 54 | scalar_t alpha = xf - floor(xf); // alpha 55 | scalar_t beta = yf - floor(yf); // beta 56 | 57 | 58 | int idim_h = DIM2(input1_size); 59 | int idim_w = DIM3(input1_size); 60 | 61 | 62 | for (int fy = 0; fy < kernel_size/2; fy += 1) { 63 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0); 64 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0); 65 | 66 | for (int fx = 0; fx < kernel_size/2; fx += 1) { 67 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0); 68 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0); 69 | 70 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha ); 71 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha ); 72 | scalar_t yT_ = ( static_cast( fy *dilation)+beta ); 73 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta ); 74 | 75 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma)); 76 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma)); 77 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma)); 78 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma)); 79 | // if (sigma==0){ 80 | // printf("xL_P %.10f\n", xL_P); 81 | // // printf("%.10f\n", -(xL_*xL_)/(2*sigma*sigma)); 82 | 83 | // } 84 | 85 | val += static_cast (yT_P*xL_P * DIM3_INDEX(input1, b, c, yT, xL)); 86 | val += static_cast (yT_P*xR_P * DIM3_INDEX(input1, b, c, yT, xR)); 87 | val += static_cast (yB_P*xL_P * DIM3_INDEX(input1, b, c, yB, xL)); 88 | val += static_cast (yB_P*xR_P * DIM3_INDEX(input1, b, c, yB, xR)); 89 | sum += (yT_P*xL_P + yT_P*xR_P + yB_P*xL_P + yB_P*xR_P); 90 | } 91 | } 92 | 93 | output[index] = SAFE_DIV(val, sum); 94 | 95 | } 96 | 97 | 98 | template 99 | __global__ void kernel_resample2d_backward_input1( 100 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 101 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 102 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 103 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, int dilation) { 104 | 105 | int index = blockIdx.x * blockDim.x + threadIdx.x; 106 | 107 | if (index >= n) { 108 | return; 109 | } 110 | 111 | scalar_t sum = 0.0f; 112 | // scalar_t *xL_P = new scalar_t [kernel_size*kernel_size/4]; 113 | // scalar_t *xR_P = new scalar_t [kernel_size*kernel_size/4]; 114 | // scalar_t *yT_P = new scalar_t [kernel_size*kernel_size/4]; 115 | // scalar_t *yB_P = new scalar_t [kernel_size*kernel_size/4]; 116 | 117 | int dim_b = DIM0(gradOutput_size); 118 | int dim_c = DIM1(gradOutput_size); 119 | int dim_h = DIM2(gradOutput_size); 120 | int dim_w = DIM3(gradOutput_size); 121 | int dim_chw = dim_c * dim_h * dim_w; 122 | int dim_hw = dim_h * dim_w; 123 | 124 | int b = ( index / dim_chw ) % dim_b; 125 | int c = ( index / dim_hw ) % dim_c; 126 | int y = ( index / dim_w ) % dim_h; 127 | int x = ( index ) % dim_w; 128 | 129 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 130 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 131 | scalar_t sigma = DIM3_INDEX(input2, b, 2, y, x); 132 | 133 | 134 | 135 | scalar_t xf = static_cast(x) + dx; 136 | scalar_t yf = static_cast(y) + dy; 137 | scalar_t alpha = xf - int(xf); // alpha 138 | scalar_t beta = yf - int(yf); // beta 139 | 140 | for (int fy = 0; fy < kernel_size/2; fy += 1) { 141 | for (int fx = 0; fx < kernel_size/2; fx += 1) { 142 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha ); 143 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha ); 144 | scalar_t yT_ = ( static_cast( fy *dilation)+beta ); 145 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta ); 146 | // scalar_t xL_ = ( alpha+static_cast(fx) ); 147 | // scalar_t xR_ = ( 1.-alpha+static_cast(fx) ); 148 | // scalar_t yT_ = ( beta+static_cast(fy) ); 149 | // scalar_t yB_ = ( 1-beta+static_cast(fy) ); 150 | 151 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma)); 152 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma)); 153 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma)); 154 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma)); 155 | // scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_,2*sigma*sigma)); 156 | // scalar_t xR_P = exp(-(xR_*xR_)/(2*sigma*sigma)); 157 | // scalar_t yT_P = exp(-(yT_*yT_)/(2*sigma*sigma)); 158 | // scalar_t yB_P = exp(-(yB_*yB_)/(2*sigma*sigma)); 159 | sum += (yT_P*xL_P + yT_P*xR_P + yB_P*xL_P + yB_P*xR_P); 160 | // printf("%f\n", SAFE_DIV(-xL_*xL_, 2*sigma*sigma)); 161 | } 162 | } 163 | 164 | int idim_h = DIM2(input1_size); 165 | int idim_w = DIM3(input1_size); 166 | 167 | 168 | for (int fy = 0; fy < kernel_size/2; fy += 1) { 169 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0); 170 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0); 171 | // int yT = max(min( int (floor(yf)-fy ), idim_h-1), 0); 172 | // int yB = max(min( int (floor(yf)+fy+1), idim_h-1), 0); 173 | 174 | for (int fx = 0; fx < kernel_size/2; fx += 1) { 175 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0); 176 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0); 177 | // int xL = max(min( int (floor(xf)-fx ), idim_w-1), 0); 178 | // int xR = max(min( int (floor(xf)+fx+1), idim_w-1), 0); 179 | 180 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha ); 181 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha ); 182 | scalar_t yT_ = ( static_cast( fy *dilation)+beta ); 183 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta ); 184 | // scalar_t xL_ = ( alpha+static_cast(fx) ); 185 | // scalar_t xR_ = ( 1.-alpha+static_cast(fx) ); 186 | // scalar_t yT_ = ( beta+static_cast(fy) ); 187 | // scalar_t yB_ = ( 1-beta+static_cast(fy) ); 188 | 189 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma)); 190 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma)); 191 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma)); 192 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma)); 193 | 194 | 195 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT), (xL)), SAFE_DIV(yT_P*xL_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x)); 196 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT), (xR)), SAFE_DIV(yT_P*xR_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x)); 197 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB), (xL)), SAFE_DIV(yB_P*xL_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x)); 198 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB), (xR)), SAFE_DIV(yB_P*xR_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x)); 199 | } 200 | } 201 | 202 | } 203 | 204 | template 205 | __global__ void kernel_resample2d_backward_input2( 206 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 207 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 208 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 209 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, int dilation) { 210 | 211 | int index = blockIdx.x * blockDim.x + threadIdx.x; 212 | 213 | if (index >= n) { 214 | return; 215 | } 216 | 217 | scalar_t grad1 = 0.0; 218 | scalar_t grad2 = 0.0; 219 | scalar_t sum = 0.0; 220 | 221 | 222 | 223 | int dim_b = DIM0(gradInput_size); 224 | int dim_c = DIM1(gradInput_size); 225 | int dim_h = DIM2(gradInput_size); 226 | int dim_w = DIM3(gradInput_size); 227 | int dim_chw = dim_c * dim_h * dim_w; 228 | int dim_hw = dim_h * dim_w; 229 | 230 | int b = ( index / dim_chw ) % dim_b; 231 | int c = ( index / dim_hw ) % dim_c; 232 | int y = ( index / dim_w ) % dim_h; 233 | int x = ( index ) % dim_w; 234 | 235 | int odim_c = DIM1(gradOutput_size); 236 | 237 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 238 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 239 | scalar_t sigma = DIM3_INDEX(input2, b, 2, y, x); 240 | 241 | 242 | scalar_t xf = static_cast(x) + dx; 243 | scalar_t yf = static_cast(y) + dy; 244 | scalar_t alpha = xf - floor(xf); // alpha 245 | scalar_t beta = yf - floor(yf); // beta 246 | 247 | 248 | int idim_h = DIM2(input1_size); 249 | int idim_w = DIM3(input1_size); 250 | scalar_t sumgrad = 0.0; 251 | 252 | for (int fy = 0; fy < kernel_size/2; fy += 1) { 253 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0); 254 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0); 255 | 256 | for (int fx = 0; fx < kernel_size/2; fx += 1) { 257 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0); 258 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0); 259 | 260 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha ); 261 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha ); 262 | scalar_t yT_ = ( static_cast( fy *dilation)+beta ); 263 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta ); 264 | 265 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma)); 266 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma)); 267 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma)); 268 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma)); 269 | sum += (yT_P*xL_P + yT_P*xR_P + yB_P*xL_P + yB_P*xR_P); 270 | 271 | for (int ch = 0; ch < odim_c; ++ch) { 272 | if (c==0) { 273 | grad1 += SAFE_DIV(xL_ * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL), -sigma*sigma); 274 | grad1 -= SAFE_DIV(xR_ * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR), -sigma*sigma); 275 | grad1 += SAFE_DIV(xL_ * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL), -sigma*sigma); 276 | grad1 -= SAFE_DIV(xR_ * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR), -sigma*sigma); 277 | sumgrad += SAFE_DIV(( xL_*yT_P*xL_P - xR_*yT_P*xR_P + xL_*yB_P*xL_P - xR_*yB_P*xR_P ), -sigma*sigma); 278 | } 279 | else if (c==1) { 280 | grad1 += SAFE_DIV(yT_ * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL), -sigma*sigma); 281 | grad1 += SAFE_DIV(yT_ * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR), -sigma*sigma); 282 | grad1 -= SAFE_DIV(yB_ * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL), -sigma*sigma); 283 | grad1 -= SAFE_DIV(yB_ * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR), -sigma*sigma); 284 | sumgrad += SAFE_DIV(( yT_*yT_P*xL_P + yT_*yT_P*xR_P - yB_*yB_P*xL_P - yB_*yB_P*xR_P ), -sigma*sigma); 285 | } 286 | else if (c==2) { 287 | grad1 += SAFE_DIV((yT_*yT_+xL_*xL_) * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL), sigma*sigma*sigma); 288 | grad1 += SAFE_DIV((yT_*yT_+xR_*xR_) * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR), sigma*sigma*sigma); 289 | grad1 += SAFE_DIV((yB_*yB_+xL_*xL_) * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL), sigma*sigma*sigma); 290 | grad1 += SAFE_DIV((yB_*yB_+xR_*xR_) * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR), sigma*sigma*sigma); 291 | sumgrad += SAFE_DIV(( (yT_*yT_+xL_*xL_)*yT_P*xL_P + (yT_*yT_+xR_*xR_)*yT_P*xR_P + (yB_*yB_+xL_*xL_)*yB_P*xL_P + (yB_*yB_+xR_*xR_)*yB_P*xR_P ), sigma*sigma*sigma); 292 | 293 | } 294 | } 295 | } 296 | } 297 | 298 | 299 | 300 | for (int fy = 0; fy < kernel_size/2; fy += 1) { 301 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0); 302 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0); 303 | 304 | for (int fx = 0; fx < kernel_size/2; fx += 1) { 305 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0); 306 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0); 307 | 308 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha ); 309 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha ); 310 | scalar_t yT_ = ( static_cast( fy *dilation)+beta ); 311 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta ); 312 | 313 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma)); 314 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma)); 315 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma)); 316 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma)); 317 | for (int ch = 0; ch < odim_c; ++ch) { 318 | grad2 += sumgrad/odim_c * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL); 319 | grad2 += sumgrad/odim_c * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR); 320 | grad2 += sumgrad/odim_c * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL); 321 | grad2 += sumgrad/odim_c * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR); 322 | } 323 | 324 | } 325 | } 326 | 327 | 328 | gradInput[index] = SAFE_DIV(grad1, sum) - SAFE_DIV(grad2, sum*sum); 329 | 330 | } 331 | 332 | 333 | 334 | 335 | void resample2d_kernel_forward( 336 | at::Tensor& input1, 337 | at::Tensor& input2, 338 | at::Tensor& output, 339 | int kernel_size, 340 | int dilation) { 341 | 342 | int n = output.numel(); 343 | 344 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 345 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 346 | 347 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); 348 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); 349 | 350 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 351 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 352 | 353 | // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF 354 | AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { 355 | kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 356 | n, 357 | input1.data(), 358 | input1_size, 359 | input1_stride, 360 | input2.data(), 361 | input2_size, 362 | input2_stride, 363 | output.data(), 364 | output_size, 365 | output_stride, 366 | kernel_size, 367 | dilation); 368 | 369 | })); 370 | 371 | // TODO: ATen-equivalent check 372 | 373 | // THCudaCheck(cudaGetLastError()); 374 | 375 | } 376 | 377 | void resample2d_kernel_backward( 378 | at::Tensor& input1, 379 | at::Tensor& input2, 380 | at::Tensor& gradOutput, 381 | at::Tensor& gradInput1, 382 | at::Tensor& gradInput2, 383 | int kernel_size, 384 | int dilation) { 385 | 386 | int n = gradOutput.numel(); 387 | 388 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 389 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 390 | 391 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); 392 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); 393 | 394 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); 395 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); 396 | 397 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); 398 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); 399 | 400 | AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { 401 | 402 | kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 403 | n, 404 | input1.data(), 405 | input1_size, 406 | input1_stride, 407 | input2.data(), 408 | input2_size, 409 | input2_stride, 410 | gradOutput.data(), 411 | gradOutput_size, 412 | gradOutput_stride, 413 | gradInput1.data(), 414 | gradInput1_size, 415 | gradInput1_stride, 416 | kernel_size, 417 | dilation 418 | ); 419 | 420 | })); 421 | 422 | const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); 423 | const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); 424 | 425 | n = gradInput2.numel(); 426 | 427 | AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { 428 | 429 | 430 | kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 431 | n, 432 | input1.data(), 433 | input1_size, 434 | input1_stride, 435 | input2.data(), 436 | input2_size, 437 | input2_stride, 438 | gradOutput.data(), 439 | gradOutput_size, 440 | gradOutput_stride, 441 | gradInput2.data(), 442 | gradInput2_size, 443 | gradInput2_stride, 444 | kernel_size, 445 | dilation 446 | ); 447 | 448 | })); 449 | 450 | // TODO: Use the ATen equivalent to get last error 451 | 452 | // THCudaCheck(cudaGetLastError()); 453 | 454 | } 455 | -------------------------------------------------------------------------------- /resample2d_package/resample2d_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void resample2d_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& input2, 8 | at::Tensor& output, 9 | int kernel_size, 10 | int dilation); 11 | 12 | void resample2d_kernel_backward( 13 | at::Tensor& input1, 14 | at::Tensor& input2, 15 | at::Tensor& gradOutput, 16 | at::Tensor& gradInput1, 17 | at::Tensor& gradInput2, 18 | int kernel_size, 19 | int dilation); 20 | -------------------------------------------------------------------------------- /resample2d_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | #'-gencode', 'arch=compute_50,code=sm_50', 12 | #'-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | #'-gencode', 'arch=compute_70,code=sm_70', 16 | #'-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='resample2d_cuda', 21 | ext_modules=[ 22 | CUDAExtension('resample2d_cuda', [ 23 | 'resample2d_cuda.cc', 24 | 'resample2d_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | 31 | -------------------------------------------------------------------------------- /scripts/flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--path', type=str, help='path to the dataset') 7 | parser.add_argument('--output', type=str, help='path to the file list') 8 | args = parser.parse_args() 9 | 10 | ext = {'.jpg', '.png'} 11 | 12 | images = [] 13 | for root, dirs, files in os.walk(args.path): 14 | print('loading ' + root) 15 | for file in files: 16 | if os.path.splitext(file)[1] in ext: 17 | images.append(os.path.join(root, file)) 18 | 19 | images = sorted(images) 20 | np.savetxt(args.output, images, fmt='%s') -------------------------------------------------------------------------------- /scripts/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | Parameters 28 | ---------- 29 | output_blocks : list of int 30 | Indices of blocks to return features of. Possible values are: 31 | - 0: corresponds to output of first max pooling 32 | - 1: corresponds to output of second max pooling 33 | - 2: corresponds to output which is fed to aux classifier 34 | - 3: corresponds to output of final average pooling 35 | resize_input : bool 36 | If true, bilinearly resizes input to width and height 299 before 37 | feeding input to model. As the network without fully connected 38 | layers is fully convolutional, it should be able to handle inputs 39 | of arbitrary size, so resizing might not be strictly needed 40 | normalize_input : bool 41 | If true, normalizes the input to the statistics the pretrained 42 | Inception network expects 43 | requires_grad : bool 44 | If true, parameters of the model require gradient. Possibly useful 45 | for finetuning the network 46 | """ 47 | super(InceptionV3, self).__init__() 48 | 49 | self.resize_input = resize_input 50 | self.normalize_input = normalize_input 51 | self.output_blocks = sorted(output_blocks) 52 | self.last_needed_block = max(output_blocks) 53 | 54 | assert self.last_needed_block <= 3, \ 55 | 'Last possible output block index is 3' 56 | 57 | self.blocks = nn.ModuleList() 58 | 59 | inception = models.inception_v3(pretrained=True) 60 | 61 | # Block 0: input to maxpool1 62 | block0 = [ 63 | inception.Conv2d_1a_3x3, 64 | inception.Conv2d_2a_3x3, 65 | inception.Conv2d_2b_3x3, 66 | nn.MaxPool2d(kernel_size=3, stride=2) 67 | ] 68 | self.blocks.append(nn.Sequential(*block0)) 69 | 70 | # Block 1: maxpool1 to maxpool2 71 | if self.last_needed_block >= 1: 72 | block1 = [ 73 | inception.Conv2d_3b_1x1, 74 | inception.Conv2d_4a_3x3, 75 | nn.MaxPool2d(kernel_size=3, stride=2) 76 | ] 77 | self.blocks.append(nn.Sequential(*block1)) 78 | 79 | # Block 2: maxpool2 to aux classifier 80 | if self.last_needed_block >= 2: 81 | block2 = [ 82 | inception.Mixed_5b, 83 | inception.Mixed_5c, 84 | inception.Mixed_5d, 85 | inception.Mixed_6a, 86 | inception.Mixed_6b, 87 | inception.Mixed_6c, 88 | inception.Mixed_6d, 89 | inception.Mixed_6e, 90 | ] 91 | self.blocks.append(nn.Sequential(*block2)) 92 | 93 | # Block 3: aux classifier to final avgpool 94 | if self.last_needed_block >= 3: 95 | block3 = [ 96 | inception.Mixed_7a, 97 | inception.Mixed_7b, 98 | inception.Mixed_7c, 99 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 100 | ] 101 | self.blocks.append(nn.Sequential(*block3)) 102 | 103 | for param in self.parameters(): 104 | param.requires_grad = requires_grad 105 | 106 | def forward(self, inp): 107 | """Get Inception feature maps 108 | Parameters 109 | ---------- 110 | inp : torch.autograd.Variable 111 | Input tensor of shape Bx3xHxW. Values are expected to be in 112 | range (0, 1) 113 | Returns 114 | ------- 115 | List of torch.autograd.Variable, corresponding to the selected output 116 | block, sorted ascending by index 117 | """ 118 | outp = [] 119 | x = inp 120 | 121 | if self.resize_input: 122 | x = F.upsample(x, size=(299, 299), mode='bilinear') 123 | 124 | if self.normalize_input: 125 | x = x.clone() 126 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 127 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 128 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 129 | 130 | for idx, block in enumerate(self.blocks): 131 | x = block(x) 132 | if idx in self.output_blocks: 133 | outp.append(x) 134 | 135 | if idx == self.last_needed_block: 136 | break 137 | 138 | return outp 139 | -------------------------------------------------------------------------------- /scripts/matlab/code/tsmooth.m: -------------------------------------------------------------------------------- 1 | function S = tsmooth(I,lambda,sigma,sharpness,maxIter,mask) 2 | %tsmooth - Structure Extraction from Texture via Relative Total Variation 3 | % modified for StructureFlow image inpainting. 4 | % S = tsmooth(I, lambda, sigma, maxIter, mask) extracts structure S from 5 | % structure+texture input I, with smoothness weight lambda, scale 6 | % parameter sigma and iteration number maxIter. 7 | % 8 | % Paras: 9 | % @I : Input UINT8 image, both grayscale and color images are acceptable. 10 | % @lambda : Parameter controlling the degree of smooth. 11 | % Range (0, 0.05], 0.01 by default. 12 | % @sigma : Parameter specifying the maximum size of texture elements. 13 | % Range (0, 6], 3 by defalut. 14 | % @sharpness : Parameter controlling the sharpness of the final results, 15 | % which corresponds to \epsilon_s in the paper [1]. The smaller the value, the sharper the result. 16 | % Range (1e-3, 0.03], 0.02 by defalut. 17 | % @maxIter : Number of itearations, 4 by default. 18 | % @mask : The mask of input image of inpainting task. 19 | % 20 | % Example 21 | % ========== 22 | % I = imread('Bishapur_zan.jpg'); 23 | % S = tsmooth(I); % Default Parameters (lambda = 0.01, sigma = 3, sharpness = 0.02, maxIter = 4) 24 | % figure, imshow(I), figure, imshow(S); 25 | % 26 | % ========== 27 | % The Code is created based on the method described in the following paper 28 | % [1] "Structure Extraction from Texture via Relative Total Variation", Li Xu, Qiong Yan, Yang Xia, Jiaya Jia, ACM Transactions on Graphics, 29 | % (SIGGRAPH Asia 2012), 2012. 30 | % The code and the algorithm are for non-comercial use only. 31 | % 32 | % Author: Li Xu (xuli@cse.cuhk.edu.hk) 33 | % Date : 08/25/2012 34 | % Version : 1.0 35 | % Copyright 2012, The Chinese University of Hong Kong. 36 | % 37 | 38 | if (~exist('lambda','var')) 39 | lambda=0.01; 40 | end 41 | if (~exist('sigma','var')) 42 | sigma=3.0; 43 | end 44 | if (~exist('sharpness','var')) 45 | sharpness = 0.02; 46 | end 47 | if (~exist('maxIter','var')) 48 | maxIter=4; 49 | end 50 | if (~exist('mask','var')) 51 | [w,h,~] = size(I); 52 | mask=zeros(w,h); 53 | end 54 | I = im2double(I); 55 | mask = im2double(mask); 56 | x = I; 57 | sigma_iter = sigma; 58 | lambda = lambda/2.0; 59 | dec=2.0; 60 | for iter = 1:maxIter 61 | [wx, wy] = computeTextureWeights(x, mask, sigma_iter, sharpness); 62 | x = solveLinearEquation(I, wx, wy, lambda); 63 | sigma_iter = sigma_iter/dec; 64 | if sigma_iter < 0.5 65 | sigma_iter = 0.5; 66 | end 67 | end 68 | S = x; 69 | end 70 | 71 | % function [retx, rety] = computeTextureWeights(fin, sigma,sharpness) 72 | % 73 | % fx = diff(fin,1,2); 74 | % fx = padarray(fx, [0 1 0], 'post'); 75 | % fy = diff(fin,1,1); 76 | % fy = padarray(fy, [1 0 0], 'post'); 77 | % 78 | % vareps_s = sharpness; 79 | % vareps = 0.001; 80 | % 81 | % wto = max(sum(sqrt(fx.^2+fy.^2),3)/size(fin,3),vareps_s).^(-1); 82 | % fbin = lpfilter(fin, sigma); 83 | % gfx = diff(fbin,1,2); 84 | % gfx = padarray(gfx, [0 1], 'post'); 85 | % gfy = diff(fbin,1,1); 86 | % gfy = padarray(gfy, [1 0], 'post'); 87 | % wtbx = max(sum(abs(gfx),3)/size(fin,3),vareps).^(-1); 88 | % wtby = max(sum(abs(gfy),3)/size(fin,3),vareps).^(-1); 89 | % retx = wtbx.*wto; 90 | % rety = wtby.*wto; 91 | % 92 | % retx(:,end) = 0; 93 | % rety(end,:) = 0; 94 | % 95 | % end 96 | function [retx, rety] = computeTextureWeights(fin, mask, sigma,sharpness) 97 | 98 | fx = diff(fin,1,2); 99 | fx = padarray(fx, [0 1 0], 'post'); 100 | fy = diff(fin,1,1); 101 | fy = padarray(fy, [1 0 0], 'post'); 102 | 103 | mask(mask>0)=1; 104 | fx=fx.*(1-mask); 105 | fy=fy.*(1-mask); 106 | 107 | vareps_s = sharpness; 108 | vareps = 0.001; 109 | 110 | wto = max(sum(sqrt(fx.^2+fy.^2),3)/size(fin,3),vareps_s).^(-1); 111 | fbin = lpfilter(fin, sigma); 112 | gfx = diff(fbin,1,2); 113 | gfx = padarray(gfx, [0 1], 'post'); 114 | gfy = diff(fbin,1,1); 115 | gfy = padarray(gfy, [1 0], 'post'); 116 | wtbx = max(sum(abs(gfx),3)/size(fin,3),vareps).^(-1); 117 | wtby = max(sum(abs(gfy),3)/size(fin,3),vareps).^(-1); 118 | retx = wtbx.*wto; 119 | rety = wtby.*wto; 120 | 121 | retx(:,end) = 0; 122 | rety(end,:) = 0; 123 | 124 | end 125 | 126 | function ret = conv2_sep(im, sigma) 127 | ksize = bitor(round(5*sigma),1); 128 | g = fspecial('gaussian', [1,ksize], sigma); 129 | ret = conv2(im,g,'same'); 130 | ret = conv2(ret,g','same'); 131 | end 132 | 133 | function FBImg = lpfilter(FImg, sigma) 134 | FBImg = FImg; 135 | for ic = 1:size(FBImg,3) 136 | FBImg(:,:,ic) = conv2_sep(FImg(:,:,ic), sigma); 137 | end 138 | end 139 | 140 | function OUT = solveLinearEquation(IN, wx, wy, lambda) 141 | % 142 | % The code for constructing inhomogenious Laplacian is adapted from 143 | % the implementaion of the wlsFilter. 144 | % 145 | % For color images, we enforce wx and wy be same for three channels 146 | % and thus the pre-conditionar only need to be computed once. 147 | % 148 | [r,c,ch] = size(IN); 149 | k = r*c; 150 | dx = -lambda*wx(:); 151 | dy = -lambda*wy(:); 152 | B(:,1) = dx; 153 | B(:,2) = dy; 154 | d = [-r,-1]; 155 | A = spdiags(B,d,k,k); 156 | e = dx; 157 | w = padarray(dx, r, 'pre'); w = w(1:end-r); 158 | s = dy; 159 | n = padarray(dy, 1, 'pre'); n = n(1:end-1); 160 | D = 1-(e+w+s+n); 161 | A = A + A' + spdiags(D, 0, k, k); 162 | if exist('ichol','builtin') 163 | L = ichol(A,struct('michol','on')); 164 | OUT = IN; 165 | for ii=1:ch 166 | tin = IN(:,:,ii); 167 | [tout, flag] = pcg(A, tin(:),0.1,100, L, L'); 168 | OUT(:,:,ii) = reshape(tout, r, c); 169 | end 170 | else 171 | OUT = IN; 172 | for ii=1:ch 173 | tin = IN(:,:,ii); 174 | tout = A\tin(:); 175 | OUT(:,:,ii) = reshape(tout, r, c); 176 | end 177 | end 178 | 179 | end -------------------------------------------------------------------------------- /scripts/matlab/dirPlus.m: -------------------------------------------------------------------------------- 1 | function output = dirPlus(rootPath, varargin) 2 | %dirPlus Recursively collect files or directories within a folder. 3 | % LIST = dirPlus(ROOTPATH) will search recursively through the folder 4 | % tree beneath ROOTPATH and collect a cell array LIST of all files it 5 | % finds. The list will contain the absolute paths to each file starting 6 | % at ROOTPATH. 7 | % 8 | % LIST = dirPlus(ROOTPATH, 'PropertyName', PropertyValue, ...) will 9 | % modify how files and directories are selected, as well as the format of 10 | % LIST, based on the property/value pairs specified. Valid properties 11 | % that the user can set are: 12 | % 13 | % GENERAL: 14 | % 'Struct' - A logical value determining if the output LIST should 15 | % instead be a structure array of the form returned by 16 | % the DIR function. If TRUE, LIST will be an N-by-1 17 | % structure array instead of a cell array. 18 | % 'Depth' - A non-negative integer value for the maximum folder 19 | % tree depth that dirPlus will search through. A value 20 | % of 0 will only search in ROOTPATH, a value of 1 will 21 | % search in ROOTPATH and its subfolders, etc. Default 22 | % (and maximum allowable) value is the current 23 | % recursion limit set on the root object (i.e. 24 | % get(0, 'RecursionLimit')). 25 | % 'ReturnDirs' - A logical value determining if the output will be a 26 | % list of files or subdirectories. If TRUE, LIST will 27 | % be a cell array of subdirectory names/paths. Default 28 | % is FALSE. 29 | % 'PrependPath' - A logical value determining if the full path from 30 | % ROOTPATH to the file/subdirectory is prepended to 31 | % each item in LIST. The default TRUE will prepend the 32 | % full path, otherwise just the file/subdirectory name 33 | % is returned. This setting is ignored if the 'Struct' 34 | % argument is TRUE. 35 | % 36 | % FILE-SPECIFIC: 37 | % 'FileFilter' - A string defining a regular-expression pattern 38 | % that will be applied to the file name. Only files 39 | % matching the pattern will be included in LIST. 40 | % Default is '' (i.e. all files are included). 41 | % 'ValidateFileFcn' - A handle to a function that takes as input a 42 | % structure of the form returned by the DIR 43 | % function and returns a logical value. This 44 | % function will be applied to all files found and 45 | % only files that have a TRUE return value will be 46 | % included in LIST. Default is [] (i.e. all files 47 | % are included). 48 | % 49 | % DIRECTORY-SPECIFIC: 50 | % 'DirFilter' - A string defining a regular-expression pattern 51 | % that will be applied to the subdirectory name. 52 | % Only subdirectories matching the pattern will be 53 | % considered valid (i.e. included in LIST themselves 54 | % or having their files included in LIST). Default 55 | % is '' (i.e. all subdirectories are valid). The 56 | % setting of the 'RecurseInvalid' argument 57 | % determines if invalid subdirectories are still 58 | % recursed down. 59 | % 'ValidateDirFcn' - A handle to a function that takes as input a 60 | % structure of the form returned by the DIR function 61 | % and returns a logical value. This function will be 62 | % applied to all subdirectories found and only 63 | % subdirectories that have a TRUE return value will 64 | % be considered valid (i.e. included in LIST 65 | % themselves or having their files included in 66 | % LIST). Default is [] (i.e. all subdirectories are 67 | % valid). The setting of the 'RecurseInvalid' 68 | % argument determines if invalid subdirectories are 69 | % still recursed down. 70 | % 'RecurseInvalid' - A logical value determining if invalid 71 | % subdirectories (as identified by the 'DirFilter' 72 | % and 'ValidateDirFcn' arguments) should still be 73 | % recursed down. Default is FALSE (i.e the recursive 74 | % searching stops at invalid subdirectories). 75 | % 76 | % Examples: 77 | % 78 | % 1) Find all '.m' files: 79 | % 80 | % fileList = dirPlus(rootPath, 'FileFilter', '\.m$'); 81 | % 82 | % 2) Find all '.m' files, returning the list as a structure array: 83 | % 84 | % fileList = dirPlus(rootPath, 'Struct', true, ... 85 | % 'FileFilter', '\.m$'); 86 | % 87 | % 3) Find all '.jpg', '.png', and '.tif' files: 88 | % 89 | % fileList = dirPlus(rootPath, 'FileFilter', '\.(jpg|png|tif)$'); 90 | % 91 | % 4) Find all '.m' files in the given folder and its subfolders: 92 | % 93 | % fileList = dirPlus(rootPath, 'Depth', 1, 'FileFilter', '\.m$'); 94 | % 95 | % 5) Find all '.m' files, returning only the file names: 96 | % 97 | % fileList = dirPlus(rootPath, 'FileFilter', '\.m$', ... 98 | % 'PrependPath', false); 99 | % 100 | % 6) Find all '.jpg' files with a size of more than 1MB: 101 | % 102 | % bigFcn = @(s) (s.bytes > 1024^2); 103 | % fileList = dirPlus(rootPath, 'FileFilter', '\.jpg$', ... 104 | % 'ValidateFcn', bigFcn); 105 | % 106 | % 7) Find all '.m' files contained in folders containing the string 107 | % 'addons', recursing without restriction: 108 | % 109 | % fileList = dirPlus(rootPath, 'DirFilter', 'addons', ... 110 | % 'FileFilter', '\.m$', ... 111 | % 'RecurseInvalid', true); 112 | % 113 | % See also dir, regexp. 114 | 115 | % Author: Ken Eaton 116 | % Version: MATLAB R2016b - R2011a 117 | % Last modified: 4/14/17 118 | % Copyright 2017 by Kenneth P. Eaton 119 | % Copyright 2017 by Stephen Larroque - backwards compatibility 120 | %-------------------------------------------------------------------------- 121 | 122 | % Create input parser (only have to do this once, hence the use of a 123 | % persistent variable): 124 | 125 | persistent parser 126 | if isempty(parser) 127 | recursionLimit = get(0, 'RecursionLimit'); 128 | parser = inputParser(); 129 | parser.FunctionName = 'dirPlus'; 130 | if verLessThan('matlab', '8.2') % MATLAB R2013b = 8.2 131 | addPVPair = @addParamValue; 132 | else 133 | parser.PartialMatching = true; 134 | addPVPair = @addParameter; 135 | end 136 | 137 | % Add general parameters: 138 | 139 | addRequired(parser, 'rootPath', ... 140 | @(s) validateattributes(s, {'char'}, {'nonempty'})); 141 | addPVPair(parser, 'Struct', false, ... 142 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 143 | addPVPair(parser, 'Depth', recursionLimit, ... 144 | @(s) validateattributes(s, {'numeric'}, ... 145 | {'scalar', 'nonnegative', ... 146 | 'nonnan', 'integer', ... 147 | '<=', recursionLimit})); 148 | addPVPair(parser, 'ReturnDirs', false, ... 149 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 150 | addPVPair(parser, 'PrependPath', true, ... 151 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 152 | 153 | % Add file-specific parameters: 154 | 155 | addPVPair(parser, 'FileFilter', '', ... 156 | @(s) validateattributes(s, {'char'}, {'row'})); 157 | addPVPair(parser, 'ValidateFileFcn', [], ... 158 | @(f) validateattributes(f, {'function_handle'}, {'scalar'})); 159 | 160 | % Add directory-specific parameters: 161 | 162 | addPVPair(parser, 'DirFilter', '', ... 163 | @(s) validateattributes(s, {'char'}, {'row'})); 164 | addPVPair(parser, 'ValidateDirFcn', [], ... 165 | @(f) validateattributes(f, {'function_handle'}, {'scalar'})); 166 | addPVPair(parser, 'RecurseInvalid', false, ... 167 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 168 | 169 | end 170 | 171 | % Parse input and recursively find contents: 172 | 173 | parse(parser, rootPath, varargin{:}); 174 | output = dirPlus_core(parser.Results.rootPath, ... 175 | rmfield(parser.Results, 'rootPath'), 0, true); 176 | if parser.Results.Struct 177 | output = vertcat(output{:}); 178 | end 179 | 180 | end 181 | 182 | %~~~Begin local functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 183 | 184 | %-------------------------------------------------------------------------- 185 | % Core recursive function to find files and directories. 186 | function output = dirPlus_core(rootPath, optionStruct, depth, isValid) 187 | 188 | % Backwards compatibility for fullfile: 189 | 190 | persistent fullfilecell 191 | if isempty(fullfilecell) 192 | if verLessThan('matlab', '8.0') % MATLAB R2012b = 8.0 193 | fullfilecell = @(P, C) cellfun(@(S) fullfile(P, S), C, ... 194 | 'UniformOutput', false); 195 | else 196 | fullfilecell = @fullfile; 197 | end 198 | end 199 | 200 | % Get current directory contents: 201 | 202 | rootData = dir(rootPath); 203 | dirIndex = [rootData.isdir]; 204 | subDirs = {}; 205 | validIndex = []; 206 | 207 | % Find valid subdirectories, only if necessary: 208 | 209 | if (depth < optionStruct.Depth) || optionStruct.ReturnDirs 210 | 211 | % Get subdirectories, not counting current or parent: 212 | 213 | dirData = rootData(dirIndex); 214 | subDirs = {dirData.name}.'; 215 | index = ~ismember(subDirs, {'.', '..'}); 216 | dirData = dirData(index); 217 | subDirs = subDirs(index); 218 | validIndex = true(size(subDirs)); 219 | if any(validIndex) 220 | % Apply directory name filter, if specified: 221 | nameFilter = optionStruct.DirFilter; 222 | if ~isempty(nameFilter) 223 | validIndex = ~cellfun(@isempty, regexp(subDirs, nameFilter)); 224 | end 225 | if any(validIndex) 226 | % Apply validation function to the directory list, if specified: 227 | validateFcn = optionStruct.ValidateDirFcn; 228 | if ~isempty(validateFcn) 229 | validIndex(validIndex) = arrayfun(validateFcn, ... 230 | dirData(validIndex)); 231 | end 232 | end 233 | end 234 | end 235 | % Determine if files or subdirectories are being returned: 236 | if optionStruct.ReturnDirs % Return directories 237 | % Use structure data or prepend full path, if specified: 238 | if optionStruct.Struct 239 | output = {dirData(validIndex)}; 240 | elseif any(validIndex) && optionStruct.PrependPath 241 | output = fullfilecell(rootPath, subDirs(validIndex)); 242 | else 243 | output = subDirs(validIndex); 244 | end 245 | elseif isValid % Return files 246 | % Find all files in the current directory: 247 | fileData = rootData(~dirIndex); 248 | output = {fileData.name}.'; 249 | 250 | if ~isempty(output) 251 | 252 | % Apply file name filter, if specified: 253 | 254 | fileFilter = optionStruct.FileFilter; 255 | if ~isempty(fileFilter) 256 | filterIndex = ~cellfun(@isempty, regexp(output, fileFilter)); 257 | fileData = fileData(filterIndex); 258 | output = output(filterIndex); 259 | end 260 | 261 | if ~isempty(output) 262 | 263 | % Apply validation function to the file list, if specified: 264 | 265 | validateFcn = optionStruct.ValidateFileFcn; 266 | if ~isempty(validateFcn) 267 | validateIndex = arrayfun(validateFcn, fileData); 268 | fileData = fileData(validateIndex); 269 | output = output(validateIndex); 270 | end 271 | 272 | % Use structure data or prepend full path, if specified: 273 | 274 | if optionStruct.Struct 275 | output = {fileData}; 276 | elseif ~isempty(output) && optionStruct.PrependPath 277 | output = fullfilecell(rootPath, output); 278 | end 279 | 280 | end 281 | 282 | end 283 | 284 | else % Return nothing 285 | 286 | output = {}; 287 | 288 | end 289 | 290 | % Check recursion depth: 291 | 292 | if (depth < optionStruct.Depth) 293 | 294 | % Select subdirectories to recurse down: 295 | 296 | if ~optionStruct.RecurseInvalid 297 | subDirs = subDirs(validIndex); 298 | validIndex = validIndex(validIndex); 299 | end 300 | 301 | % Recursively collect output from subdirectories: 302 | 303 | nSubDirs = numel(subDirs); 304 | if (nSubDirs > 0) 305 | subDirs = fullfilecell(rootPath, subDirs); 306 | output = {output; cell(nSubDirs, 1)}; 307 | for iSub = 1:nSubDirs 308 | output{iSub+1} = dirPlus_core(subDirs{iSub}, optionStruct, ... 309 | depth+1, validIndex(iSub)); 310 | end 311 | output = vertcat(output{:}); 312 | end 313 | 314 | end 315 | 316 | end 317 | 318 | %~~~End local functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -------------------------------------------------------------------------------- /scripts/matlab/generate_structure_images.m: -------------------------------------------------------------------------------- 1 | function generate_structure_images(dataset_path, output_path) 2 | addpath('code'); 3 | image_list = dirPlus(dataset_path, 'FileFilter', '\.(jpg|png|tif)$'); 4 | num_image = numel(image_list); 5 | for i=1:num_image 6 | image_name = image_list{i}; 7 | image = im2double(imread(image_name)); 8 | S = tsmooth(image, 0.015, 3, 0.001, 3); 9 | write_name = strrep(image_name, dataset_path, output_path); 10 | [filepath,~,~] = fileparts(write_name); 11 | if ~exist(filepath, 'dir') 12 | mkdir(filepath); 13 | end 14 | imwrite(S, write_name); 15 | 16 | if mod(i,100)==1 17 | fprintf('total: %d; output: %d; completed: %f%% \n',num_image, i, (i/num_image)*100) ; 18 | end 19 | end 20 | end 21 | 22 | -------------------------------------------------------------------------------- /scripts/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | import numpy as np 5 | from scipy.misc import imread 6 | from scipy import linalg 7 | from torch.nn.functional import adaptive_avg_pool2d 8 | from inception import InceptionV3 9 | from skimage.measure import compare_ssim 10 | from skimage.measure import compare_psnr 11 | import glob 12 | import argparse 13 | 14 | 15 | 16 | class FID(): 17 | """docstring for FID 18 | Calculates the Frechet Inception Distance (FID) to evalulate GANs 19 | The FID metric calculates the distance between two distributions of images. 20 | Typically, we have summary statistics (mean & covariance matrix) of one 21 | of these distributions, while the 2nd distribution is given by a GAN. 22 | When run as a stand-alone program, it compares the distribution of 23 | images that are stored as PNG/JPEG at a specified location with a 24 | distribution given by summary statistics (in pickle format). 25 | The FID is calculated by assuming that X_1 and X_2 are the activations of 26 | the pool_3 layer of the inception net for generated samples and real world 27 | samples respectivly. 28 | See --help to see further details. 29 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 30 | of Tensorflow 31 | Copyright 2018 Institute of Bioinformatics, JKU Linz 32 | Licensed under the Apache License, Version 2.0 (the "License"); 33 | you may not use this file except in compliance with the License. 34 | You may obtain a copy of the License at 35 | http://www.apache.org/licenses/LICENSE-2.0 36 | Unless required by applicable law or agreed to in writing, software 37 | distributed under the License is distributed on an "AS IS" BASIS, 38 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | See the License for the specific language governing permissions and 40 | limitations under the License. 41 | """ 42 | def __init__(self): 43 | self.dims = 2048 44 | self.batch_size = 64 45 | self.cuda = True 46 | self.verbose=False 47 | 48 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] 49 | self.model = InceptionV3([block_idx]) 50 | if self.cuda: 51 | # TODO: put model into specific GPU 52 | self.model.cuda() 53 | 54 | def __call__(self, images, gt_path): 55 | """ images: list of the generated image. The values must lie between 0 and 1. 56 | gt_path: the path of the ground truth images. The values must lie between 0 and 1. 57 | """ 58 | if not os.path.exists(gt_path): 59 | raise RuntimeError('Invalid path: %s' % gt_path) 60 | 61 | 62 | print('calculate gt_path statistics...') 63 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose) 64 | print('calculate generated_images statistics...') 65 | m2, s2 = self.calculate_activation_statistics(images, self.verbose) 66 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 67 | return fid_value 68 | 69 | 70 | def calculate_from_disk(self, generated_path, gt_path): 71 | """ 72 | """ 73 | if not os.path.exists(gt_path): 74 | raise RuntimeError('Invalid path: %s' % gt_path) 75 | if not os.path.exists(generated_path): 76 | raise RuntimeError('Invalid path: %s' % generated_path) 77 | 78 | print('calculate gt_path statistics...') 79 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose) 80 | print('calculate generated_path statistics...') 81 | m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose) 82 | print('calculate frechet distance...') 83 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 84 | print('fid_distance %f' % (fid_value)) 85 | return fid_value 86 | 87 | 88 | def compute_statistics_of_path(self, path, verbose): 89 | npz_file = os.path.join(path, 'statistics.npz') 90 | if os.path.exists(npz_file): 91 | f = np.load(npz_file) 92 | m, s = f['mu'][:], f['sigma'][:] 93 | f.close() 94 | else: 95 | path = pathlib.Path(path) 96 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 97 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 98 | 99 | # Bring images to shape (B, 3, H, W) 100 | imgs = imgs.transpose((0, 3, 1, 2)) 101 | 102 | # Rescale images to be between 0 and 1 103 | imgs /= 255 104 | 105 | m, s = self.calculate_activation_statistics(imgs, verbose) 106 | np.savez(npz_file, mu=m, sigma=s) 107 | 108 | return m, s 109 | 110 | def calculate_activation_statistics(self, images, verbose): 111 | """Calculation of the statistics used by the FID. 112 | Params: 113 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 114 | must lie between 0 and 1. 115 | -- model : Instance of inception model 116 | -- batch_size : The images numpy array is split into batches with 117 | batch size batch_size. A reasonable batch size 118 | depends on the hardware. 119 | -- dims : Dimensionality of features returned by Inception 120 | -- cuda : If set to True, use GPU 121 | -- verbose : If set to True and parameter out_step is given, the 122 | number of calculated batches is reported. 123 | Returns: 124 | -- mu : The mean over samples of the activations of the pool_3 layer of 125 | the inception model. 126 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 127 | the inception model. 128 | """ 129 | act = self.get_activations(images, verbose) 130 | mu = np.mean(act, axis=0) 131 | sigma = np.cov(act, rowvar=False) 132 | return mu, sigma 133 | 134 | 135 | 136 | def get_activations(self, images, verbose=False): 137 | """Calculates the activations of the pool_3 layer for all images. 138 | Params: 139 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 140 | must lie between 0 and 1. 141 | -- model : Instance of inception model 142 | -- batch_size : the images numpy array is split into batches with 143 | batch size batch_size. A reasonable batch size depends 144 | on the hardware. 145 | -- dims : Dimensionality of features returned by Inception 146 | -- cuda : If set to True, use GPU 147 | -- verbose : If set to True and parameter out_step is given, the number 148 | of calculated batches is reported. 149 | Returns: 150 | -- A numpy array of dimension (num images, dims) that contains the 151 | activations of the given tensor when feeding inception with the 152 | query tensor. 153 | """ 154 | self.model.eval() 155 | 156 | d0 = images.shape[0] 157 | if self.batch_size > d0: 158 | print(('Warning: batch size is bigger than the data size. ' 159 | 'Setting batch size to data size')) 160 | self.batch_size = d0 161 | 162 | n_batches = d0 // self.batch_size 163 | n_used_imgs = n_batches * self.batch_size 164 | 165 | pred_arr = np.empty((n_used_imgs, self.dims)) 166 | for i in range(n_batches): 167 | if verbose: 168 | print('\rPropagating batch %d/%d' % (i + 1, n_batches)) 169 | # end='', flush=True) 170 | start = i * self.batch_size 171 | end = start + self.batch_size 172 | 173 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 174 | # batch = Variable(batch, volatile=True) 175 | if self.cuda: 176 | batch = batch.cuda() 177 | 178 | pred = self.model(batch)[0] 179 | 180 | # If model output is not scalar, apply global spatial average pooling. 181 | # This happens if you choose a dimensionality not equal 2048. 182 | if pred.shape[2] != 1 or pred.shape[3] != 1: 183 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 184 | 185 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1) 186 | 187 | if verbose: 188 | print(' done') 189 | 190 | return pred_arr 191 | 192 | 193 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 194 | """Numpy implementation of the Frechet Distance. 195 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 196 | and X_2 ~ N(mu_2, C_2) is 197 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 198 | Stable version by Dougal J. Sutherland. 199 | Params: 200 | -- mu1 : Numpy array containing the activations of a layer of the 201 | inception net (like returned by the function 'get_predictions') 202 | for generated samples. 203 | -- mu2 : The sample mean over activations, precalculated on an 204 | representive data set. 205 | -- sigma1: The covariance matrix over activations for generated samples. 206 | -- sigma2: The covariance matrix over activations, precalculated on an 207 | representive data set. 208 | Returns: 209 | -- : The Frechet Distance. 210 | """ 211 | 212 | mu1 = np.atleast_1d(mu1) 213 | mu2 = np.atleast_1d(mu2) 214 | 215 | sigma1 = np.atleast_2d(sigma1) 216 | sigma2 = np.atleast_2d(sigma2) 217 | 218 | assert mu1.shape == mu2.shape, \ 219 | 'Training and test mean vectors have different lengths' 220 | assert sigma1.shape == sigma2.shape, \ 221 | 'Training and test covariances have different dimensions' 222 | 223 | diff = mu1 - mu2 224 | 225 | # Product might be almost singular 226 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 227 | if not np.isfinite(covmean).all(): 228 | msg = ('fid calculation produces singular product; ' 229 | 'adding %s to diagonal of cov estimates') % eps 230 | print(msg) 231 | offset = np.eye(sigma1.shape[0]) * eps 232 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 233 | 234 | # Numerical error might give slight imaginary component 235 | if np.iscomplexobj(covmean): 236 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 237 | m = np.max(np.abs(covmean.imag)) 238 | raise ValueError('Imaginary component {}'.format(m)) 239 | covmean = covmean.real 240 | 241 | tr_covmean = np.trace(covmean) 242 | 243 | return (diff.dot(diff) + np.trace(sigma1) + 244 | np.trace(sigma2) - 2 * tr_covmean) 245 | 246 | 247 | class Reconstruction_Metrics(): 248 | def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True): 249 | self.data_range = data_range 250 | self.win_size = win_size 251 | self.multichannel = multichannel 252 | for metric in metric_list: 253 | if metric in ['ssim', 'psnr', 'l1', 'mae']: 254 | setattr(self, metric, True) 255 | else: 256 | print('unsupport reconstruction metric: %s'%metric) 257 | 258 | 259 | def __call__(self, inputs, gts): 260 | """ 261 | inputs: the generated image, size (b,c,w,h), data range(0, data_range) 262 | gts: the ground-truth image, size (b,c,w,h), data range(0, data_range) 263 | """ 264 | result = dict() 265 | [b,n,w,h] = inputs.size() 266 | inputs = inputs.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0) 267 | gts = gts.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0) 268 | 269 | if hasattr(self, 'ssim'): 270 | ssim_value = compare_ssim(inputs, gts, data_range=self.data_range, 271 | win_size=self.win_size, multichannel=self.multichannel) 272 | result['ssim'] = ssim_value 273 | 274 | if hasattr(self, 'psnr'): 275 | psnr_value = compare_psnr(inputs, gts, self.data_range) 276 | result['psnr'] = psnr_value 277 | 278 | if hasattr(self, 'l1'): 279 | l1_value = compare_l1(inputs, gts) 280 | result['l1'] = l1_value 281 | 282 | if hasattr(self, 'mae'): 283 | mae_value = compare_mae(inputs, gts) 284 | result['mae'] = mae_value 285 | return result 286 | 287 | 288 | def calculate_from_disk(self, inputs, gts, save_path=None, debug=0): 289 | """ 290 | inputs: .txt files, floders, image files (string), image files (list) 291 | gts: .txt files, floders, image files (string), image files (list) 292 | """ 293 | input_image_list = sorted(get_image_list(inputs)) 294 | gt_image_list = sorted(get_image_list(gts)) 295 | print(len(input_image_list)) 296 | print(len(gt_image_list)) 297 | 298 | psnr = [] 299 | ssim = [] 300 | mae = [] 301 | l1 = [] 302 | names = [] 303 | 304 | for index in range(len(input_image_list)): 305 | name = os.path.basename(input_image_list[index]) 306 | names.append(name) 307 | 308 | img_gt = (imread(str(gt_image_list[index]))).astype(np.float32) / 255.0 309 | img_pred = (imread(str(input_image_list[index]))).astype(np.float32) / 255.0 310 | 311 | 312 | if debug != 0: 313 | plt.subplot('121') 314 | plt.imshow(img_gt) 315 | plt.title('Groud truth') 316 | plt.subplot('122') 317 | plt.imshow(img_pred) 318 | plt.title('Output') 319 | plt.show() 320 | 321 | psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range)) 322 | ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range, 323 | win_size=self.win_size,multichannel=self.multichannel)) 324 | mae.append(compare_mae(img_gt, img_pred)) 325 | l1.append(compare_l1(img_gt, img_pred)) 326 | 327 | 328 | if np.mod(index, 200) == 0: 329 | print( 330 | str(index) + ' images processed', 331 | "PSNR: %.4f" % round(np.mean(psnr), 4), 332 | "SSIM: %.4f" % round(np.mean(ssim), 4), 333 | "MAE: %.4f" % round(np.mean(mae), 4), 334 | "l1: %.4f" % round(np.mean(l1), 4), 335 | ) 336 | 337 | if save_path: 338 | np.savez(save_path + '/metrics.npz', psnr=psnr, ssim=ssim, mae=mae, names=names) 339 | 340 | print( 341 | "PSNR: %.4f" % round(np.mean(psnr), 4), 342 | "PSNR Variance: %.4f" % round(np.var(psnr), 4), 343 | "SSIM: %.4f" % round(np.mean(ssim), 4), 344 | "SSIM Variance: %.4f" % round(np.var(ssim), 4), 345 | "MAE: %.4f" % round(np.mean(mae), 4), 346 | "MAE Variance: %.4f" % round(np.var(mae), 4), 347 | "l1: %.4f" % round(np.mean(l1), 4), 348 | "l1 Variance: %.4f" % round(np.var(l1), 4) 349 | ) 350 | 351 | 352 | def get_image_list(flist): 353 | if isinstance(flist, list): 354 | return flist 355 | 356 | # flist: image file path, image directory path, text file flist path 357 | if isinstance(flist, str): 358 | if os.path.isdir(flist): 359 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 360 | flist.sort() 361 | return flist 362 | 363 | if os.path.isfile(flist): 364 | try: 365 | return np.genfromtxt(flist, dtype=np.str) 366 | except: 367 | return [flist] 368 | print('can not read files from %s return empty list'%flist) 369 | return [] 370 | 371 | def compare_l1(img_true, img_test): 372 | img_true = img_true.astype(np.float32) 373 | img_test = img_test.astype(np.float32) 374 | return np.mean(np.abs(img_true - img_test)) 375 | 376 | def compare_mae(img_true, img_test): 377 | img_true = img_true.astype(np.float32) 378 | img_test = img_test.astype(np.float32) 379 | return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test) 380 | 381 | if __name__ == "__main__": 382 | fid = FID() 383 | rec = Reconstruction_Metrics() 384 | 385 | parser = argparse.ArgumentParser(description='script to compute all statistics') 386 | parser.add_argument('--input_path', help='Path to ground truth data', type=str) 387 | parser.add_argument('--output_path', help='Path to output data', type=str) 388 | parser.add_argument('--fid_real_path', help='Path to real images when calculate FID', type=str) 389 | args = parser.parse_args() 390 | 391 | for arg in vars(args): 392 | print('[%s] =' % arg, getattr(args, arg)) 393 | 394 | # rec.calculate_from_disk(args.input_path, args.output_path, save_path=args.output_path) 395 | fid.calculate_from_disk(args.output_path, args.fid_real_path) 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | -------------------------------------------------------------------------------- /src/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | from .utils import weights_init, get_iteration 8 | 9 | 10 | class BaseModel(nn.Module): 11 | def __init__(self, name, config): 12 | super(BaseModel, self).__init__() 13 | self.config = config 14 | self.samples_path = os.path.join(config.PATH, config.NAME, 'images') 15 | self.checkpoints_path = os.path.join(config.PATH, config.NAME, 'checkpoints') 16 | 17 | def save(self, which_epoch): 18 | """Save all the networks to the disk""" 19 | for net_name in self.net_name: 20 | if hasattr(self, net_name) and not(self.config.MODEL == 3 and 's_' in net_name): 21 | sub_net = getattr(self, net_name) 22 | save_filename = '%s_net_%s.pth' % (which_epoch, net_name) 23 | save_path = os.path.join(self.checkpoints_path, save_filename) 24 | torch.save(sub_net.state_dict(), save_path) 25 | 26 | 27 | def load(self, which_epoch): 28 | for net_name in self.net_name: 29 | if hasattr(self, net_name): 30 | sub_net = getattr(self, net_name) 31 | filename = '%s_net_%s.pth' % (which_epoch, net_name) 32 | model_name = os.path.join(self.checkpoints_path, filename) 33 | if not os.path.isfile(model_name): 34 | print('checkpoint %s do not exist'%model_name) 35 | continue 36 | self.load_networks(model_name, sub_net, net_name) 37 | self.iterations = get_iteration(self.checkpoints_path, filename, net_name) 38 | print('Resume %s from iteration %s' % (net_name, which_epoch)) 39 | 40 | sub_net_opt = getattr(self, net_name+'_opt') 41 | setattr(self, net_name+'_scheduler', self.get_scheduler(sub_net_opt)) 42 | 43 | 44 | def load_networks(self, path, net, name): 45 | """Load all the networks from the disk""" 46 | try: 47 | net.load_state_dict(torch.load(path)) 48 | except: 49 | pretrained_dict = torch.load(path) 50 | model_dict = net.state_dict() 51 | try: 52 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 53 | net.load_state_dict(pretrained_dict) 54 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % name) 55 | except: 56 | print('Pretrained network %s has fewer layers; The following are not initialized:' % name) 57 | not_initialized = set() 58 | for k, v in pretrained_dict.items(): 59 | if v.size() == model_dict[k].size(): 60 | model_dict[k] = v 61 | 62 | for k, v in model_dict.items(): 63 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 64 | not_initialized.add(k) 65 | print(sorted(not_initialized)) 66 | net.load_state_dict(model_dict) 67 | 68 | 69 | def init(self): 70 | for net_name in self.net_name: 71 | if hasattr(self, net_name): 72 | sub_net = getattr(self, net_name) 73 | sub_net.apply(weights_init(self.config.INIT_TYPE)) 74 | 75 | 76 | def define_optimizer(self): 77 | for net_name in self.net_name: 78 | if hasattr(self, net_name): 79 | sub_net = getattr(self, net_name) 80 | optimizer = optim.Adam(sub_net.parameters(),lr=self.config.LR,betas=(self.config.BETA1, self.config.BETA2)) 81 | scheduler = self.get_scheduler(optimizer) 82 | setattr(self, net_name+'_opt', optimizer) 83 | setattr(self, net_name+'_scheduler', scheduler) 84 | 85 | 86 | def get_scheduler(self, optimizer): 87 | if self.config.LR_POLICY == None or self.config.LR_POLICY == 'constant': 88 | scheduler = None 89 | elif self.config.LR_POLICY == 'step': 90 | scheduler = lr_scheduler.StepLR(optimizer, step_size=self.config.STEP_SIZE, 91 | gamma=self.config.GAMMA, last_epoch = self.iterations-1) 92 | else: 93 | return NotImplementedError('learning rate policy [%s] is not implemented', self.config.LR_POLICY) 94 | return scheduler 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | class Config(dict): 5 | def __init__(self, opts, mode): 6 | with open(opts.config, 'r') as f: 7 | self._yaml = f.read() 8 | self._dict = yaml.load(self._yaml) 9 | 10 | self.modify_param(opts, mode) 11 | 12 | 13 | def __getattr__(self, name): 14 | if self._dict.get(name) is not None: 15 | return self._dict[name] 16 | return None 17 | 18 | def print(self): 19 | print('Model configurations:') 20 | print('---------------------------------') 21 | print(self._yaml) 22 | print('') 23 | print('---------------------------------') 24 | print('') 25 | 26 | 27 | def modify_param(self, opts, mode): 28 | self._dict['PATH'] = opts.path 29 | self._dict['NAME'] = opts.name 30 | self._dict['RESUME_ALL'] = opts.resume_all 31 | 32 | if mode == 'test': 33 | self._dict['DATA_TEST_GT'] = opts.input 34 | self._dict['DATA_TEST_MASK'] = opts.mask 35 | self._dict['DATA_TEST_STRUCTURE'] = opts.structure 36 | self._dict['DATA_TEST_RESULTS'] = opts.output 37 | self._dict['MODEL'] = opts.model 38 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import os 4 | import os.path 5 | import glob 6 | from torchvision import transforms 7 | import torchvision.transforms.functional as transFunc 8 | import random 9 | import numpy as np 10 | import torch.nn.functional as F 11 | import math 12 | import scipy.io as scio 13 | import torch.utils.data as data 14 | from PIL import Image 15 | from torch.utils.data import DataLoader 16 | import cv2 17 | 18 | 19 | ## TODO: choose with or without transformation at test mode 20 | class Dataset(data.Dataset): 21 | def __init__(self, gt_file, structure_file, config, mask_file=None): 22 | self.gt_image_files = self.load_file_list(gt_file) 23 | self.structure_image_files = self.load_file_list(structure_file) 24 | 25 | if len(self.gt_image_files) == 0: 26 | raise(RuntimeError("Found 0 images in the input files " + "\n")) 27 | 28 | if config.MODE == 'test': 29 | self.transform_opt = {'crop': False, 'flip': False, 30 | 'resize': config.DATA_TEST_SIZE, 'random_load_mask': False} 31 | config.DATA_MASK_TYPE == 'from_file' if mask_file is not None else config.DATA_MASK_TYPE 32 | else: 33 | self.transform_opt = {'crop': config.DATA_CROP, 'flip': config.DATA_FLIP, 34 | 'resize': config.DATA_TRAIN_SIZE, 'random_load_mask': True} 35 | 36 | self.mask_type = config.DATA_MASK_TYPE 37 | # generate random rectangle mask 38 | if self.mask_type == 'random_bbox': 39 | self.mask_setting = config.DATA_RANDOM_BBOX_SETTING 40 | # generate random free form mask 41 | elif self.mask_type == 'random_free_form': 42 | self.mask_setting = config.DATA_RANDOM_FF_SETTING 43 | # read masks from files 44 | elif self.mask_type == 'from_file': 45 | self.mask_image_files = self.load_file_list(mask_file) 46 | 47 | def __getitem__(self, index): 48 | try: 49 | item = self.load_item(index) 50 | except: 51 | print('loading error: ' + self.gt_image_files[index]) 52 | item = self.load_item(0) 53 | return item 54 | 55 | def __len__(self): 56 | return len(self.gt_image_files) 57 | 58 | 59 | def load_file_list(self, flist): 60 | if isinstance(flist, list): 61 | return flist 62 | 63 | # flist: image file path, image directory path, text file flist path 64 | if isinstance(flist, str): 65 | if os.path.isdir(flist): 66 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 67 | flist.sort() 68 | return flist 69 | 70 | if os.path.isfile(flist): 71 | try: 72 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 73 | except: 74 | return [flist] 75 | return [] 76 | 77 | 78 | def load_item(self, index): 79 | gt_path = self.gt_image_files[index] 80 | structure_path = self.structure_image_files[index] 81 | gt_image = loader(gt_path) 82 | structure_image = loader(structure_path) 83 | transform_param = get_params(gt_image.size, self.transform_opt) 84 | gt_image, structure_image = transform_image(transform_param, gt_image, structure_image) 85 | 86 | inpaint_map = self.load_mask(index, gt_image) 87 | input_image = gt_image*(1-inpaint_map) 88 | 89 | return input_image, structure_image, gt_image, inpaint_map 90 | 91 | 92 | def load_mask(self, index, img): 93 | _, w, h = img.shape 94 | image_shape = [w, h] 95 | if self.mask_type == 'random_bbox': 96 | bboxs = [] 97 | for i in range(self.mask_setting['num']): 98 | bbox = random_bbox(self.mask_setting, image_shape) 99 | bboxs.append(bbox) 100 | mask = bbox2mask(bboxs, image_shape, self.mask_setting) 101 | return torch.from_numpy(mask) 102 | 103 | elif self.mask_type == 'random_free_form': 104 | mask = random_ff_mask(self.mask_setting, image_shape) 105 | return torch.from_numpy(mask) 106 | 107 | elif self.mask_type == 'from_file': 108 | if self.transform_opt['random_load_mask']: 109 | index = np.random.randint(0, len(self.mask_image_files)) 110 | mask = gray_loader(self.mask_image_files[index]) 111 | if random.random() > 0.5: 112 | mask = transFunc.hflip(mask) 113 | if random.random() > 0.5: 114 | mask = transFunc.vflip(mask) 115 | else: 116 | mask = gray_loader(self.mask_image_files[index]) 117 | mask = transFunc.resize(mask, size=image_shape) 118 | mask = transFunc.to_tensor(mask) 119 | mask = (mask > 0).float() 120 | return mask 121 | else: 122 | raise(RuntimeError("No such mask type: %s"%self.mask_type)) 123 | 124 | def load_name(self, index, add_mask_name=False): 125 | name = self.gt_image_files[index] 126 | name = os.path.basename(name) 127 | 128 | if not add_mask_name: 129 | return name 130 | else: 131 | if len(self.mask_image_files)==0: 132 | return name 133 | else: 134 | mask_name = os.path.basename(self.mask_image_files[index]) 135 | mask_name, _ = os.path.splitext(mask_name) 136 | name, ext = os.path.splitext(name) 137 | name = name+'_'+mask_name+ext 138 | return name 139 | 140 | 141 | 142 | def create_iterator(self, batch_size): 143 | while True: 144 | sample_loader = DataLoader( 145 | dataset=self, 146 | batch_size=batch_size, 147 | drop_last=True 148 | ) 149 | 150 | for item in sample_loader: 151 | yield item 152 | 153 | 154 | def random_bbox(config, shape): 155 | """Generate a random tlhw with configuration. 156 | Args: 157 | config: Config should have configuration including DATA_NEW_SHAPE, 158 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 159 | Returns: 160 | tuple: (top, left, height, width) 161 | """ 162 | img_height = shape[0] 163 | img_width = shape[1] 164 | height, width = config['shape'] 165 | ver_margin, hor_margin = config['margin'] 166 | maxt = img_height - ver_margin - height 167 | maxl = img_width - hor_margin - width 168 | t = np.random.randint(low=ver_margin, high=maxt) 169 | l = np.random.randint(low=hor_margin, high=maxl) 170 | h = height 171 | w = width 172 | return (t, l, h, w) 173 | 174 | def random_ff_mask( config, shape): 175 | """Generate a random free form mask with configuration. 176 | 177 | Args: 178 | config: Config should have configuration including DATA_NEW_SHAPES, 179 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 180 | 181 | Returns: 182 | tuple: (top, left, height, width) 183 | """ 184 | 185 | h,w = shape 186 | mask = np.zeros((h,w)) 187 | num_v = 12+np.random.randint(config['mv']) #tf.random_uniform([], minval=0, maxval=config.MAXVERTEX, dtype=tf.int32) 188 | 189 | for i in range(num_v): 190 | start_x = np.random.randint(w) 191 | start_y = np.random.randint(h) 192 | for j in range(1+np.random.randint(5)): 193 | angle = 0.01+np.random.randint(config['ma']) 194 | if i % 2 == 0: 195 | angle = 2 * 3.1415926 - angle 196 | length = 10+np.random.randint(config['ml']) 197 | brush_w = 10+np.random.randint(config['mbw']) 198 | end_x = (start_x + length * np.sin(angle)).astype(np.int32) 199 | end_y = (start_y + length * np.cos(angle)).astype(np.int32) 200 | 201 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) 202 | start_x, start_y = end_x, end_y 203 | 204 | return mask.reshape((1,)+mask.shape).astype(np.float32) 205 | 206 | 207 | def bbox2mask( bboxs, shape, config): 208 | """Generate mask tensor from bbox. 209 | 210 | Args: 211 | bbox: configuration tuple, (top, left, height, width) 212 | config: Config should have configuration including DATA_NEW_SHAPES, 213 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH. 214 | 215 | Returns: 216 | tf.Tensor: output with shape [1, H, W, 1] 217 | 218 | """ 219 | height, width = shape 220 | mask = np.zeros(( height, width), np.float32) 221 | #print(mask.shape) 222 | for bbox in bboxs: 223 | if config['random_size']: 224 | h = int(0.1*bbox[2])+np.random.randint(int(bbox[2]*0.2+1)) 225 | w = int(0.1*bbox[3])+np.random.randint(int(bbox[3]*0.2)+1) 226 | else: 227 | h=0 228 | w=0 229 | mask[bbox[0]+h:bbox[0]+bbox[2]-h, 230 | bbox[1]+w:bbox[1]+bbox[3]-w] = 1. 231 | #print("after", mask.shape) 232 | return mask.reshape((1,)+mask.shape).astype(np.float32) 233 | 234 | 235 | def gray_loader( path): 236 | return Image.open(path) 237 | 238 | def loader( path): 239 | return Image.open(path).convert('RGB') 240 | 241 | def get_params(size, transform_opt): 242 | w, h = size 243 | if transform_opt['flip']: 244 | flip = random.random() > 0.5 245 | else: 246 | flip = False 247 | if transform_opt['crop']: 248 | transform_crop = transform_opt['crop'] \ 249 | if w>=transform_opt['crop'][0] and h>=transform_opt['crop'][1] else [h, w] 250 | x = random.randint(0, np.maximum(0, w - transform_crop[0])) 251 | y = random.randint(0, np.maximum(0, h - transform_crop[1])) 252 | crop = [x, y, transform_crop[0], transform_crop[1]] 253 | else: 254 | crop = False 255 | if transform_opt['resize']: 256 | resize = [transform_opt['resize'], transform_opt['resize'],] 257 | else: 258 | resize = False 259 | param = {'crop': crop, 'flip': flip, 'resize': resize} 260 | return param 261 | 262 | def transform_image(transform_param, gt_image, structure_image, normalize=True, toTensor=True): 263 | transform_list = [] 264 | 265 | if transform_param['crop']: 266 | crop_position = transform_param['crop'][:2] 267 | crop_size = transform_param['crop'][2:] 268 | transform_list.append(transforms.Lambda(lambda img: __crop(img, crop_position, crop_size))) 269 | if transform_param['resize']: 270 | transform_list.append(transforms.Resize(transform_param['resize'])) 271 | if transform_param['flip']: 272 | transform_list.append(transforms.Lambda(lambda img: __flip(img, True))) 273 | 274 | if toTensor: 275 | transform_list += [transforms.ToTensor()] 276 | 277 | if normalize: 278 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 279 | (0.5, 0.5, 0.5))] 280 | trans = transforms.Compose(transform_list) 281 | if gt_image.size != structure_image.size: 282 | structure_image = transFunc.resize(structure_image, size=gt_image.size) 283 | 284 | gt_image = trans(gt_image) 285 | structure_image = trans(structure_image) 286 | 287 | return gt_image, structure_image 288 | 289 | def __crop(img, pos, size): 290 | ow, oh = img.size 291 | x1, y1 = pos 292 | tw, th = size 293 | return img.crop((x1, y1, x1 + tw, y1 + th)) 294 | 295 | def __flip(img, flip): 296 | if flip: 297 | return img.transpose(Image.FLIP_LEFT_RIGHT) 298 | return img 299 | 300 | 301 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | from .resample2d import Resample2d 6 | 7 | from .utils import write_2images 8 | 9 | class AdversarialLoss(nn.Module): 10 | r""" 11 | Adversarial loss 12 | https://arxiv.org/abs/1711.10337 13 | """ 14 | 15 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): 16 | r""" 17 | type = nsgan | lsgan | hinge 18 | """ 19 | super(AdversarialLoss, self).__init__() 20 | 21 | self.type = type 22 | self.register_buffer('real_label', torch.tensor(target_real_label)) 23 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 24 | 25 | if type == 'nsgan': 26 | self.criterion = nn.BCELoss() 27 | 28 | elif type == 'lsgan': 29 | self.criterion = nn.MSELoss() 30 | 31 | elif type == 'hinge': 32 | self.criterion = nn.ReLU() 33 | 34 | def __call__(self, outputs, is_real, for_dis=None): 35 | if self.type == 'hinge': 36 | if for_dis: 37 | if is_real: 38 | outputs = -outputs 39 | return self.criterion(1 + outputs).mean() 40 | else: 41 | return (-outputs).mean() 42 | 43 | else: 44 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) 45 | loss = self.criterion(outputs, labels) 46 | return loss 47 | 48 | 49 | class StyleLoss(nn.Module): 50 | r""" 51 | Perceptual loss, VGG-based 52 | https://arxiv.org/abs/1603.08155 53 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 54 | """ 55 | 56 | def __init__(self): 57 | super(StyleLoss, self).__init__() 58 | self.add_module('vgg', VGG19()) 59 | self.criterion = torch.nn.L1Loss() 60 | 61 | def compute_gram(self, x): 62 | b, ch, h, w = x.size() 63 | f = x.view(b, ch, w * h) 64 | f_T = f.transpose(1, 2) 65 | G = f.bmm(f_T) / (h * w * ch) 66 | 67 | return G 68 | 69 | def __call__(self, x, y): 70 | # Compute features 71 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 72 | 73 | # Compute loss 74 | style_loss = 0.0 75 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 76 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 77 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 78 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 79 | 80 | return style_loss 81 | 82 | 83 | 84 | class PerceptualLoss(nn.Module): 85 | r""" 86 | Perceptual loss, VGG-based 87 | https://arxiv.org/abs/1603.08155 88 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 89 | """ 90 | 91 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 92 | super(PerceptualLoss, self).__init__() 93 | self.add_module('vgg', VGG19()) 94 | self.criterion = torch.nn.L1Loss() 95 | self.weights = weights 96 | 97 | def __call__(self, x, y): 98 | # Compute features 99 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 100 | content_loss = 0.0 101 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 102 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 103 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 104 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 105 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 106 | 107 | return content_loss 108 | 109 | 110 | class PerceptualCorrectness(nn.Module): 111 | r""" 112 | 113 | """ 114 | 115 | def __init__(self, layer='relu3_1'): 116 | super(PerceptualCorrectness, self).__init__() 117 | self.add_module('vgg', VGG19()) 118 | self.layer = layer 119 | self.eps=1e-8 120 | self.resample = Resample2d(4, 1, sigma=2) 121 | 122 | def __call__(self, gts, inputs, flow, maps): 123 | gts_vgg, inputs_vgg = self.vgg(gts), self.vgg(inputs) 124 | gts_vgg = gts_vgg[self.layer] 125 | inputs_vgg = inputs_vgg[self.layer] 126 | [b, c, h, w] = gts_vgg.shape 127 | 128 | maps = F.interpolate(maps, [h,w]).view(b,-1) 129 | flow = F.adaptive_avg_pool2d(flow, [h,w]) 130 | 131 | gts_all = gts_vgg.view(b, c, -1) #[b C N2] 132 | inputs_all = inputs_vgg.view(b, c, -1).transpose(1,2) #[b N2 C] 133 | 134 | 135 | input_norm = inputs_all/(inputs_all.norm(dim=2, keepdim=True)+self.eps) 136 | gt_norm = gts_all/(gts_all.norm(dim=1, keepdim=True)+self.eps) 137 | correction = torch.bmm(input_norm, gt_norm) #[b N2 N2] 138 | (correction_max,max_indices) = torch.max(correction, dim=1) 139 | 140 | # interple with gaussian sampling 141 | input_sample = self.resample(inputs_vgg, flow) 142 | input_sample = input_sample.view(b, c, -1) #[b C N2] 143 | 144 | correction_sample = F.cosine_similarity(input_sample, gts_all) #[b 1 N2] 145 | loss_map = torch.exp(-correction_sample/(correction_max+self.eps)) 146 | loss = torch.sum(loss_map*maps)/(torch.sum(maps)+self.eps) 147 | 148 | return loss 149 | 150 | 151 | 152 | 153 | class VGG19(torch.nn.Module): 154 | def __init__(self): 155 | super(VGG19, self).__init__() 156 | features = models.vgg19(pretrained=True).features 157 | self.relu1_1 = torch.nn.Sequential() 158 | self.relu1_2 = torch.nn.Sequential() 159 | 160 | self.relu2_1 = torch.nn.Sequential() 161 | self.relu2_2 = torch.nn.Sequential() 162 | 163 | self.relu3_1 = torch.nn.Sequential() 164 | self.relu3_2 = torch.nn.Sequential() 165 | self.relu3_3 = torch.nn.Sequential() 166 | self.relu3_4 = torch.nn.Sequential() 167 | 168 | self.relu4_1 = torch.nn.Sequential() 169 | self.relu4_2 = torch.nn.Sequential() 170 | self.relu4_3 = torch.nn.Sequential() 171 | self.relu4_4 = torch.nn.Sequential() 172 | 173 | self.relu5_1 = torch.nn.Sequential() 174 | self.relu5_2 = torch.nn.Sequential() 175 | self.relu5_3 = torch.nn.Sequential() 176 | self.relu5_4 = torch.nn.Sequential() 177 | 178 | for x in range(2): 179 | self.relu1_1.add_module(str(x), features[x]) 180 | 181 | for x in range(2, 4): 182 | self.relu1_2.add_module(str(x), features[x]) 183 | 184 | for x in range(4, 7): 185 | self.relu2_1.add_module(str(x), features[x]) 186 | 187 | for x in range(7, 9): 188 | self.relu2_2.add_module(str(x), features[x]) 189 | 190 | for x in range(9, 12): 191 | self.relu3_1.add_module(str(x), features[x]) 192 | 193 | for x in range(12, 14): 194 | self.relu3_2.add_module(str(x), features[x]) 195 | 196 | for x in range(14, 16): 197 | self.relu3_2.add_module(str(x), features[x]) 198 | 199 | for x in range(16, 18): 200 | self.relu3_4.add_module(str(x), features[x]) 201 | 202 | for x in range(18, 21): 203 | self.relu4_1.add_module(str(x), features[x]) 204 | 205 | for x in range(21, 23): 206 | self.relu4_2.add_module(str(x), features[x]) 207 | 208 | for x in range(23, 25): 209 | self.relu4_3.add_module(str(x), features[x]) 210 | 211 | for x in range(25, 27): 212 | self.relu4_4.add_module(str(x), features[x]) 213 | 214 | for x in range(27, 30): 215 | self.relu5_1.add_module(str(x), features[x]) 216 | 217 | for x in range(30, 32): 218 | self.relu5_2.add_module(str(x), features[x]) 219 | 220 | for x in range(32, 34): 221 | self.relu5_3.add_module(str(x), features[x]) 222 | 223 | for x in range(34, 36): 224 | self.relu5_4.add_module(str(x), features[x]) 225 | 226 | # don't need the gradients, just want the features 227 | for param in self.parameters(): 228 | param.requires_grad = False 229 | 230 | def forward(self, x): 231 | relu1_1 = self.relu1_1(x) 232 | relu1_2 = self.relu1_2(relu1_1) 233 | 234 | relu2_1 = self.relu2_1(relu1_2) 235 | relu2_2 = self.relu2_2(relu2_1) 236 | 237 | relu3_1 = self.relu3_1(relu2_2) 238 | relu3_2 = self.relu3_2(relu3_1) 239 | relu3_3 = self.relu3_3(relu3_2) 240 | relu3_4 = self.relu3_4(relu3_3) 241 | 242 | relu4_1 = self.relu4_1(relu3_4) 243 | relu4_2 = self.relu4_2(relu4_1) 244 | relu4_3 = self.relu4_3(relu4_2) 245 | relu4_4 = self.relu4_4(relu4_3) 246 | 247 | relu5_1 = self.relu5_1(relu4_4) 248 | relu5_2 = self.relu5_2(relu5_1) 249 | relu5_3 = self.relu5_3(relu5_2) 250 | relu5_4 = self.relu5_4(relu5_3) 251 | 252 | out = { 253 | 'relu1_1': relu1_1, 254 | 'relu1_2': relu1_2, 255 | 256 | 'relu2_1': relu2_1, 257 | 'relu2_2': relu2_2, 258 | 259 | 'relu3_1': relu3_1, 260 | 'relu3_2': relu3_2, 261 | 'relu3_3': relu3_3, 262 | 'relu3_4': relu3_4, 263 | 264 | 'relu4_1': relu4_1, 265 | 'relu4_2': relu4_2, 266 | 'relu4_3': relu4_3, 267 | 'relu4_4': relu4_4, 268 | 269 | 'relu5_1': relu5_1, 270 | 'relu5_2': relu5_2, 271 | 'relu5_3': relu5_3, 272 | 'relu5_4': relu5_4, 273 | } 274 | return out 275 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | When run as a stand-alone program, it compares the distribution of 7 | images that are stored as PNG/JPEG at a specified location with a 8 | distribution given by summary statistics (in pickle format). 9 | The FID is calculated by assuming that X_1 and X_2 are the activations of 10 | the pool_3 layer of the inception net for generated samples and real world 11 | samples respectivly. 12 | See --help to see further details. 13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 14 | of Tensorflow 15 | Copyright 2018 Institute of Bioinformatics, JKU Linz 16 | Licensed under the Apache License, Version 2.0 (the "License"); 17 | you may not use this file except in compliance with the License. 18 | You may obtain a copy of the License at 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | Unless required by applicable law or agreed to in writing, software 21 | distributed under the License is distributed on an "AS IS" BASIS, 22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | See the License for the specific language governing permissions and 24 | limitations under the License. 25 | """ 26 | import os 27 | import pathlib 28 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 29 | 30 | import torch 31 | import numpy as np 32 | from scipy.misc import imread 33 | from scipy import linalg 34 | from torch.autograd import Variable 35 | from torch.nn.functional import adaptive_avg_pool2d 36 | 37 | from .inception import InceptionV3 38 | 39 | 40 | 41 | class FID(): 42 | """docstring for FID""" 43 | def __init__(self): 44 | self.dims = 2048 45 | self.batch_size = 64 46 | self.cuda = True 47 | 48 | 49 | 50 | 51 | def __call__(self, images, gt_path): 52 | if not os.path.exists(gt_path): 53 | raise RuntimeError('Invalid path: %s' % gt_path) 54 | 55 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] 56 | 57 | self.model = InceptionV3([block_idx]) 58 | if self.cuda: 59 | # TODO: put model into specific GPU 60 | self.model.cuda() 61 | 62 | print('calculate gt_path statistics...') 63 | m1, s1 = self.compute_statistics_of_path(gt_path) 64 | print('calculate generated_images statistics...') 65 | m2, s2 = self.calculate_activation_statistics(images) 66 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) 67 | return fid_value 68 | 69 | 70 | 71 | def compute_statistics_of_path(self, path): 72 | npz_file = os.path.join(path, 'statistics.npz') 73 | if os.path.exists(npz_file): 74 | f = np.load(npz_file) 75 | m, s = f['mu'][:], f['sigma'][:] 76 | f.close() 77 | else: 78 | path = pathlib.Path(path) 79 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 80 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 81 | 82 | # Bring images to shape (B, 3, H, W) 83 | imgs = imgs.transpose((0, 3, 1, 2)) 84 | 85 | # Rescale images to be between 0 and 1 86 | imgs /= 255 87 | 88 | m, s = self.calculate_activation_statistics(imgs) 89 | np.savez(npz_file, mu=m, sigma=s) 90 | 91 | return m, s 92 | 93 | def calculate_activation_statistics(self, images, verbose=False): 94 | """Calculation of the statistics used by the FID. 95 | Params: 96 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 97 | must lie between 0 and 1. 98 | -- model : Instance of inception model 99 | -- batch_size : The images numpy array is split into batches with 100 | batch size batch_size. A reasonable batch size 101 | depends on the hardware. 102 | -- dims : Dimensionality of features returned by Inception 103 | -- cuda : If set to True, use GPU 104 | -- verbose : If set to True and parameter out_step is given, the 105 | number of calculated batches is reported. 106 | Returns: 107 | -- mu : The mean over samples of the activations of the pool_3 layer of 108 | the inception model. 109 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 110 | the inception model. 111 | """ 112 | act = self.get_activations(images, True) 113 | mu = np.mean(act, axis=0) 114 | sigma = np.cov(act, rowvar=False) 115 | return mu, sigma 116 | 117 | 118 | 119 | def get_activations(self, images, verbose=False): 120 | """Calculates the activations of the pool_3 layer for all images. 121 | Params: 122 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values 123 | must lie between 0 and 1. 124 | -- model : Instance of inception model 125 | -- batch_size : the images numpy array is split into batches with 126 | batch size batch_size. A reasonable batch size depends 127 | on the hardware. 128 | -- dims : Dimensionality of features returned by Inception 129 | -- cuda : If set to True, use GPU 130 | -- verbose : If set to True and parameter out_step is given, the number 131 | of calculated batches is reported. 132 | Returns: 133 | -- A numpy array of dimension (num images, dims) that contains the 134 | activations of the given tensor when feeding inception with the 135 | query tensor. 136 | """ 137 | self.model.eval() 138 | 139 | d0 = images.shape[0] 140 | if self.batch_size > d0: 141 | print(('Warning: batch size is bigger than the data size. ' 142 | 'Setting batch size to data size')) 143 | self.batch_size = d0 144 | 145 | n_batches = d0 // self.batch_size 146 | n_used_imgs = n_batches * self.batch_size 147 | 148 | pred_arr = np.empty((n_used_imgs, self.dims)) 149 | for i in range(n_batches): 150 | if verbose: 151 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 152 | end='', flush=True) 153 | start = i * self.batch_size 154 | end = start + self.batch_size 155 | 156 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 157 | # batch = Variable(batch, volatile=True) 158 | if self.cuda: 159 | batch = batch.cuda() 160 | 161 | pred = self.model(batch)[0] 162 | 163 | # If model output is not scalar, apply global spatial average pooling. 164 | # This happens if you choose a dimensionality not equal 2048. 165 | if pred.shape[2] != 1 or pred.shape[3] != 1: 166 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 167 | 168 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1) 169 | 170 | if verbose: 171 | print(' done') 172 | 173 | return pred_arr 174 | 175 | 176 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 177 | """Numpy implementation of the Frechet Distance. 178 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 179 | and X_2 ~ N(mu_2, C_2) is 180 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 181 | Stable version by Dougal J. Sutherland. 182 | Params: 183 | -- mu1 : Numpy array containing the activations of a layer of the 184 | inception net (like returned by the function 'get_predictions') 185 | for generated samples. 186 | -- mu2 : The sample mean over activations, precalculated on an 187 | representive data set. 188 | -- sigma1: The covariance matrix over activations for generated samples. 189 | -- sigma2: The covariance matrix over activations, precalculated on an 190 | representive data set. 191 | Returns: 192 | -- : The Frechet Distance. 193 | """ 194 | 195 | mu1 = np.atleast_1d(mu1) 196 | mu2 = np.atleast_1d(mu2) 197 | 198 | sigma1 = np.atleast_2d(sigma1) 199 | sigma2 = np.atleast_2d(sigma2) 200 | 201 | assert mu1.shape == mu2.shape, \ 202 | 'Training and test mean vectors have different lengths' 203 | assert sigma1.shape == sigma2.shape, \ 204 | 'Training and test covariances have different dimensions' 205 | 206 | diff = mu1 - mu2 207 | 208 | # Product might be almost singular 209 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 210 | if not np.isfinite(covmean).all(): 211 | msg = ('fid calculation produces singular product; ' 212 | 'adding %s to diagonal of cov estimates') % eps 213 | print(msg) 214 | offset = np.eye(sigma1.shape[0]) * eps 215 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 216 | 217 | # Numerical error might give slight imaginary component 218 | if np.iscomplexobj(covmean): 219 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 220 | m = np.max(np.abs(covmean.imag)) 221 | raise ValueError('Imaginary component {}'.format(m)) 222 | covmean = covmean.real 223 | 224 | tr_covmean = np.trace(covmean) 225 | 226 | return (diff.dot(diff) + np.trace(sigma1) + 227 | np.trace(sigma2) - 2 * tr_covmean) 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .base_model import BaseModel 4 | from .network import StructureGen 5 | from .network import MultiDiscriminator 6 | from .network import FlowGen 7 | from .loss import AdversarialLoss, PerceptualCorrectness, StyleLoss, PerceptualLoss 8 | 9 | 10 | class StructureFlowModel(BaseModel): 11 | def __init__(self, config): 12 | super(StructureFlowModel, self).__init__('StructureFlow', config) 13 | self.config = config 14 | self.net_name = ['s_gen', 's_dis', 'f_gen', 'f_dis'] 15 | 16 | self.structure_param = {'input_dim':3, 'dim':64, 'n_res':4, 'activ':'relu', 17 | 'norm':'in', 'pad_type':'reflect', 'use_sn':True} 18 | self.flow_param = {'input_dim':3, 'dim':64, 'n_res':2, 'activ':'relu', 19 | 'norm_conv':'ln', 'norm_flow':'in', 'pad_type':'reflect', 'use_sn':False} 20 | self.dis_param = {'input_dim':3, 'dim':64, 'n_layers':3, 21 | 'norm':'none', 'activ':'lrelu', 'pad_type':'reflect', 'use_sn':True} 22 | 23 | l1_loss = nn.L1Loss() 24 | adversarial_loss = AdversarialLoss(type=config.DIS_GAN_LOSS) 25 | correctness_loss = PerceptualCorrectness() 26 | vgg_style = StyleLoss() 27 | vgg_content = PerceptualLoss() 28 | self.use_correction_loss = True 29 | self.use_vgg_loss = True if self.config.MODEL == 3 else False 30 | 31 | self.add_module('l1_loss', l1_loss) 32 | self.add_module('adversarial_loss', adversarial_loss) 33 | self.add_module('correctness_loss', correctness_loss) 34 | self.add_module('vgg_style', vgg_style) 35 | self.add_module('vgg_content', vgg_content) 36 | 37 | self.build_model() 38 | 39 | def build_model(self): 40 | self.iterations = 0 41 | # structure model 42 | if self.config.MODEL == 1: 43 | self.s_gen = StructureGen(**self.structure_param) 44 | self.s_dis = MultiDiscriminator(**self.dis_param) 45 | # flow model with true input smooth 46 | elif self.config.MODEL == 2: 47 | self.f_gen = FlowGen(**self.flow_param) 48 | self.f_dis = MultiDiscriminator(**self.dis_param) 49 | # flow model with fake input smooth 50 | elif self.config.MODEL == 3: 51 | self.s_gen = StructureGen(**self.structure_param) 52 | self.f_gen = FlowGen(**self.flow_param) 53 | self.f_dis = MultiDiscriminator(**self.dis_param) 54 | 55 | self.define_optimizer() 56 | self.init() 57 | 58 | 59 | def structure_forward(self, inputs, smooths, maps): 60 | smooths_input = smooths*(1-maps) 61 | outputs = self.s_gen(torch.cat((inputs, smooths_input, maps),dim=1)) 62 | return outputs 63 | 64 | def flow_forward(self, inputs, stage_1, maps): 65 | outputs, flow = self.f_gen(torch.cat((inputs, stage_1, maps),dim=1)) 66 | return outputs, flow 67 | 68 | def sample(self, inputs, smooths, gts, maps): 69 | with torch.no_grad(): 70 | if self.config.MODEL == 1: 71 | outputs = self.structure_forward(inputs, smooths, maps) 72 | result =[inputs,smooths,gts,maps,outputs] 73 | flow = None 74 | elif self.config.MODEL == 2: 75 | outputs, flow = self.flow_forward(inputs, smooths, maps) 76 | result =[inputs,smooths,gts,maps,outputs] 77 | if flow is not None: 78 | flow = [flow[:,0,:,:].unsqueeze(1)/30, flow[:,1,:,:].unsqueeze(1)/30] 79 | 80 | elif self.config.MODEL == 3: 81 | smooth_stage_1 = self.structure_forward(inputs, smooths, maps) 82 | outputs, flow = self.flow_forward(inputs, smooth_stage_1, maps) 83 | result =[inputs,smooths,gts,maps,smooth_stage_1,outputs] 84 | if flow is not None: 85 | flow = [flow[:,0,:,:].unsqueeze(1)/30, flow[:,1,:,:].unsqueeze(1)/30] 86 | return result, flow 87 | 88 | def update_structure(self, inputs, smooths, maps): 89 | self.iterations += 1 90 | 91 | self.s_gen.zero_grad() 92 | self.s_dis.zero_grad() 93 | outputs = self.structure_forward(inputs, smooths, maps) 94 | 95 | dis_loss = 0 96 | dis_fake_input = outputs.detach() 97 | dis_real_input = smooths 98 | fake_labels = self.s_dis(dis_fake_input) 99 | real_labels = self.s_dis(dis_real_input) 100 | for i in range(len(fake_labels)): 101 | dis_real_loss = self.adversarial_loss(real_labels[i], True, True) 102 | dis_fake_loss = self.adversarial_loss(fake_labels[i], False, True) 103 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 104 | self.structure_adv_dis_loss = dis_loss/len(fake_labels) 105 | 106 | self.structure_adv_dis_loss.backward() 107 | self.s_dis_opt.step() 108 | if self.s_dis_scheduler is not None: 109 | self.s_dis_scheduler.step() 110 | 111 | 112 | dis_gen_loss = 0 113 | fake_labels = self.s_dis(outputs) 114 | for i in range(len(fake_labels)): 115 | dis_fake_loss = self.adversarial_loss(fake_labels[i], True, False) 116 | dis_gen_loss += dis_fake_loss 117 | self.structure_adv_gen_loss = dis_gen_loss/len(fake_labels) * self.config.STRUCTURE_ADV_GEN 118 | self.structure_l1_loss = self.l1_loss(outputs, smooths) * self.config.STRUCTURE_L1 119 | self.structure_gen_loss = self.structure_l1_loss + self.structure_adv_gen_loss 120 | 121 | self.structure_gen_loss.backward() 122 | self.s_gen_opt.step() 123 | if self.s_gen_scheduler is not None: 124 | self.s_gen_scheduler.step() 125 | 126 | logs = [ 127 | ("l_s_adv_dis", self.structure_adv_dis_loss.item()), 128 | ("l_s_l1", self.structure_l1_loss.item()), 129 | ("l_s_adv_gen", self.structure_adv_gen_loss.item()), 130 | ("l_s_gen", self.structure_gen_loss.item()), 131 | ] 132 | return logs 133 | 134 | 135 | def update_flow(self, inputs, smooths, gts, maps, use_correction_loss, use_vgg_loss): 136 | self.iterations += 1 137 | 138 | self.f_dis.zero_grad() 139 | self.f_gen.zero_grad() 140 | outputs, flow_maps = self.flow_forward(inputs, smooths, maps) 141 | 142 | 143 | dis_loss = 0 144 | dis_fake_input = outputs.detach() 145 | dis_real_input = gts 146 | fake_labels = self.f_dis(dis_fake_input) 147 | real_labels = self.f_dis(dis_real_input) 148 | # self.flow_adv_dis_loss = (dis_real_loss + dis_fake_loss) / 2 149 | for i in range(len(fake_labels)): 150 | dis_real_loss = self.adversarial_loss(real_labels[i], True, True) 151 | dis_fake_loss = self.adversarial_loss(fake_labels[i], False, True) 152 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 153 | self.flow_adv_dis_loss = dis_loss/len(fake_labels) 154 | 155 | self.flow_adv_dis_loss.backward() 156 | self.f_dis_opt.step() 157 | if self.f_dis_scheduler is not None: 158 | self.f_dis_scheduler.step() 159 | 160 | 161 | dis_gen_loss = 0 162 | fake_labels = self.f_dis(outputs) 163 | for i in range(len(fake_labels)): 164 | dis_fake_loss = self.adversarial_loss(fake_labels[i], True, False) 165 | dis_gen_loss += dis_fake_loss 166 | self.flow_adv_gen_loss = dis_gen_loss/len(fake_labels) * self.config.FLOW_ADV_GEN 167 | self.flow_l1_loss = self.l1_loss(outputs, gts) * self.config.FLOW_L1 168 | self.flow_correctness_loss = self.correctness_loss(gts, inputs, flow_maps, maps)* \ 169 | self.config.FLOW_CORRECTNESS if use_correction_loss else 0 170 | 171 | 172 | if use_vgg_loss: 173 | self.vgg_loss_style = self.vgg_style(outputs*maps, gts*maps)*self.config.VGG_STYLE 174 | self.vgg_loss_content = self.vgg_content(outputs, gts)*self.config.VGG_CONTENT 175 | self.vgg_loss = self.vgg_loss_style + self.vgg_loss_content 176 | else: 177 | self.vgg_loss = 0 178 | 179 | self.flow_loss = self.flow_adv_gen_loss + self.flow_l1_loss + self.flow_correctness_loss + self.vgg_loss 180 | 181 | self.flow_loss.backward() 182 | self.f_gen_opt.step() 183 | 184 | if self.f_gen_scheduler is not None: 185 | self.f_gen_scheduler.step() 186 | 187 | 188 | logs = [ 189 | ("l_f_adv_dis", self.flow_adv_dis_loss.item()), 190 | ("l_f_adv_gen", self.flow_adv_gen_loss.item()), 191 | ("l_f_l1_gen", self.flow_l1_loss.item()), 192 | ("l_f_total_gen", self.flow_loss.item()), 193 | ] 194 | if use_correction_loss: 195 | logs = logs + [("l_f_correctness_gen", self.flow_correctness_loss.item())] 196 | if use_vgg_loss: 197 | logs = logs + [("l_f_vgg_style", self.vgg_loss_style.item())] 198 | logs = logs + [("l_f_vgg_content", self.vgg_loss_content.item())] 199 | return logs 200 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from torch.nn.utils.spectral_norm import spectral_norm 6 | # from spatial_correlation_sampler import spatial_correlation_sample 7 | from .resample2d import Resample2d 8 | import torchvision.utils as vutils 9 | 10 | 11 | class Discriminator(nn.Module): 12 | def __init__(self, input_dim=3, dim=64, n_layers=3, 13 | norm='none', activ='lrelu', pad_type='reflect', use_sn=True): 14 | super(Discriminator, self).__init__() 15 | 16 | self.model = nn.ModuleList() 17 | self.model.append(Conv2dBlock(input_dim,dim,4,2,1,'none',activ,pad_type,use_sn=use_sn)) 18 | dim_in = dim 19 | for i in range(n_layers - 1): 20 | dim_out = min(dim*8, dim_in*2) 21 | self.model.append(DownsampleResBlock(dim_in,dim_out,'none',activ,pad_type,use_sn)) 22 | dim_in = dim_out 23 | 24 | self.model.append(Conv2dBlock(dim_in,1,1,1,activation='none',use_bias=False, use_sn=use_sn)) 25 | self.model = nn.Sequential(*self.model) 26 | 27 | def forward(self, x): 28 | return self.model(x) 29 | 30 | class MultiDiscriminator(nn.Module): 31 | def __init__(self, **parameter_dic): 32 | super(MultiDiscriminator, self).__init__() 33 | self.model_1 = Discriminator(**parameter_dic) 34 | self.down = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 35 | self.model_2 = Discriminator(**parameter_dic) 36 | 37 | def forward(self, x): 38 | pre1 = self.model_1(x) 39 | pre2 = self.model_2(self.down(x)) 40 | return [pre1, pre2] 41 | 42 | 43 | class StructureGen(nn.Module): 44 | def __init__(self, input_dim=3, dim=64, n_res=4, activ='relu', 45 | norm='in', pad_type='reflect', use_sn=True): 46 | super(StructureGen, self).__init__() 47 | 48 | self.down_sample=nn.ModuleList() 49 | self.up_sample=nn.ModuleList() 50 | self.content_param=nn.ModuleList() 51 | 52 | self.input_layer = Conv2dBlock(input_dim*2+1, dim, 7, 1, 3, norm, activ, pad_type, use_sn=use_sn) 53 | self.down_sample += [nn.Sequential( 54 | Conv2dBlock(dim, 2*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 55 | Conv2dBlock(2*dim, 2*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))] 56 | 57 | self.down_sample += [nn.Sequential( 58 | Conv2dBlock(2*dim, 4*dim, 4, 2, 1,norm, activ, pad_type, use_sn=use_sn), 59 | Conv2dBlock(4*dim, 4*dim, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn))] 60 | 61 | self.down_sample += [nn.Sequential( 62 | Conv2dBlock(4*dim, 8*dim, 4, 2, 1,norm, activ, pad_type, use_sn=use_sn))] 63 | dim = 8*dim 64 | # content decoder 65 | self.up_sample += [(nn.Sequential( 66 | ResBlocks(n_res, dim, norm, activ, pad_type=pad_type), 67 | nn.Upsample(scale_factor=2), 68 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)) )] 69 | 70 | self.up_sample += [(nn.Sequential( 71 | ResBlocks(n_res, dim//2, norm, activ, pad_type=pad_type), 72 | nn.Upsample(scale_factor=2), 73 | Conv2dBlock(dim//2, dim//4, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn)) )] 74 | 75 | self.up_sample += [(nn.Sequential( 76 | ResBlocks(n_res, dim//4, norm, activ, pad_type=pad_type), 77 | nn.Upsample(scale_factor=2), 78 | Conv2dBlock(dim//4, dim//8, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn)) )] 79 | 80 | self.content_param += [Conv2dBlock(dim//2, dim//2, 5, 1, 2, norm, activ, pad_type)] 81 | self.content_param += [Conv2dBlock(dim//4, dim//4, 5, 1, 2, norm, activ, pad_type)] 82 | self.content_param += [Conv2dBlock(dim//8, dim//8, 5, 1, 2, norm, activ, pad_type)] 83 | 84 | self.image_net = Get_image(dim//8, input_dim) 85 | 86 | def forward(self, inputs): 87 | x0 = self.input_layer(inputs) 88 | x1 = self.down_sample[0](x0) 89 | x2 = self.down_sample[1](x1) 90 | x3 = self.down_sample[2](x2) 91 | 92 | u1 = self.up_sample[0](x3) + self.content_param[0](x2) 93 | u2 = self.up_sample[1](u1) + self.content_param[1](x1) 94 | u3 = self.up_sample[2](u2) + self.content_param[2](x0) 95 | 96 | images_out = self.image_net(u3) 97 | return images_out 98 | 99 | 100 | class FlowGen(nn.Module): 101 | def __init__(self, input_dim=3, dim=64, n_res=2, activ='relu', 102 | norm_flow='ln', norm_conv='in', pad_type='reflect', use_sn=True): 103 | super(FlowGen, self).__init__() 104 | 105 | self.flow_column = FlowColumn(input_dim, dim, n_res, activ, 106 | norm_flow, pad_type, use_sn) 107 | self.conv_column = ConvColumn(input_dim, dim, n_res, activ, 108 | norm_conv, pad_type, use_sn) 109 | 110 | def forward(self, inputs): 111 | flow_map = self.flow_column(inputs) 112 | images_out = self.conv_column(inputs, flow_map) 113 | return images_out, flow_map 114 | 115 | 116 | 117 | class ConvColumn(nn.Module): 118 | def __init__(self, input_dim=3, dim=64, n_res=2, activ='lrelu', 119 | norm='ln', pad_type='reflect', use_sn=True): 120 | super(ConvColumn, self).__init__() 121 | 122 | self.down_sample = nn.ModuleList() 123 | self.up_sample = nn.ModuleList() 124 | 125 | 126 | self.down_sample += [nn.Sequential( 127 | Conv2dBlock(input_dim*2+1, dim//2, 7, 1, 3, norm, activ, pad_type, use_sn=use_sn), 128 | Conv2dBlock(dim//2, dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 129 | Conv2dBlock(dim, dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn), 130 | Conv2dBlock(dim, 2*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 131 | Conv2dBlock(2*dim, 2*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))] 132 | 133 | self.down_sample += [nn.Sequential( 134 | Conv2dBlock(2*dim, 4*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn), 135 | Conv2dBlock(4*dim, 8*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn))] 136 | dim = 8*dim 137 | 138 | # content decoder 139 | self.up_sample += [(nn.Sequential( 140 | ResBlocks(n_res, dim, norm, activ, pad_type=pad_type), 141 | nn.Upsample(scale_factor=2), 142 | Conv2dBlock(dim, dim//2, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)) )] 143 | 144 | self.up_sample += [(nn.Sequential( 145 | Conv2dBlock(dim, dim//2, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn), 146 | ResBlocks(n_res, dim//2, norm, activ, pad_type=pad_type), 147 | nn.Upsample(scale_factor=2), 148 | Conv2dBlock(dim//2, dim//4, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn), 149 | 150 | ResBlocks(n_res, dim//4, norm, activ, pad_type=pad_type), 151 | nn.Upsample(scale_factor=2), 152 | Conv2dBlock(dim//4, dim//8, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn), 153 | Get_image(dim//8, input_dim)) )] 154 | 155 | self.resample16 = Resample2d(16, 1, sigma=4) 156 | self.resample4 = Resample2d(4, 1, sigma=2) 157 | 158 | 159 | def forward(self, inputs, flow_maps): 160 | x1 = self.down_sample[0](inputs) 161 | x2 = self.down_sample[1](x1) 162 | flow_fea = self.resample_image(x1, flow_maps) 163 | 164 | u1 = torch.cat((self.up_sample[0](x2), flow_fea), 1) 165 | images_out = self.up_sample[1](u1) 166 | return images_out 167 | 168 | def resample_image(self, img, flow): 169 | output16 = self.resample16(img, flow) 170 | output4 = self.resample4 (img, flow) 171 | outputs = torch.cat((output16,output4), 1) 172 | return outputs 173 | 174 | 175 | 176 | class FlowColumn(nn.Module): 177 | def __init__(self, input_dim=3, dim=64, n_res=2, activ='lrelu', 178 | norm='in', pad_type='reflect', use_sn=True): 179 | super(FlowColumn, self).__init__() 180 | 181 | self.down_sample_flow = nn.ModuleList() 182 | self.up_sample_flow = nn.ModuleList() 183 | 184 | self.down_sample_flow.append( nn.Sequential( 185 | Conv2dBlock(input_dim*2+1, dim//2, 7, 1, 3, norm, activ, pad_type, use_sn=use_sn), 186 | Conv2dBlock(dim//2, dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 187 | Conv2dBlock( dim, dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))) 188 | self.down_sample_flow.append( nn.Sequential( 189 | Conv2dBlock( dim, 2*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 190 | Conv2dBlock(2*dim, 2*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))) 191 | self.down_sample_flow.append(nn.Sequential( 192 | Conv2dBlock(2*dim, 4*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 193 | Conv2dBlock(4*dim, 4*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))) 194 | self.down_sample_flow.append(nn.Sequential( 195 | Conv2dBlock(4*dim, 8*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn), 196 | Conv2dBlock(8*dim, 8*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))) 197 | dim = 8*dim 198 | 199 | # content decoder 200 | self.up_sample_flow.append(nn.Sequential( 201 | ResBlocks(n_res, dim, norm, activ, pad_type=pad_type), 202 | TransConv2dBlock(dim, dim//2, 6, 2, 2, norm=norm, activation=activ) )) 203 | 204 | self.up_sample_flow.append(nn.Sequential( 205 | Conv2dBlock(dim, dim//2, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn), 206 | ResBlocks(n_res, dim//2, norm, activ, pad_type=pad_type), 207 | TransConv2dBlock(dim//2, dim//4, 6, 2, 2, norm=norm, activation=activ) )) 208 | 209 | self.location = nn.Sequential( 210 | Conv2dBlock(dim//2, dim//8, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn), 211 | Conv2dBlock(dim//8, 2, 3, 1, 1, norm='none', activation='none', pad_type=pad_type, use_bias=False) ) 212 | 213 | def forward(self, inputs): 214 | f_x1 = self.down_sample_flow[0](inputs) 215 | f_x2 = self.down_sample_flow[1](f_x1) 216 | f_x3 = self.down_sample_flow[2](f_x2) 217 | f_x4 = self.down_sample_flow[3](f_x3) 218 | 219 | f_u1 = torch.cat((self.up_sample_flow[0](f_x4), f_x3), 1) 220 | f_u2 = torch.cat((self.up_sample_flow[1](f_u1), f_x2), 1) 221 | flow_map = self.location(f_u2) 222 | return flow_map 223 | 224 | 225 | 226 | ################################################################################## 227 | # Basic Blocks 228 | ################################################################################## 229 | class Get_image(nn.Module): 230 | def __init__(self, input_dim, output_dim, activation='tanh'): 231 | super(Get_image, self).__init__() 232 | self.conv = Conv2dBlock(input_dim, output_dim, kernel_size=3, stride=1, 233 | padding=1, pad_type='reflect', activation=activation) 234 | def forward(self, x): 235 | return self.conv(x) 236 | 237 | class ResBlocks(nn.Module): 238 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero', use_sn=False): 239 | super(ResBlocks, self).__init__() 240 | self.model = [] 241 | for i in range(num_blocks): 242 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, use_sn=use_sn)] 243 | self.model = nn.Sequential(*self.model) 244 | 245 | def forward(self, x): 246 | return self.model(x) 247 | 248 | class ResBlock(nn.Module): 249 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero', use_sn=False): 250 | super(ResBlock, self).__init__() 251 | 252 | model = [] 253 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type, use_sn=use_sn)] 254 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type, use_sn=use_sn)] 255 | self.model = nn.Sequential(*model) 256 | 257 | def forward(self, x): 258 | residual = x 259 | out = self.model(x) 260 | out += residual 261 | return out 262 | 263 | class DilationBlock(nn.Module): 264 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 265 | super(DilationBlock, self).__init__() 266 | 267 | model = [] 268 | model += [Conv2dBlock(dim ,dim, 3, 1, 2, norm=norm, activation=activation, pad_type=pad_type, dilation=2)] 269 | model += [Conv2dBlock(dim ,dim, 3, 1, 4, norm=norm, activation=activation, pad_type=pad_type, dilation=4)] 270 | model += [Conv2dBlock(dim ,dim, 3, 1, 8, norm=norm, activation=activation, pad_type=pad_type, dilation=8)] 271 | self.model = nn.Sequential(*model) 272 | 273 | def forward(self, x): 274 | out = self.model(x) 275 | return out 276 | 277 | class Conv2dBlock(nn.Module): 278 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 279 | padding=0, norm='none', activation='relu', pad_type='zero', dilation=1, 280 | use_bias=True, use_sn=False): 281 | super(Conv2dBlock, self).__init__() 282 | self.use_bias = use_bias 283 | # initialize padding 284 | if pad_type == 'reflect': 285 | self.pad = nn.ReflectionPad2d(padding) 286 | elif pad_type == 'replicate': 287 | self.pad = nn.ReplicationPad2d(padding) 288 | elif pad_type == 'zero': 289 | self.pad = nn.ZeroPad2d(padding) 290 | else: 291 | assert 0, "Unsupported padding type: {}".format(pad_type) 292 | 293 | # initialize normalization 294 | norm_dim = output_dim 295 | if norm == 'bn': 296 | self.norm = nn.BatchNorm2d(norm_dim) 297 | elif norm == 'in': 298 | self.norm = nn.InstanceNorm2d(norm_dim) 299 | elif norm == 'ln': 300 | self.norm = LayerNorm(norm_dim) 301 | elif norm == 'adain': 302 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 303 | elif norm == 'none': 304 | self.norm = None 305 | else: 306 | assert 0, "Unsupported normalization: {}".format(norm) 307 | 308 | # initialize activation 309 | if activation == 'relu': 310 | self.activation = nn.ReLU(inplace=True) 311 | elif activation == 'lrelu': 312 | self.activation = nn.LeakyReLU(0.2, inplace=True) 313 | elif activation == 'prelu': 314 | self.activation = nn.PReLU() 315 | elif activation == 'selu': 316 | self.activation = nn.SELU(inplace=True) 317 | elif activation == 'tanh': 318 | self.activation = nn.Tanh() 319 | elif activation == 'sigmoid': 320 | self.activation = nn.Sigmoid() 321 | elif activation == 'none': 322 | self.activation = None 323 | else: 324 | assert 0, "Unsupported activation: {}".format(activation) 325 | 326 | # initialize convolution 327 | if use_sn: 328 | self.conv = spectral_norm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias, dilation=dilation)) 329 | else: 330 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias, dilation=dilation) 331 | 332 | def forward(self, x): 333 | x = self.conv(self.pad(x)) 334 | if self.norm: 335 | x = self.norm(x) 336 | if self.activation: 337 | x = self.activation(x) 338 | return x 339 | 340 | class TransConv2dBlock(nn.Module): 341 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 342 | padding=0, norm='none', activation='relu'): 343 | super(TransConv2dBlock, self).__init__() 344 | self.use_bias = True 345 | 346 | # initialize normalization 347 | norm_dim = output_dim 348 | if norm == 'bn': 349 | self.norm = nn.BatchNorm2d(norm_dim) 350 | elif norm == 'in': 351 | self.norm = nn.InstanceNorm2d(norm_dim) 352 | elif norm == 'in_affine': 353 | self.norm = nn.InstanceNorm2d(norm_dim, affine=True) 354 | elif norm == 'ln': 355 | self.norm = LayerNorm(norm_dim) 356 | elif norm == 'adain': 357 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 358 | elif norm == 'none': 359 | self.norm = None 360 | else: 361 | assert 0, "Unsupported normalization: {}".format(norm) 362 | 363 | # initialize activation 364 | if activation == 'relu': 365 | self.activation = nn.ReLU(inplace=True) 366 | elif activation == 'lrelu': 367 | self.activation = nn.LeakyReLU(0.2, inplace=True) 368 | elif activation == 'prelu': 369 | self.activation = nn.PReLU() 370 | elif activation == 'selu': 371 | self.activation = nn.SELU(inplace=True) 372 | elif activation == 'tanh': 373 | self.activation = nn.Tanh() 374 | elif activation == 'sigmoid': 375 | self.activation = nn.Sigmoid() 376 | elif activation == 'none': 377 | self.activation = None 378 | else: 379 | assert 0, "Unsupported activation: {}".format(activation) 380 | 381 | # initialize convolution 382 | self.transConv = nn.ConvTranspose2d(input_dim, output_dim, kernel_size, stride, padding, bias=self.use_bias) 383 | 384 | def forward(self, x): 385 | x = self.transConv(x) 386 | if self.norm: 387 | x = self.norm(x) 388 | if self.activation: 389 | x = self.activation(x) 390 | return x 391 | 392 | class AdaptiveInstanceNorm2d(nn.Module): 393 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 394 | super(AdaptiveInstanceNorm2d, self).__init__() 395 | self.num_features = num_features 396 | self.eps = eps 397 | self.momentum = momentum 398 | # weight and bias are dynamically assigned 399 | self.weight = None 400 | self.bias = None 401 | # just dummy buffers, not used 402 | self.register_buffer('running_mean', torch.zeros(num_features)) 403 | self.register_buffer('running_var', torch.ones(num_features)) 404 | 405 | def forward(self, x): 406 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 407 | b, c = x.size(0), x.size(1) 408 | running_mean = self.running_mean.repeat(b) 409 | running_var = self.running_var.repeat(b) 410 | 411 | # Apply instance norm 412 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 413 | 414 | out = F.batch_norm( 415 | x_reshaped, running_mean, running_var, self.weight, self.bias, 416 | True, self.momentum, self.eps) 417 | 418 | return out.view(b, c, *x.size()[2:]) 419 | 420 | def __repr__(self): 421 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 422 | 423 | class LayerNorm(nn.Module): 424 | def __init__(self, n_out, eps=1e-5, affine=True): 425 | super(LayerNorm, self).__init__() 426 | self.n_out = n_out 427 | self.affine = affine 428 | 429 | if self.affine: 430 | self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) 431 | self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) 432 | 433 | def forward(self, x): 434 | normalized_shape = x.size()[1:] 435 | if self.affine: 436 | return F.layer_norm(x, normalized_shape, self.weight.expand(normalized_shape), self.bias.expand(normalized_shape)) 437 | else: 438 | return F.layer_norm(x, normalized_shape) 439 | 440 | class DownsampleResBlock(nn.Module): 441 | def __init__(self, input_dim, output_dim, norm='in', activation='relu', pad_type='zero', use_sn=False): 442 | super(DownsampleResBlock, self).__init__() 443 | self.conv_1 = nn.ModuleList() 444 | self.conv_2 = nn.ModuleList() 445 | 446 | self.conv_1.append(Conv2dBlock(input_dim,input_dim,3,1,1,'none',activation,pad_type,use_sn=use_sn)) 447 | self.conv_1.append(Conv2dBlock(input_dim,output_dim,3,1,1,'none',activation,pad_type,use_sn=use_sn)) 448 | self.conv_1.append(nn.AvgPool2d(kernel_size=2, stride=2)) 449 | self.conv_1 = nn.Sequential(*self.conv_1) 450 | 451 | 452 | self.conv_2.append(nn.AvgPool2d(kernel_size=2, stride=2)) 453 | self.conv_2.append(Conv2dBlock(input_dim,output_dim,1,1,0,'none',activation,pad_type,use_sn=use_sn)) 454 | self.conv_2 = nn.Sequential(*self.conv_2) 455 | 456 | 457 | def forward(self, x): 458 | out = self.conv_1(x) + self.conv_2(x) 459 | return out 460 | 461 | 462 | -------------------------------------------------------------------------------- /src/resample2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function, Variable 4 | import resample2d_cuda 5 | 6 | class Resample2dFunction(Function): 7 | 8 | @staticmethod 9 | def forward(ctx, input1, input2, kernel_size=2, dilation=1): 10 | assert input1.is_contiguous() 11 | assert input2.is_contiguous() 12 | 13 | ctx.save_for_backward(input1, input2) 14 | ctx.kernel_size = kernel_size 15 | ctx.dilation = dilation 16 | 17 | _, d, _, _ = input1.size() 18 | b, _, h, w = input2.size() 19 | output = input1.new(b, d, h, w).zero_() 20 | 21 | resample2d_cuda.forward(input1, input2, output, kernel_size, dilation) 22 | 23 | return output 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | if not grad_output.is_contiguous(): 28 | grad_output.contiguous() 29 | 30 | input1, input2 = ctx.saved_tensors 31 | 32 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 33 | grad_input2 = Variable(input1.new(input2.size()).zero_()) 34 | 35 | resample2d_cuda.backward(input1, input2, grad_output.data, 36 | grad_input1.data, grad_input2.data, 37 | ctx.kernel_size, ctx.dilation) 38 | 39 | return grad_input1, grad_input2, None, None 40 | 41 | class Resample2d(Module): 42 | 43 | def __init__(self, kernel_size=2, dilation=1, sigma=5 ): 44 | super(Resample2d, self).__init__() 45 | self.kernel_size = kernel_size 46 | self.dilation = dilation 47 | self.sigma = torch.tensor(sigma, dtype=torch.float).cuda() 48 | 49 | def forward(self, input1, input2): 50 | input1_c = input1.contiguous() 51 | sigma = self.sigma.expand(input2.size(0), 1, input2.size(2), input2.size(3)).type(input2.dtype) 52 | input2 = torch.cat((input2,sigma), 1) 53 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.dilation) 54 | -------------------------------------------------------------------------------- /src/structure_flow.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable, grad 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import torch.nn.functional as F 7 | import glob 8 | import torchvision.utils as vutils 9 | import math 10 | import shutil 11 | import tensorboardX 12 | from itertools import islice 13 | from torch.utils.data import DataLoader 14 | from .data import Dataset 15 | from .utils import Progbar, write_2images, write_2tensorboard, create_dir, imsave 16 | from skimage.measure import compare_ssim 17 | from skimage.measure import compare_psnr 18 | from .models import StructureFlowModel 19 | 20 | 21 | class StructureFlow(): 22 | def __init__(self, config): 23 | self.config = config 24 | self.debug=False 25 | self.flow_model = StructureFlowModel(config).to(config.DEVICE) 26 | 27 | self.samples_path = os.path.join(config.PATH, config.NAME, 'images') 28 | self.checkpoints_path = os.path.join(config.PATH, config.NAME, 'checkpoints') 29 | self.test_image_path = os.path.join(config.PATH, config.NAME, 'test_result') 30 | 31 | if self.config.MODE == 'train' and not self.config.RESUME_ALL: 32 | pass 33 | else: 34 | self.flow_model.load(self.config.WHICH_ITER) 35 | 36 | if self.config.MODEL == 1: 37 | self.stage_name='structure_reconstructor' 38 | elif self.config.MODEL == 2: 39 | self.stage_name='texture_generator' 40 | elif self.config.MODEL == 3: 41 | self.stage_name='joint_train' 42 | 43 | def train(self): 44 | train_writer = self.obtain_log(self.config) 45 | train_dataset = Dataset(self.config.DATA_TRAIN_GT, self.config.DATA_TRAIN_STRUCTURE, 46 | self.config, self.config.DATA_MASK_FILE) 47 | train_loader = DataLoader(dataset=train_dataset, batch_size=self.config.TRAIN_BATCH_SIZE, 48 | shuffle=True, drop_last=True, num_workers=8) 49 | 50 | val_dataset = Dataset(self.config.DATA_VAL_GT, self.config.DATA_VAL_STRUCTURE, 51 | self.config, self.config.DATA_MASK_FILE) 52 | sample_iterator = val_dataset.create_iterator(self.config.SAMPLE_SIZE) 53 | 54 | 55 | iterations = self.flow_model.iterations 56 | total = len(train_dataset) 57 | epoch = math.floor(iterations*self.config.TRAIN_BATCH_SIZE/total) 58 | keep_training = True 59 | model = self.config.MODEL 60 | max_iterations = int(float(self.config.MAX_ITERS)) 61 | 62 | while(keep_training): 63 | epoch += 1 64 | print('\n\nTraining epoch: %d' % epoch) 65 | 66 | progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter']) 67 | 68 | for items in train_loader: 69 | inputs, smooths, gts, maps = self.cuda(*items) 70 | 71 | # structure model 72 | if model == 1: 73 | logs = self.flow_model.update_structure(inputs, smooths, maps) 74 | iterations = self.flow_model.iterations 75 | # flow model 76 | elif model == 2: 77 | logs = self.flow_model.update_flow(inputs, smooths, gts, maps, self.flow_model.use_correction_loss, self.flow_model.use_vgg_loss) 78 | iterations = self.flow_model.iterations 79 | # flow with structure model 80 | elif model == 3: 81 | with torch.no_grad(): 82 | smooth_stage_1 = self.flow_model.structure_forward(inputs, smooths, maps) 83 | logs = self.flow_model.update_flow(inputs, smooth_stage_1.detach(), gts, maps, self.flow_model.use_correction_loss, self.flow_model.use_vgg_loss) 84 | iterations = self.flow_model.iterations 85 | 86 | if iterations >= max_iterations: 87 | keep_training = False 88 | break 89 | 90 | # print(logs) 91 | logs = [ 92 | ("epoch", epoch), 93 | ("iter", iterations), 94 | ] + logs 95 | 96 | progbar.add(len(inputs), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')]) 97 | 98 | # log model 99 | if self.config.LOG_INTERVAL and iterations % self.config.LOG_INTERVAL == 0: 100 | self.write_loss(logs, train_writer) 101 | # sample model 102 | if self.config.SAMPLE_INTERVAL and iterations % self.config.SAMPLE_INTERVAL == 0: 103 | items = next(sample_iterator) 104 | inputs, smooths, gts, maps = self.cuda(*items) 105 | result,flow = self.flow_model.sample(inputs, smooths, gts, maps) 106 | self.write_image(result, train_writer, iterations, 'image') 107 | self.write_image(flow, train_writer, iterations, 'flow') 108 | # evaluate model 109 | if self.config.EVAL_INTERVAL and iterations % self.config.EVAL_INTERVAL == 0: 110 | self.flow_model.eval() 111 | print('\nstart eval...\n') 112 | self.eval(writer=train_writer) 113 | self.flow_model.train() 114 | 115 | # save the latest model 116 | if self.config.SAVE_LATEST and iterations % self.config.SAVE_LATEST == 0: 117 | print('\nsaving the latest model (total_steps %d)\n' % (iterations)) 118 | self.flow_model.save('latest') 119 | 120 | # save the model 121 | if self.config.SAVE_INTERVAL and iterations % self.config.SAVE_INTERVAL == 0: 122 | print('\nsaving the model of iterations %d\n' % iterations) 123 | self.flow_model.save(iterations) 124 | print('\nEnd training....') 125 | 126 | 127 | def eval(self, writer=None): 128 | val_dataset = Dataset(self.config.DATA_VAL_GT , self.config.DATA_VAL_STRUCTURE, self.config, self.config.DATA_VAL_MASK) 129 | val_loader = DataLoader( 130 | dataset=val_dataset, 131 | batch_size = self.config.TRAIN_BATCH_SIZE, 132 | shuffle=False 133 | ) 134 | model = self.config.MODEL 135 | total = len(val_dataset) 136 | iterations = self.flow_model.iterations 137 | 138 | progbar = Progbar(total, width=20, stateful_metrics=['it']) 139 | iteration = 0 140 | psnr_list = [] 141 | 142 | # TODO: add fid score to evaluate 143 | with torch.no_grad(): 144 | # for items in val_loader: 145 | for j, items in enumerate(islice(val_loader, 50)): 146 | 147 | logs = [] 148 | iteration += 1 149 | inputs, smooths, gts, maps = self.cuda(*items) 150 | if model == 1: 151 | outputs_structure = self.flow_model.structure_forward(inputs, smooths, maps) 152 | psnr, ssim, l1 = self.metrics(outputs_structure, smooths) 153 | logs.append(('psnr', psnr.item())) 154 | psnr_list.append(psnr.item()) 155 | 156 | # inpaint model 157 | elif model == 2: 158 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooths, maps) 159 | psnr, ssim, l1 = self.metrics(outputs, gts) 160 | logs.append(('psnr', psnr.item())) 161 | psnr_list.append(psnr.item()) 162 | 163 | 164 | # inpaint with structure model 165 | elif model == 3: 166 | smooth_stage_1 = self.flow_model.structure_forward(inputs, smooths, maps) 167 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooth_stage_1, maps) 168 | psnr, ssim, l1 = self.metrics(outputs, gts) 169 | logs.append(('psnr', psnr.item())) 170 | psnr_list.append(psnr.item()) 171 | 172 | logs = [("it", iteration), ] + logs 173 | progbar.add(len(inputs), values=logs) 174 | 175 | avg_psnr = np.average(psnr_list) 176 | 177 | if writer is not None: 178 | writer.add_scalar('eval_psnr', avg_psnr, iterations) 179 | 180 | print('model eval at iterations:%d'%iterations) 181 | print('average psnr:%f'%avg_psnr) 182 | 183 | 184 | def test(self): 185 | self.flow_model.eval() 186 | 187 | model = self.config.MODEL 188 | print(self.config.DATA_TEST_RESULTS) 189 | create_dir(self.config.DATA_TEST_RESULTS) 190 | test_dataset = Dataset(self.config.DATA_TEST_GT, self.config.DATA_TEST_STRUCTURE, self.config, self.config.DATA_TEST_MASK) 191 | test_loader = DataLoader( 192 | dataset=test_dataset, 193 | batch_size=8, 194 | ) 195 | 196 | index = 0 197 | with torch.no_grad(): 198 | for items in test_loader: 199 | inputs, smooths, gts, maps = self.cuda(*items) 200 | 201 | # structure model 202 | if model == 1: 203 | outputs = self.flow_model.structure_forward(inputs, smooths, maps) 204 | outputs_merged = (outputs * maps) + (smooths * (1 - maps)) 205 | 206 | # flow model 207 | elif model == 2: 208 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooths, maps) 209 | outputs_merged = (outputs * maps) + (gts * (1 - maps)) 210 | 211 | 212 | # inpaint with structure model / joint model 213 | else: 214 | smooth_stage_1 = self.flow_model.structure_forward(inputs, smooths, maps) 215 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooth_stage_1, maps) 216 | outputs_merged = (outputs * maps) + (gts * (1 - maps)) 217 | 218 | outputs_merged = self.postprocess(outputs_merged)*255.0 219 | inputs_show = inputs + maps 220 | 221 | 222 | for i in range(outputs_merged.size(0)): 223 | name = test_dataset.load_name(index, self.debug) 224 | print(index, name) 225 | path = os.path.join(self.config.DATA_TEST_RESULTS, name) 226 | imsave(outputs_merged[i,:,:,:].unsqueeze(0), path) 227 | index += 1 228 | 229 | if self.debug and model == 3: 230 | smooth_ = self.postprocess(smooth_stage_1[i,:,:,:].unsqueeze(0))*255.0 231 | inputs_ = self.postprocess(inputs_show[i,:,:,:].unsqueeze(0))*255.0 232 | gts_ = self.postprocess(gts[i,:,:,:].unsqueeze(0))*255.0 233 | print(path) 234 | fname, fext = os.path.splitext(path) 235 | imsave(smooth_, fname+'_smooth.'+fext) 236 | imsave(inputs_, fname+'_inputs.'+fext) 237 | imsave(gts_, fname+'_gts.'+fext) 238 | 239 | print('\nEnd test....') 240 | 241 | 242 | 243 | def obtain_log(self, config): 244 | log_dir = os.path.join(config.PATH, config.NAME, self.stage_name+'_log') 245 | if os.path.exists(log_dir) and config.REMOVE_LOG: 246 | shutil.rmtree(log_dir) 247 | train_writer = tensorboardX.SummaryWriter(log_dir) 248 | return train_writer 249 | 250 | 251 | def cuda(self, *args): 252 | return (item.to(self.config.DEVICE) for item in args) 253 | 254 | 255 | def write_loss(self, logs, train_writer): 256 | iteration = [x[1] for x in logs if x[0]=='iter'] 257 | for x in logs: 258 | if x[0].startswith('l_'): 259 | train_writer.add_scalar(x[0], x[1], iteration[-1]) 260 | 261 | def write_image(self, result, train_writer, iterations, label): 262 | if result: 263 | name = '%s/model%d_sample_%08d'%(self.samples_path, self.config.MODEL, iterations) + label + '.jpg' 264 | write_2images(result, self.config.SAMPLE_SIZE, name) 265 | write_2tensorboard(iterations, result, train_writer, self.config.SAMPLE_SIZE, label) 266 | 267 | 268 | def postprocess(self, x): 269 | x = (x + 1) / 2 270 | x.clamp_(0, 1) 271 | return x 272 | 273 | def metrics(self, inputs, gts): 274 | inputs = self.postprocess(inputs) 275 | gts = self.postprocess(gts) 276 | psnr_value=[] 277 | l1_value = torch.mean(torch.abs(inputs-gts)) 278 | 279 | [b,n,w,h] = inputs.size() 280 | inputs = (inputs*255.0).int().float()/255.0 281 | gts = (gts*255.0).int().float()/255.0 282 | 283 | for i in range(inputs.size(0)): 284 | inputs_p = inputs[i,:,:,:].cpu().numpy().astype(np.float32).transpose(1,2,0) 285 | gts_p = gts[i,:,:,:].cpu().numpy().astype(np.float32).transpose(1,2,0) 286 | psnr_value.append(compare_psnr(inputs_p, gts_p, data_range=1)) 287 | 288 | psnr_value = np.average(psnr_value) 289 | inputs = inputs.view(b*n, w, h).cpu().numpy().astype(np.float32).transpose(1,2,0) 290 | gts = gts.view(b*n, w, h).cpu().numpy().astype(np.float32).transpose(1,2,0) 291 | ssim_value = compare_ssim(inputs, gts, data_range=1, win_size=51, multichannel=True) 292 | return psnr_value, ssim_value, l1_value 293 | 294 | 295 | 296 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import torch.nn.init as init 9 | import math 10 | import torch 11 | import torchvision.utils as vutils 12 | from natsort import natsorted 13 | 14 | def create_dir(dir): 15 | if not os.path.exists(dir): 16 | os.makedirs(dir) 17 | 18 | 19 | def create_mask(width, height, mask_width, mask_height, x=None, y=None): 20 | mask = np.zeros((height, width)) 21 | mask_x = x if x is not None else random.randint(0, width - mask_width) 22 | mask_y = y if y is not None else random.randint(0, height - mask_height) 23 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 24 | return mask 25 | 26 | 27 | def stitch_images(inputs, *outputs, img_per_row=2): 28 | gap = 5 29 | columns = len(outputs) + 1 30 | 31 | width, height = inputs[0][:, :, 0].shape 32 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row))) 33 | images = [inputs, *outputs] 34 | 35 | for ix in range(len(inputs)): 36 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap 37 | yoffset = int(ix / img_per_row) * height 38 | 39 | for cat in range(len(images)): 40 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze() 41 | im = Image.fromarray(im) 42 | img.paste(im, (xoffset + cat * width, yoffset)) 43 | 44 | return img 45 | 46 | 47 | def imshow(img, title=''): 48 | fig = plt.gcf() 49 | fig.canvas.set_window_title(title) 50 | plt.axis('off') 51 | plt.imshow(img, interpolation='none') 52 | plt.show() 53 | 54 | 55 | def imsave(img, path): 56 | # print(img) 57 | # print(img.shape) 58 | img = img.permute(0, 2, 3, 1) 59 | im = Image.fromarray(img.cpu().detach().numpy().astype(np.uint8).squeeze()) 60 | im.save(path) 61 | 62 | 63 | class Progbar(object): 64 | """Displays a progress bar. 65 | 66 | Arguments: 67 | target: Total number of steps expected, None if unknown. 68 | width: Progress bar width on screen. 69 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 70 | stateful_metrics: Iterable of string names of metrics that 71 | should *not* be averaged over time. Metrics in this list 72 | will be displayed as-is. All others will be averaged 73 | by the progbar before display. 74 | interval: Minimum visual progress update interval (in seconds). 75 | """ 76 | 77 | def __init__(self, target, width=25, verbose=1, interval=0.05, 78 | stateful_metrics=None): 79 | self.target = target 80 | self.width = width 81 | self.verbose = verbose 82 | self.interval = interval 83 | if stateful_metrics: 84 | self.stateful_metrics = set(stateful_metrics) 85 | else: 86 | self.stateful_metrics = set() 87 | 88 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 89 | sys.stdout.isatty()) or 90 | 'ipykernel' in sys.modules or 91 | 'posix' in sys.modules) 92 | self._total_width = 0 93 | self._seen_so_far = 0 94 | # We use a dict + list to avoid garbage collection 95 | # issues found in OrderedDict 96 | self._values = {} 97 | self._values_order = [] 98 | self._start = time.time() 99 | self._last_update = 0 100 | 101 | def update(self, current, values=None): 102 | """Updates the progress bar. 103 | 104 | Arguments: 105 | current: Index of current step. 106 | values: List of tuples: 107 | `(name, value_for_last_step)`. 108 | If `name` is in `stateful_metrics`, 109 | `value_for_last_step` will be displayed as-is. 110 | Else, an average of the metric over time will be displayed. 111 | """ 112 | values = values or [] 113 | for k, v in values: 114 | if k not in self._values_order: 115 | self._values_order.append(k) 116 | if k not in self.stateful_metrics: 117 | if k not in self._values: 118 | self._values[k] = [v * (current - self._seen_so_far), 119 | current - self._seen_so_far] 120 | else: 121 | self._values[k][0] += v * (current - self._seen_so_far) 122 | self._values[k][1] += (current - self._seen_so_far) 123 | else: 124 | self._values[k] = v 125 | self._seen_so_far = current 126 | 127 | now = time.time() 128 | info = ' - %.0fs' % (now - self._start) 129 | if self.verbose == 1: 130 | if (now - self._last_update < self.interval and 131 | self.target is not None and current < self.target): 132 | return 133 | 134 | prev_total_width = self._total_width 135 | if self._dynamic_display: 136 | sys.stdout.write('\b' * prev_total_width) 137 | sys.stdout.write('\r') 138 | else: 139 | sys.stdout.write('\n') 140 | 141 | if self.target is not None: 142 | numdigits = int(np.floor(np.log10(self.target))) + 1 143 | barstr = '%%%dd/%d [' % (numdigits, self.target) 144 | bar = barstr % current 145 | prog = float(current) / self.target 146 | prog_width = int(self.width * prog) 147 | if prog_width > 0: 148 | bar += ('=' * (prog_width - 1)) 149 | if current < self.target: 150 | bar += '>' 151 | else: 152 | bar += '=' 153 | bar += ('.' * (self.width - prog_width)) 154 | bar += ']' 155 | else: 156 | bar = '%7d/Unknown' % current 157 | 158 | self._total_width = len(bar) 159 | sys.stdout.write(bar) 160 | 161 | if current: 162 | time_per_unit = (now - self._start) / current 163 | else: 164 | time_per_unit = 0 165 | if self.target is not None and current < self.target: 166 | eta = time_per_unit * (self.target - current) 167 | if eta > 3600: 168 | eta_format = '%d:%02d:%02d' % (eta // 3600, 169 | (eta % 3600) // 60, 170 | eta % 60) 171 | elif eta > 60: 172 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 173 | else: 174 | eta_format = '%ds' % eta 175 | 176 | info = ' - ETA: %s' % eta_format 177 | else: 178 | if time_per_unit >= 1: 179 | info += ' %.0fs/step' % time_per_unit 180 | elif time_per_unit >= 1e-3: 181 | info += ' %.0fms/step' % (time_per_unit * 1e3) 182 | else: 183 | info += ' %.0fus/step' % (time_per_unit * 1e6) 184 | 185 | for k in self._values_order: 186 | info += ' - %s:' % k 187 | if isinstance(self._values[k], list): 188 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 189 | if abs(avg) > 1e-3: 190 | info += ' %.4f' % avg 191 | else: 192 | info += ' %.4e' % avg 193 | else: 194 | info += ' %s' % self._values[k] 195 | 196 | self._total_width += len(info) 197 | if prev_total_width > self._total_width: 198 | info += (' ' * (prev_total_width - self._total_width)) 199 | 200 | if self.target is not None and current >= self.target: 201 | info += '\n' 202 | 203 | sys.stdout.write(info) 204 | sys.stdout.flush() 205 | 206 | elif self.verbose == 2: 207 | if self.target is None or current >= self.target: 208 | for k in self._values_order: 209 | info += ' - %s:' % k 210 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 211 | if avg > 1e-3: 212 | info += ' %.4f' % avg 213 | else: 214 | info += ' %.4e' % avg 215 | info += '\n' 216 | 217 | sys.stdout.write(info) 218 | sys.stdout.flush() 219 | 220 | self._last_update = now 221 | 222 | def add(self, n, values=None): 223 | self.update(self._seen_so_far + n, values) 224 | 225 | # network init function 226 | def weights_init(init_type='gaussian'): 227 | def init_fun(m): 228 | classname = m.__class__.__name__ 229 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 230 | if init_type == 'gaussian': 231 | init.normal_(m.weight.data, 0.0, 0.02) 232 | elif init_type == 'xavier': 233 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 234 | elif init_type == 'kaiming': 235 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 236 | elif init_type == 'orthogonal': 237 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 238 | elif init_type == 'default': 239 | pass 240 | else: 241 | assert 0, "Unsupported initialization: {}".format(init_type) 242 | if hasattr(m, 'bias') and m.bias is not None: 243 | init.constant_(m.bias.data, 0.0) 244 | return init_fun 245 | 246 | 247 | # Get model list for resume 248 | def get_model_list(dirname, key): 249 | if os.path.exists(dirname) is False: 250 | return None 251 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 252 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 253 | if gen_models == []: 254 | return None 255 | gen_models.sort() 256 | last_model_name = gen_models[-1] 257 | return last_model_name 258 | 259 | def get_iteration(dir_name, file_name, net_name): 260 | if os.path.exists(os.path.join(dir_name, file_name)) is False: 261 | return None 262 | if 'latest' in file_name: 263 | gen_models = [os.path.join(dir_name, f) for f in os.listdir(dir_name) if 264 | os.path.isfile(os.path.join(dir_name, f)) and (not 'latest' in f) and (".pt" in f) and (net_name in f)] 265 | if gen_models == []: 266 | return 0 267 | model_name = os.path.basename(natsorted(gen_models)[-1]) 268 | else: 269 | model_name = file_name 270 | iterations = int(model_name.replace('_net_'+net_name+'.pth', '')) 271 | return iterations 272 | 273 | def __denorm(x): 274 | x = (x + 1) / 2 275 | return x.clamp_(0, 1) 276 | 277 | 278 | def __write_images(image_outputs, display_image_num, file_name): 279 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels 280 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 281 | image_tensor = __denorm(image_tensor) 282 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=False) 283 | vutils.save_image(image_grid, file_name, nrow=1) 284 | 285 | 286 | def write_2images(image_outputs, display_image_num, name): 287 | n = len(image_outputs) 288 | __write_images(image_outputs[0:n], display_image_num, name) 289 | 290 | def write_2tensorboard(iterations, results, train_writer, display_image_num, name): 291 | results = [images.expand(-1, 3, -1, -1) for images in results] # expand gray-scale images to 3 channels 292 | image_tensor = torch.cat([images[:display_image_num] for images in results], 0) 293 | image_tensor = __denorm(image_tensor) 294 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=False) 295 | train_writer.add_image(name, image_grid, iterations) 296 | 297 | 298 | def write_flow_visualization(flow, images_input, images_gt, name): 299 | if not os.path.exists(name): 300 | os.mkdir(name) 301 | 302 | kernel_size = 16 303 | n, _, h, w = flow.size() 304 | 305 | x = torch.arange(w).view(1, -1).expand(h, -1) 306 | y = torch.arange(h).view(-1, 1).expand(-1, w) 307 | grid = torch.stack([x,y], dim=0).float().cuda() 308 | grid = grid.unsqueeze(0).expand(n, -1, -1, -1) 309 | grid = grid+flow 310 | 311 | images_input = torch.nn.functional.interpolate(images_input, size=(h,w)) 312 | images_gt = torch.nn.functional.interpolate(images_gt, size=(h,w)) 313 | 314 | index=0 315 | for i in range(0, h, kernel_size): 316 | for j in range(0, w, kernel_size): 317 | img_ = _write_flow_helper(i, j, grid, kernel_size) 318 | out = img_+images_input 319 | 320 | name_write = os.path.join(name, '%04d.jpg'%(index)) 321 | write_2images([out, images_gt,flow[:,0,:,:].unsqueeze(1)/30, flow[:,1,:,:].unsqueeze(1)/30], n, name_write) 322 | 323 | index=index+1 324 | 325 | 326 | def _write_flow_helper( i, j, grid, kernel_size): 327 | n, _, h, w = grid.size() 328 | img = torch.zeros(n,3,h,w).cuda() 329 | _grid = grid[:,:,i:i+kernel_size,j:j+kernel_size] 330 | img[:,:,i:i+kernel_size,j:j+kernel_size] = 0.5 331 | 332 | for _i in range(kernel_size): 333 | for _j in range(kernel_size): 334 | for _b in range(n): 335 | index_x = int(min(max(_grid[_b,0,_i,_j],0), w-1)) 336 | index_y = int(min(max(_grid[_b,1,_i,_j],0), h-1)) 337 | img[_b, :, index_y, index_x]=0.8 338 | 339 | return img -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from main import main 2 | 3 | main('test') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from main import main 2 | 3 | main('train') --------------------------------------------------------------------------------