├── LICENSE.md
├── README.md
├── config.yaml
├── example
├── celeba
│ ├── 1.jpg
│ ├── 1_mask.png
│ ├── 1_tsmooth.png
│ ├── 2.jpg
│ ├── 2_mask.png
│ ├── 2_tsmooth.png
│ ├── 3.jpg
│ ├── 3_mask.png
│ ├── 3_tsmooth.png
│ ├── 4.jpg
│ ├── 4_mask.png
│ └── 4_tsmooth.jpg
├── paris
│ ├── 1.jpg
│ ├── 1_mask.png
│ ├── 1_tsmooth.jpg
│ ├── 2.png
│ ├── 2_mask.png
│ └── 2_tsmooth.jpg
└── places
│ ├── 1.jpg
│ ├── 1_mask.png
│ ├── 1_tsmooth.png
│ ├── 2.jpg
│ ├── 2_mask.png
│ ├── 2_tsmooth.jpg
│ ├── 3.jpg
│ ├── 3_mask.png
│ ├── 3_tsmooth.jpg
│ ├── 4.jpg
│ ├── 4_mask.png
│ └── 4_tsmooth.jpg
├── main.py
├── model_config.yaml
├── resample2d_package
├── __init__.py
├── resample2d_cuda.cc
├── resample2d_kernel.cu
├── resample2d_kernel.cuh
└── setup.py
├── scripts
├── flist.py
├── inception.py
├── matlab
│ ├── code
│ │ └── tsmooth.m
│ ├── dirPlus.m
│ └── generate_structure_images.m
└── metrics.py
├── src
├── base_model.py
├── config.py
├── data.py
├── loss.py
├── metrics.py
├── models.py
├── network.py
├── resample2d.py
├── structure_flow.py
└── utils.py
├── test.py
└── train.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26 |
27 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28 |
29 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30 |
31 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
32 |
33 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
34 |
35 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
36 |
37 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
38 |
39 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
40 |
41 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
42 |
43 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
44 |
45 | ### Section 2 – Scope.
46 |
47 | a. ___License grant.___
48 |
49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
50 |
51 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
52 |
53 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
54 |
55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
56 |
57 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
58 |
59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
60 |
61 | 5. __Downstream recipients.__
62 |
63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
64 |
65 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
66 |
67 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
68 |
69 | b. ___Other rights.___
70 |
71 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
72 |
73 | 2. Patent and trademark rights are not licensed under this Public License.
74 |
75 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
76 |
77 | ### Section 3 – License Conditions.
78 |
79 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
80 |
81 | a. ___Attribution.___
82 |
83 | 1. If You Share the Licensed Material (including in modified form), You must:
84 |
85 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
86 |
87 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
88 |
89 | ii. a copyright notice;
90 |
91 | iii. a notice that refers to this Public License;
92 |
93 | iv. a notice that refers to the disclaimer of warranties;
94 |
95 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
96 |
97 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
98 |
99 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
100 |
101 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
102 |
103 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
104 |
105 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
106 |
107 | ### Section 4 – Sui Generis Database Rights.
108 |
109 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
110 |
111 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
112 |
113 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
114 |
115 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
116 |
117 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
118 |
119 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
120 |
121 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
122 |
123 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
124 |
125 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
126 |
127 | ### Section 6 – Term and Termination.
128 |
129 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
130 |
131 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
132 |
133 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
134 |
135 | 2. upon express reinstatement by the Licensor.
136 |
137 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
138 |
139 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
140 |
141 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
142 |
143 | ### Section 7 – Other Terms and Conditions.
144 |
145 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
146 |
147 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
148 |
149 | ### Section 8 – Interpretation.
150 |
151 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
152 |
153 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
154 |
155 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
156 |
157 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
158 |
159 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
160 | >
161 | > Creative Commons may be contacted at creativecommons.org
162 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StructureFlow
2 | Code for our paper "[StructureFlow: Image Inpainting via Structure-aware Appearance Flow](https://arxiv.org/abs/1908.03852)" (ICCV 2019)
3 |
4 | ### Introduction
5 |
6 | We propose a two-stage image inpainting network which splits the task into two parts: **structure reconstruction** and **texture generation**. In the first stage, edge-preserved smooth images are employed to train a structure reconstructor which completes the missing structures of the inputs. In the second stage, based on the reconstructed structures, a texture generator using appearance flow is designed to yield image details.
7 |
8 |
9 |
10 |
11 |
12 | *(From left to right) Input corrupted images, reconstructed structure images, visualizations of the appearance flow fields, final output images. To visualize the appearance flow fields, we plot the sample points of some typical missing regions. The arrows show the direction of the appearance flow.*
13 |
14 | ### Requirements
15 |
16 | 1. Pytorch >= 1.0
17 | 2. Python 3
18 | 3. NVIDIA GPU + CUDA 9.0
19 | 4. Tensorboard
20 | 5. Matlab
21 |
22 | ### Installation
23 |
24 | 1. Clone this repository
25 |
26 | ```bash
27 | git clone https://github.com/RenYurui/StructureFlow
28 | ```
29 |
30 | 2. Build Gaussian Sampling CUDA package
31 |
32 | ```bash
33 | cd ./StructureFlow/resample2d_package
34 | python setup.py install --user
35 | ```
36 |
37 |
38 | ### Running
39 |
40 | **1. Image Prepare**
41 |
42 | We train our model on three public datasets including Places2, Celeba, and Paris StreetView. We use the irregular mask dataset provided by [PConv](https://arxiv.org/abs/1804.07723). You can download these datasets from their project website.
43 |
44 | 1. [Places2](http://places2.csail.mit.edu)
45 | 2. [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
46 | 3. [Paris Street-View](https://github.com/pathak22/context-encoder)
47 | 4. [Irregular Masks](http://masc.cs.gmu.edu/wiki/partialconv)
48 |
49 | After downloading the datasets, The edge-preserved smooth images can be obtained by using [RTV smooth method](http://www.cse.cuhk.edu.hk/~leojia/projects/texturesep/). Run generation function [`scripts/matlab/generate_structre_images.m`](scripts/matlab/generate_structure_images.m) in your matlab. For example, if you want to generate smooth images for Places2, you can run the following code:
50 |
51 | ```matlab
52 | generate_structure_images("path to Places2 dataset root", "path to output folder");
53 | ```
54 |
55 | Finally, you can generate the image list using script [`scripts/flist.py`](scripts/flist.py) for training and testing.
56 |
57 | **2. Training**
58 |
59 | To train our model, modify the model config file [model_config.yaml](model_config.yaml). You may need to change the path of dataset or the parameters of the networks etc. Then run the following code:
60 |
61 | ```bash
62 | python train.py \
63 | --name=[the name of your experiment] \
64 | --path=[path save the results]
65 | ```
66 |
67 | **3. Testing**
68 |
69 | To output the generated results of the inputs, you can use the [test.py](test.py). Please run the following code:
70 |
71 | ```bash
72 | python test.py \
73 | --name=[the name of your experiment] \
74 | --path=[path of your experiments] \
75 | --input=[input images] \
76 | --mask=[mask images] \
77 | --structure=[structure images] \
78 | --output=[path to save the output images] \
79 | --model=[which model to be tested]
80 | ```
81 |
82 | To evaluate the model performance over a dateset, you can use the provided script [./scripts/matric.py](scripts/metrics.py). This script can provide the PSNR, SSIM and Fréchet Inception Distance ([FID score](https://github.com/mseitzer/pytorch-fid)) of the results.
83 |
84 | ```bash
85 | python ./scripts/metrics.py \
86 | --input_path=[path to ground-truth images] \
87 | --output_path=[path to model outputs] \
88 | --fid_real_path=[path to the real images using to calculate fid]
89 | ```
90 |
91 | **The pre-trained weights can be downloaded from [Places2](https://drive.google.com/open?id=1K7U6fYthC4Acsx0GBde5iszHJWymyv1A), [Celeba](https://drive.google.com/open?id=1PrLgcEd964etxZcHIOE93uUONB9-b6pI), [Paris Street](https://drive.google.com/open?id=18AQpgsYZtA_eL-aJb6n8-geWLdihwXAi).**
92 |
93 | Download the checkpoints and save them to './path_of_your_experiments/name_of_your_experiment/checkpoints'
94 |
95 | For example you can download the checkpoints of Places2 and save them to './results/places/checkpoints' and run the following code:
96 |
97 | ```bash
98 | python test.py \
99 | --name=places \
100 | --path=results \
101 | --input=./example/places/1.jpg \
102 | --mask=./example/places/1_mask.png \
103 | --structure=./example/places/1_tsmooth.png \
104 | --output=./result_images \
105 | --model=3
106 | ```
107 |
108 | ### Citation
109 |
110 | If you find this code is helpful for your research, please cite our paper:
111 |
112 | ```
113 | @inproceedings{ren2019structureflow,
114 | author = {Ren, Yurui and Yu, Xiaoming and Zhang, Ruonan and Li, Thomas H. and Liu, Shan and Li, Ge},
115 | title = {StructureFlow: Image Inpainting via Structure-aware Appearance Flow},
116 | booktitle={IEEE International Conference on Computer Vision (ICCV)},
117 | year = {2019}
118 | }
119 | ```
120 |
121 |
122 |
123 | ### Acknowledgements
124 |
125 | We built our code based on [Edge-Connect](https://github.com/knazeri/edge-connect). Part of the code were derived from [FlowNet2](https://github.com/NVIDIA/flownet2-pytorch). Please consider to cite their papers.
126 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | MODEL: 3 # 1: structure model, 2: inpaint model, 3: structure-inpaint model
2 | VERBOSE: 1 # turns on verbose mode in the output console
3 | GPU: [1] # gpu ids
4 |
5 |
6 | MAX_ITERS: 8e6 # maximum number of iterations to train the model
7 | LR: 0.0001 # learning rate
8 | BETA1: 0 # adam optimizer beta1
9 | BETA2: 0.999 # adam optimizer beta2
10 | LR_POLICY: constant # the method to adjust learning rate (eg: constant|step)
11 | STEP_SIZE: 100000 # Period of learning rate decay (only used when choosing "step" as the lr adjusment method)
12 | GAMMA: 0.5 # Multiplicative factor of learning rate decay. (only used when choosing "step" as the lr adjusment method)
13 | INIT_TYPE: xavier # initialization [gaussian/kaiming/xavier/orthogonal]
14 |
15 | SAVE_INTERVAL: 30 # how many iterations to wait before saving model (0: never)
16 | SAVE_LATEST: 10 # how many iterations to wait before saving lastest model (0: never)
17 | SAMPLE_INTERVAL: 10 # how many iterations to wait before sampling (0: never)
18 | SAMPLE_SIZE: 4 # number of images to sample
19 | EVAL_INTERVAL: 1000 # how many iterations to wait before model evaluation (0: never)
20 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never)
21 | WHICH_ITER: latest # which iterations to load
22 |
23 | DIS_GAN_LOSS: lsgan # type of gan loss
24 |
25 | STRUCTURE_L1: 4 # structure net parameter of l1 loss
26 | STRUCTURE_ADV_GEN: 1 # structure net parameter of gan loss
27 |
28 | FLOW_ADV_GEN: 1 # texture net parameter of gan loss
29 | FLOW_L1: 5 # texture net parameter of l1 loss
30 | FLOW_CORRECTNESS: 0.25 # texture net parameter of sampling correctness loss
31 | VGG_STYLE: 250 # texture net parameter of vgg_style loss (Optional loss on stage_3)
32 | VGG_CONTENT: 0.1 # texture net parameter of vgg_content loss (Optional loss on stage_3)
33 |
34 |
35 | TRAIN_BATCH_SIZE: 8 # batch size
36 | DATA_TRAIN_SIZE: 256 # image size for training
37 | DATA_TEST_SIZE: False # image size for testing (False for never resize)
38 | DATA_FLIP: False # filp image or not when training
39 | DATA_CROP: FALSE #[537,537] # crop size when training (False for never crop )
40 | DATA_MASK_TYPE: from_file # mask type (random_bbox|random_free_form|from_file)
41 | DATA_RANDOM_BBOX_SETTING: # parameters for random_bbox
42 | random_size: False # random hole size according to shape [0.4*shape shape]
43 | shape: [80, 80] # hole size
44 | margin: [0, 0] # minimum distance from the image boundary
45 | num: 3
46 | DATA_RANDOM_FF_SETTING: #parameters for random_free_form
47 | mv: 5
48 | ma: 4.0
49 | ml: 40
50 | mbw: 10
51 | DATA_MASK_FILE: ./txt/irregular_mask.txt #parameters for from_file
52 |
53 | DATA_TRAIN_GT: ./txt/places_gt_train.txt
54 | DATA_TRAIN_STRUCTURE: ./txt/places_structure_train.txt
55 | DATA_VAL_GT: ./txt/places_gt_val.txt
56 | DATA_VAL_STRUCTURE: ./txt/places_structure_val.txt
57 | DATA_VAL_MASK: ./txt/places_mask_val.txt
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/example/celeba/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/1.jpg
--------------------------------------------------------------------------------
/example/celeba/1_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/1_mask.png
--------------------------------------------------------------------------------
/example/celeba/1_tsmooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/1_tsmooth.png
--------------------------------------------------------------------------------
/example/celeba/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/2.jpg
--------------------------------------------------------------------------------
/example/celeba/2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/2_mask.png
--------------------------------------------------------------------------------
/example/celeba/2_tsmooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/2_tsmooth.png
--------------------------------------------------------------------------------
/example/celeba/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/3.jpg
--------------------------------------------------------------------------------
/example/celeba/3_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/3_mask.png
--------------------------------------------------------------------------------
/example/celeba/3_tsmooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/3_tsmooth.png
--------------------------------------------------------------------------------
/example/celeba/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/4.jpg
--------------------------------------------------------------------------------
/example/celeba/4_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/4_mask.png
--------------------------------------------------------------------------------
/example/celeba/4_tsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/celeba/4_tsmooth.jpg
--------------------------------------------------------------------------------
/example/paris/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/1.jpg
--------------------------------------------------------------------------------
/example/paris/1_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/1_mask.png
--------------------------------------------------------------------------------
/example/paris/1_tsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/1_tsmooth.jpg
--------------------------------------------------------------------------------
/example/paris/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/2.png
--------------------------------------------------------------------------------
/example/paris/2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/2_mask.png
--------------------------------------------------------------------------------
/example/paris/2_tsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/paris/2_tsmooth.jpg
--------------------------------------------------------------------------------
/example/places/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/1.jpg
--------------------------------------------------------------------------------
/example/places/1_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/1_mask.png
--------------------------------------------------------------------------------
/example/places/1_tsmooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/1_tsmooth.png
--------------------------------------------------------------------------------
/example/places/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/2.jpg
--------------------------------------------------------------------------------
/example/places/2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/2_mask.png
--------------------------------------------------------------------------------
/example/places/2_tsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/2_tsmooth.jpg
--------------------------------------------------------------------------------
/example/places/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/3.jpg
--------------------------------------------------------------------------------
/example/places/3_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/3_mask.png
--------------------------------------------------------------------------------
/example/places/3_tsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/3_tsmooth.jpg
--------------------------------------------------------------------------------
/example/places/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/4.jpg
--------------------------------------------------------------------------------
/example/places/4_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/4_mask.png
--------------------------------------------------------------------------------
/example/places/4_tsmooth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/example/places/4_tsmooth.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import shutil
5 | from src.config import Config
6 | from src.structure_flow import StructureFlow
7 |
8 | def main(mode=None):
9 | r"""starts the model
10 | Args:
11 | mode : train, test, eval, reads from config file if not specified
12 | """
13 |
14 | config = load_config(mode)
15 | config.MODE = mode
16 | os.environ['CUDA_VISIBLE_DEVICES'] = ''.join(str(e) for e in config.GPU)
17 |
18 | if torch.cuda.is_available():
19 | config.DEVICE = torch.device("cuda")
20 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner
21 | else:
22 | config.DEVICE = torch.device("cpu")
23 |
24 | model = StructureFlow(config)
25 |
26 | if mode == 'train':
27 | # config.print()
28 | print('\nstart training...\n')
29 | model.train()
30 |
31 | elif mode == 'test':
32 | print('\nstart test...\n')
33 | model.test()
34 |
35 | elif mode == 'eval':
36 | print('\nstart eval...\n')
37 | model.eval()
38 |
39 |
40 | def load_config(mode=None):
41 | r"""loads model config
42 | """
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--name', type=str, help='output model name.')
45 | parser.add_argument('--config', type=str, default='model_config.yaml', help='Path to the config file.')
46 | parser.add_argument('--path', type=str, default='./results', help='outputs path')
47 | parser.add_argument("--resume_all", action="store_true", help='load model from checkpoints')
48 | parser.add_argument("--remove_log", action="store_true", help='remove previous tensorboard log files')
49 |
50 |
51 | if mode == 'test':
52 | parser.add_argument('--input', type=str, help='path to the input image files')
53 | parser.add_argument('--mask', type=str, help='path to the mask files')
54 | parser.add_argument('--structure', type=str, help='path to the structure files')
55 | parser.add_argument('--output', type=str, help='path to the output directory')
56 | parser.add_argument('--model', type=int, default=3, help='which model to test')
57 |
58 | opts = parser.parse_args()
59 | config = Config(opts, mode)
60 | output_dir = os.path.join(opts.path, opts.name)
61 | perpare_sub_floder(output_dir)
62 |
63 | if mode == 'train':
64 | config_dir = os.path.join(output_dir, 'config.yaml')
65 | shutil.copyfile(opts.config, config_dir)
66 | return config
67 |
68 |
69 | def perpare_sub_floder(output_path):
70 | img_dir = os.path.join(output_path, 'images')
71 | if not os.path.exists(img_dir):
72 | print("Creating directory: {}".format(img_dir))
73 | os.makedirs(img_dir)
74 |
75 |
76 | checkpoints_dir = os.path.join(output_path, 'checkpoints')
77 | if not os.path.exists(checkpoints_dir):
78 | print("Creating directory: {}".format(checkpoints_dir))
79 | os.makedirs(checkpoints_dir)
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/model_config.yaml:
--------------------------------------------------------------------------------
1 | MODEL: 3 # 1: structure model, 2: inpaint model, 3: structure-inpaint model
2 | VERBOSE: 1 # turns on verbose mode in the output console
3 | GPU: [1] # gpu ids
4 |
5 |
6 | MAX_ITERS: 8e6 # maximum number of iterations to train the model
7 | LR: 0.0001 # learning rate
8 | BETA1: 0 # adam optimizer beta1
9 | BETA2: 0.999 # adam optimizer beta2
10 | LR_POLICY: constant # the method to adjust learning rate (eg: constant|step)
11 | STEP_SIZE: 100000 # Period of learning rate decay (only used when choosing "step" as the lr adjusment method)
12 | GAMMA: 0.5 # Multiplicative factor of learning rate decay. (only used when choosing "step" as the lr adjusment method)
13 | INIT_TYPE: xavier # initialization [gaussian/kaiming/xavier/orthogonal]
14 |
15 | SAVE_INTERVAL: 10000 # how many iterations to wait before saving model
16 | SAVE_LATEST: 1000 # how many iterations to wait before saving lastest model
17 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling
18 | SAMPLE_SIZE: 4 # number of images to sample
19 | EVAL_INTERVAL: 10000 # how many iterations to wait before model evaluation
20 | LOG_INTERVAL: 100 # how many iterations to wait before logging training status
21 | WHICH_ITER: latest # which iterations to load
22 |
23 | DIS_GAN_LOSS: lsgan # type of gan loss
24 |
25 | STRUCTURE_L1: 4 # structure net parameter of l1 loss
26 | STRUCTURE_ADV_GEN: 1 # structure net parameter of gan loss
27 |
28 | FLOW_ADV_GEN: 1 # texture net parameter of gan loss
29 | FLOW_L1: 5 # texture net parameter of l1 loss
30 | FLOW_CORRECTNESS: 0.25 # texture net parameter of sampling correctness loss
31 | VGG_STYLE: 250 # texture net parameter of vgg_style loss (Optional loss on stage_3)
32 | VGG_CONTENT: 0.1 # texture net parameter of vgg_content loss (Optional loss on stage_3)
33 |
34 |
35 | TRAIN_BATCH_SIZE: 8 # batch size
36 | DATA_TRAIN_SIZE: 256 # image size for training
37 | DATA_TEST_SIZE: False # image size for testing (False for never resize)
38 | DATA_FLIP: False # filp image or not when training
39 | DATA_CROP: FALSE #[537,537] # crop size when training (False for never cro)
40 | DATA_MASK_TYPE: from_file # mask type (random_bbox|random_free_form|from_file)
41 | DATA_RANDOM_BBOX_SETTING: # parameters for random_bbox
42 | random_size: False # random hole size according to shape [0.4*shape shape]
43 | shape: [80, 80] # hole size
44 | margin: [0, 0] # minimum distance from the image boundary
45 | num: 3
46 | DATA_RANDOM_FF_SETTING: # parameters for random_free_form
47 | mv: 5
48 | ma: 4.0
49 | ml: 40
50 | mbw: 10
51 | DATA_MASK_FILE: ./txt/irregular_mask.list #parameters for from_file
52 |
53 | # use places365 dataset
54 | DATA_TRAIN_GT: ./txt/places_gt_train.list
55 | DATA_TRAIN_STRUCTURE: ./txt/places_structure_train.list
56 | DATA_VAL_GT: ./txt/places_gt_val.list
57 | DATA_VAL_STRUCTURE: ./txt/places_structure_val.list
58 | DATA_VAL_MASK: ./txt/places_mask_val.list
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/resample2d_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RenYurui/StructureFlow/1ac8f559475452e6b674699671c6b34f000d9ebd/resample2d_package/__init__.py
--------------------------------------------------------------------------------
/resample2d_package/resample2d_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "resample2d_kernel.cuh"
5 |
6 | int resample2d_cuda_forward(
7 | at::Tensor& input1,
8 | at::Tensor& input2,
9 | at::Tensor& output,
10 | int kernel_size,
11 | int dilation) {
12 | resample2d_kernel_forward(input1, input2, output, kernel_size, dilation);
13 | return 1;
14 | }
15 |
16 | int resample2d_cuda_backward(
17 | at::Tensor& input1,
18 | at::Tensor& input2,
19 | at::Tensor& gradOutput,
20 | at::Tensor& gradInput1,
21 | at::Tensor& gradInput2,
22 | int kernel_size,
23 | int dilation) {
24 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, dilation);
25 | return 1;
26 | }
27 |
28 |
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
32 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
33 | }
34 |
35 |
--------------------------------------------------------------------------------
/resample2d_package/resample2d_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #define CUDA_NUM_THREADS 256
6 | #define THREADS_PER_BLOCK 64
7 |
8 | #define DIM0(TENSOR) ((TENSOR).x)
9 | #define DIM1(TENSOR) ((TENSOR).y)
10 | #define DIM2(TENSOR) ((TENSOR).z)
11 | #define DIM3(TENSOR) ((TENSOR).w)
12 |
13 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
14 | #define EPS 1e-8
15 | #define SAFE_DIV(a, b) ( (b==0)? ( (a)/(EPS) ): ( (a)/(b) ) )
16 |
17 |
18 |
19 |
20 | template
21 | __global__ void kernel_resample2d_update_output(const int n,
22 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
23 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
24 | scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size, int dilation) {
25 | int index = blockIdx.x * blockDim.x + threadIdx.x;
26 |
27 | if (index >= n) {
28 | return;
29 | }
30 |
31 | scalar_t val = 0.0f;
32 | scalar_t sum = 0.0f;
33 |
34 |
35 | int dim_b = DIM0(output_size);
36 | int dim_c = DIM1(output_size);
37 | int dim_h = DIM2(output_size);
38 | int dim_w = DIM3(output_size);
39 | int dim_chw = dim_c * dim_h * dim_w;
40 | int dim_hw = dim_h * dim_w;
41 |
42 | int b = ( index / dim_chw ) % dim_b;
43 | int c = ( index / dim_hw ) % dim_c;
44 | int y = ( index / dim_w ) % dim_h;
45 | int x = ( index ) % dim_w;
46 |
47 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
48 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
49 | scalar_t sigma = DIM3_INDEX(input2, b, 2, y, x);
50 |
51 |
52 | scalar_t xf = static_cast(x) + dx;
53 | scalar_t yf = static_cast(y) + dy;
54 | scalar_t alpha = xf - floor(xf); // alpha
55 | scalar_t beta = yf - floor(yf); // beta
56 |
57 |
58 | int idim_h = DIM2(input1_size);
59 | int idim_w = DIM3(input1_size);
60 |
61 |
62 | for (int fy = 0; fy < kernel_size/2; fy += 1) {
63 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0);
64 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0);
65 |
66 | for (int fx = 0; fx < kernel_size/2; fx += 1) {
67 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0);
68 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0);
69 |
70 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha );
71 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha );
72 | scalar_t yT_ = ( static_cast( fy *dilation)+beta );
73 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta );
74 |
75 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma));
76 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma));
77 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma));
78 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma));
79 | // if (sigma==0){
80 | // printf("xL_P %.10f\n", xL_P);
81 | // // printf("%.10f\n", -(xL_*xL_)/(2*sigma*sigma));
82 |
83 | // }
84 |
85 | val += static_cast (yT_P*xL_P * DIM3_INDEX(input1, b, c, yT, xL));
86 | val += static_cast (yT_P*xR_P * DIM3_INDEX(input1, b, c, yT, xR));
87 | val += static_cast (yB_P*xL_P * DIM3_INDEX(input1, b, c, yB, xL));
88 | val += static_cast (yB_P*xR_P * DIM3_INDEX(input1, b, c, yB, xR));
89 | sum += (yT_P*xL_P + yT_P*xR_P + yB_P*xL_P + yB_P*xR_P);
90 | }
91 | }
92 |
93 | output[index] = SAFE_DIV(val, sum);
94 |
95 | }
96 |
97 |
98 | template
99 | __global__ void kernel_resample2d_backward_input1(
100 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
101 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
102 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
103 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, int dilation) {
104 |
105 | int index = blockIdx.x * blockDim.x + threadIdx.x;
106 |
107 | if (index >= n) {
108 | return;
109 | }
110 |
111 | scalar_t sum = 0.0f;
112 | // scalar_t *xL_P = new scalar_t [kernel_size*kernel_size/4];
113 | // scalar_t *xR_P = new scalar_t [kernel_size*kernel_size/4];
114 | // scalar_t *yT_P = new scalar_t [kernel_size*kernel_size/4];
115 | // scalar_t *yB_P = new scalar_t [kernel_size*kernel_size/4];
116 |
117 | int dim_b = DIM0(gradOutput_size);
118 | int dim_c = DIM1(gradOutput_size);
119 | int dim_h = DIM2(gradOutput_size);
120 | int dim_w = DIM3(gradOutput_size);
121 | int dim_chw = dim_c * dim_h * dim_w;
122 | int dim_hw = dim_h * dim_w;
123 |
124 | int b = ( index / dim_chw ) % dim_b;
125 | int c = ( index / dim_hw ) % dim_c;
126 | int y = ( index / dim_w ) % dim_h;
127 | int x = ( index ) % dim_w;
128 |
129 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
130 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
131 | scalar_t sigma = DIM3_INDEX(input2, b, 2, y, x);
132 |
133 |
134 |
135 | scalar_t xf = static_cast(x) + dx;
136 | scalar_t yf = static_cast(y) + dy;
137 | scalar_t alpha = xf - int(xf); // alpha
138 | scalar_t beta = yf - int(yf); // beta
139 |
140 | for (int fy = 0; fy < kernel_size/2; fy += 1) {
141 | for (int fx = 0; fx < kernel_size/2; fx += 1) {
142 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha );
143 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha );
144 | scalar_t yT_ = ( static_cast( fy *dilation)+beta );
145 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta );
146 | // scalar_t xL_ = ( alpha+static_cast(fx) );
147 | // scalar_t xR_ = ( 1.-alpha+static_cast(fx) );
148 | // scalar_t yT_ = ( beta+static_cast(fy) );
149 | // scalar_t yB_ = ( 1-beta+static_cast(fy) );
150 |
151 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma));
152 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma));
153 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma));
154 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma));
155 | // scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_,2*sigma*sigma));
156 | // scalar_t xR_P = exp(-(xR_*xR_)/(2*sigma*sigma));
157 | // scalar_t yT_P = exp(-(yT_*yT_)/(2*sigma*sigma));
158 | // scalar_t yB_P = exp(-(yB_*yB_)/(2*sigma*sigma));
159 | sum += (yT_P*xL_P + yT_P*xR_P + yB_P*xL_P + yB_P*xR_P);
160 | // printf("%f\n", SAFE_DIV(-xL_*xL_, 2*sigma*sigma));
161 | }
162 | }
163 |
164 | int idim_h = DIM2(input1_size);
165 | int idim_w = DIM3(input1_size);
166 |
167 |
168 | for (int fy = 0; fy < kernel_size/2; fy += 1) {
169 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0);
170 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0);
171 | // int yT = max(min( int (floor(yf)-fy ), idim_h-1), 0);
172 | // int yB = max(min( int (floor(yf)+fy+1), idim_h-1), 0);
173 |
174 | for (int fx = 0; fx < kernel_size/2; fx += 1) {
175 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0);
176 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0);
177 | // int xL = max(min( int (floor(xf)-fx ), idim_w-1), 0);
178 | // int xR = max(min( int (floor(xf)+fx+1), idim_w-1), 0);
179 |
180 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha );
181 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha );
182 | scalar_t yT_ = ( static_cast( fy *dilation)+beta );
183 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta );
184 | // scalar_t xL_ = ( alpha+static_cast(fx) );
185 | // scalar_t xR_ = ( 1.-alpha+static_cast(fx) );
186 | // scalar_t yT_ = ( beta+static_cast(fy) );
187 | // scalar_t yB_ = ( 1-beta+static_cast(fy) );
188 |
189 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma));
190 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma));
191 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma));
192 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma));
193 |
194 |
195 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT), (xL)), SAFE_DIV(yT_P*xL_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x));
196 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT), (xR)), SAFE_DIV(yT_P*xR_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x));
197 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB), (xL)), SAFE_DIV(yB_P*xL_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x));
198 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB), (xR)), SAFE_DIV(yB_P*xR_P, sum) * DIM3_INDEX(gradOutput, b, c, y, x));
199 | }
200 | }
201 |
202 | }
203 |
204 | template
205 | __global__ void kernel_resample2d_backward_input2(
206 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
207 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
208 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
209 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, int dilation) {
210 |
211 | int index = blockIdx.x * blockDim.x + threadIdx.x;
212 |
213 | if (index >= n) {
214 | return;
215 | }
216 |
217 | scalar_t grad1 = 0.0;
218 | scalar_t grad2 = 0.0;
219 | scalar_t sum = 0.0;
220 |
221 |
222 |
223 | int dim_b = DIM0(gradInput_size);
224 | int dim_c = DIM1(gradInput_size);
225 | int dim_h = DIM2(gradInput_size);
226 | int dim_w = DIM3(gradInput_size);
227 | int dim_chw = dim_c * dim_h * dim_w;
228 | int dim_hw = dim_h * dim_w;
229 |
230 | int b = ( index / dim_chw ) % dim_b;
231 | int c = ( index / dim_hw ) % dim_c;
232 | int y = ( index / dim_w ) % dim_h;
233 | int x = ( index ) % dim_w;
234 |
235 | int odim_c = DIM1(gradOutput_size);
236 |
237 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
238 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
239 | scalar_t sigma = DIM3_INDEX(input2, b, 2, y, x);
240 |
241 |
242 | scalar_t xf = static_cast(x) + dx;
243 | scalar_t yf = static_cast(y) + dy;
244 | scalar_t alpha = xf - floor(xf); // alpha
245 | scalar_t beta = yf - floor(yf); // beta
246 |
247 |
248 | int idim_h = DIM2(input1_size);
249 | int idim_w = DIM3(input1_size);
250 | scalar_t sumgrad = 0.0;
251 |
252 | for (int fy = 0; fy < kernel_size/2; fy += 1) {
253 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0);
254 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0);
255 |
256 | for (int fx = 0; fx < kernel_size/2; fx += 1) {
257 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0);
258 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0);
259 |
260 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha );
261 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha );
262 | scalar_t yT_ = ( static_cast( fy *dilation)+beta );
263 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta );
264 |
265 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma));
266 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma));
267 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma));
268 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma));
269 | sum += (yT_P*xL_P + yT_P*xR_P + yB_P*xL_P + yB_P*xR_P);
270 |
271 | for (int ch = 0; ch < odim_c; ++ch) {
272 | if (c==0) {
273 | grad1 += SAFE_DIV(xL_ * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL), -sigma*sigma);
274 | grad1 -= SAFE_DIV(xR_ * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR), -sigma*sigma);
275 | grad1 += SAFE_DIV(xL_ * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL), -sigma*sigma);
276 | grad1 -= SAFE_DIV(xR_ * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR), -sigma*sigma);
277 | sumgrad += SAFE_DIV(( xL_*yT_P*xL_P - xR_*yT_P*xR_P + xL_*yB_P*xL_P - xR_*yB_P*xR_P ), -sigma*sigma);
278 | }
279 | else if (c==1) {
280 | grad1 += SAFE_DIV(yT_ * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL), -sigma*sigma);
281 | grad1 += SAFE_DIV(yT_ * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR), -sigma*sigma);
282 | grad1 -= SAFE_DIV(yB_ * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL), -sigma*sigma);
283 | grad1 -= SAFE_DIV(yB_ * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR), -sigma*sigma);
284 | sumgrad += SAFE_DIV(( yT_*yT_P*xL_P + yT_*yT_P*xR_P - yB_*yB_P*xL_P - yB_*yB_P*xR_P ), -sigma*sigma);
285 | }
286 | else if (c==2) {
287 | grad1 += SAFE_DIV((yT_*yT_+xL_*xL_) * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL), sigma*sigma*sigma);
288 | grad1 += SAFE_DIV((yT_*yT_+xR_*xR_) * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR), sigma*sigma*sigma);
289 | grad1 += SAFE_DIV((yB_*yB_+xL_*xL_) * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL), sigma*sigma*sigma);
290 | grad1 += SAFE_DIV((yB_*yB_+xR_*xR_) * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR), sigma*sigma*sigma);
291 | sumgrad += SAFE_DIV(( (yT_*yT_+xL_*xL_)*yT_P*xL_P + (yT_*yT_+xR_*xR_)*yT_P*xR_P + (yB_*yB_+xL_*xL_)*yB_P*xL_P + (yB_*yB_+xR_*xR_)*yB_P*xR_P ), sigma*sigma*sigma);
292 |
293 | }
294 | }
295 | }
296 | }
297 |
298 |
299 |
300 | for (int fy = 0; fy < kernel_size/2; fy += 1) {
301 | int yT = max(min( int (floor(yf)-fy*dilation), idim_h-1), 0);
302 | int yB = max(min( int (floor(yf)+(fy+1)*dilation),idim_h-1), 0);
303 |
304 | for (int fx = 0; fx < kernel_size/2; fx += 1) {
305 | int xL = max(min( int (floor(xf)-fx*dilation ), idim_w-1), 0);
306 | int xR = max(min( int (floor(xf)+(fx+1)*dilation), idim_w-1), 0);
307 |
308 | scalar_t xL_ = ( static_cast( fx *dilation)+alpha );
309 | scalar_t xR_ = ( static_cast((1.+fx)*dilation)-alpha );
310 | scalar_t yT_ = ( static_cast( fy *dilation)+beta );
311 | scalar_t yB_ = ( static_cast((1.+fy)*dilation)-beta );
312 |
313 | scalar_t xL_P = exp(SAFE_DIV(-xL_*xL_, 2*sigma*sigma));
314 | scalar_t xR_P = exp(SAFE_DIV(-xR_*xR_, 2*sigma*sigma));
315 | scalar_t yT_P = exp(SAFE_DIV(-yT_*yT_, 2*sigma*sigma));
316 | scalar_t yB_P = exp(SAFE_DIV(-yB_*yB_, 2*sigma*sigma));
317 | for (int ch = 0; ch < odim_c; ++ch) {
318 | grad2 += sumgrad/odim_c * yT_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xL);
319 | grad2 += sumgrad/odim_c * yT_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yT, xR);
320 | grad2 += sumgrad/odim_c * yB_P * xL_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xL);
321 | grad2 += sumgrad/odim_c * yB_P * xR_P * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, yB, xR);
322 | }
323 |
324 | }
325 | }
326 |
327 |
328 | gradInput[index] = SAFE_DIV(grad1, sum) - SAFE_DIV(grad2, sum*sum);
329 |
330 | }
331 |
332 |
333 |
334 |
335 | void resample2d_kernel_forward(
336 | at::Tensor& input1,
337 | at::Tensor& input2,
338 | at::Tensor& output,
339 | int kernel_size,
340 | int dilation) {
341 |
342 | int n = output.numel();
343 |
344 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
345 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
346 |
347 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
348 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
349 |
350 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
351 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
352 |
353 | // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF
354 | AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] {
355 | kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
356 | n,
357 | input1.data(),
358 | input1_size,
359 | input1_stride,
360 | input2.data(),
361 | input2_size,
362 | input2_stride,
363 | output.data(),
364 | output_size,
365 | output_stride,
366 | kernel_size,
367 | dilation);
368 |
369 | }));
370 |
371 | // TODO: ATen-equivalent check
372 |
373 | // THCudaCheck(cudaGetLastError());
374 |
375 | }
376 |
377 | void resample2d_kernel_backward(
378 | at::Tensor& input1,
379 | at::Tensor& input2,
380 | at::Tensor& gradOutput,
381 | at::Tensor& gradInput1,
382 | at::Tensor& gradInput2,
383 | int kernel_size,
384 | int dilation) {
385 |
386 | int n = gradOutput.numel();
387 |
388 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
389 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
390 |
391 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
392 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
393 |
394 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
395 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
396 |
397 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
398 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
399 |
400 | AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] {
401 |
402 | kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
403 | n,
404 | input1.data(),
405 | input1_size,
406 | input1_stride,
407 | input2.data(),
408 | input2_size,
409 | input2_stride,
410 | gradOutput.data(),
411 | gradOutput_size,
412 | gradOutput_stride,
413 | gradInput1.data(),
414 | gradInput1_size,
415 | gradInput1_stride,
416 | kernel_size,
417 | dilation
418 | );
419 |
420 | }));
421 |
422 | const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3));
423 | const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3));
424 |
425 | n = gradInput2.numel();
426 |
427 | AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] {
428 |
429 |
430 | kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
431 | n,
432 | input1.data(),
433 | input1_size,
434 | input1_stride,
435 | input2.data(),
436 | input2_size,
437 | input2_stride,
438 | gradOutput.data(),
439 | gradOutput_size,
440 | gradOutput_stride,
441 | gradInput2.data(),
442 | gradInput2_size,
443 | gradInput2_stride,
444 | kernel_size,
445 | dilation
446 | );
447 |
448 | }));
449 |
450 | // TODO: Use the ATen equivalent to get last error
451 |
452 | // THCudaCheck(cudaGetLastError());
453 |
454 | }
455 |
--------------------------------------------------------------------------------
/resample2d_package/resample2d_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void resample2d_kernel_forward(
6 | at::Tensor& input1,
7 | at::Tensor& input2,
8 | at::Tensor& output,
9 | int kernel_size,
10 | int dilation);
11 |
12 | void resample2d_kernel_backward(
13 | at::Tensor& input1,
14 | at::Tensor& input2,
15 | at::Tensor& gradOutput,
16 | at::Tensor& gradInput1,
17 | at::Tensor& gradInput2,
18 | int kernel_size,
19 | int dilation);
20 |
--------------------------------------------------------------------------------
/resample2d_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | #'-gencode', 'arch=compute_50,code=sm_50',
12 | #'-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | #'-gencode', 'arch=compute_70,code=sm_70',
16 | #'-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='resample2d_cuda',
21 | ext_modules=[
22 | CUDAExtension('resample2d_cuda', [
23 | 'resample2d_cuda.cc',
24 | 'resample2d_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
31 |
--------------------------------------------------------------------------------
/scripts/flist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--path', type=str, help='path to the dataset')
7 | parser.add_argument('--output', type=str, help='path to the file list')
8 | args = parser.parse_args()
9 |
10 | ext = {'.jpg', '.png'}
11 |
12 | images = []
13 | for root, dirs, files in os.walk(args.path):
14 | print('loading ' + root)
15 | for file in files:
16 | if os.path.splitext(file)[1] in ext:
17 | images.append(os.path.join(root, file))
18 |
19 | images = sorted(images)
20 | np.savetxt(args.output, images, fmt='%s')
--------------------------------------------------------------------------------
/scripts/inception.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torchvision import models
4 |
5 |
6 | class InceptionV3(nn.Module):
7 | """Pretrained InceptionV3 network returning feature maps"""
8 |
9 | # Index of default block of inception to return,
10 | # corresponds to output of final average pooling
11 | DEFAULT_BLOCK_INDEX = 3
12 |
13 | # Maps feature dimensionality to their output blocks indices
14 | BLOCK_INDEX_BY_DIM = {
15 | 64: 0, # First max pooling features
16 | 192: 1, # Second max pooling featurs
17 | 768: 2, # Pre-aux classifier features
18 | 2048: 3 # Final average pooling features
19 | }
20 |
21 | def __init__(self,
22 | output_blocks=[DEFAULT_BLOCK_INDEX],
23 | resize_input=True,
24 | normalize_input=True,
25 | requires_grad=False):
26 | """Build pretrained InceptionV3
27 | Parameters
28 | ----------
29 | output_blocks : list of int
30 | Indices of blocks to return features of. Possible values are:
31 | - 0: corresponds to output of first max pooling
32 | - 1: corresponds to output of second max pooling
33 | - 2: corresponds to output which is fed to aux classifier
34 | - 3: corresponds to output of final average pooling
35 | resize_input : bool
36 | If true, bilinearly resizes input to width and height 299 before
37 | feeding input to model. As the network without fully connected
38 | layers is fully convolutional, it should be able to handle inputs
39 | of arbitrary size, so resizing might not be strictly needed
40 | normalize_input : bool
41 | If true, normalizes the input to the statistics the pretrained
42 | Inception network expects
43 | requires_grad : bool
44 | If true, parameters of the model require gradient. Possibly useful
45 | for finetuning the network
46 | """
47 | super(InceptionV3, self).__init__()
48 |
49 | self.resize_input = resize_input
50 | self.normalize_input = normalize_input
51 | self.output_blocks = sorted(output_blocks)
52 | self.last_needed_block = max(output_blocks)
53 |
54 | assert self.last_needed_block <= 3, \
55 | 'Last possible output block index is 3'
56 |
57 | self.blocks = nn.ModuleList()
58 |
59 | inception = models.inception_v3(pretrained=True)
60 |
61 | # Block 0: input to maxpool1
62 | block0 = [
63 | inception.Conv2d_1a_3x3,
64 | inception.Conv2d_2a_3x3,
65 | inception.Conv2d_2b_3x3,
66 | nn.MaxPool2d(kernel_size=3, stride=2)
67 | ]
68 | self.blocks.append(nn.Sequential(*block0))
69 |
70 | # Block 1: maxpool1 to maxpool2
71 | if self.last_needed_block >= 1:
72 | block1 = [
73 | inception.Conv2d_3b_1x1,
74 | inception.Conv2d_4a_3x3,
75 | nn.MaxPool2d(kernel_size=3, stride=2)
76 | ]
77 | self.blocks.append(nn.Sequential(*block1))
78 |
79 | # Block 2: maxpool2 to aux classifier
80 | if self.last_needed_block >= 2:
81 | block2 = [
82 | inception.Mixed_5b,
83 | inception.Mixed_5c,
84 | inception.Mixed_5d,
85 | inception.Mixed_6a,
86 | inception.Mixed_6b,
87 | inception.Mixed_6c,
88 | inception.Mixed_6d,
89 | inception.Mixed_6e,
90 | ]
91 | self.blocks.append(nn.Sequential(*block2))
92 |
93 | # Block 3: aux classifier to final avgpool
94 | if self.last_needed_block >= 3:
95 | block3 = [
96 | inception.Mixed_7a,
97 | inception.Mixed_7b,
98 | inception.Mixed_7c,
99 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
100 | ]
101 | self.blocks.append(nn.Sequential(*block3))
102 |
103 | for param in self.parameters():
104 | param.requires_grad = requires_grad
105 |
106 | def forward(self, inp):
107 | """Get Inception feature maps
108 | Parameters
109 | ----------
110 | inp : torch.autograd.Variable
111 | Input tensor of shape Bx3xHxW. Values are expected to be in
112 | range (0, 1)
113 | Returns
114 | -------
115 | List of torch.autograd.Variable, corresponding to the selected output
116 | block, sorted ascending by index
117 | """
118 | outp = []
119 | x = inp
120 |
121 | if self.resize_input:
122 | x = F.upsample(x, size=(299, 299), mode='bilinear')
123 |
124 | if self.normalize_input:
125 | x = x.clone()
126 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
127 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
128 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
129 |
130 | for idx, block in enumerate(self.blocks):
131 | x = block(x)
132 | if idx in self.output_blocks:
133 | outp.append(x)
134 |
135 | if idx == self.last_needed_block:
136 | break
137 |
138 | return outp
139 |
--------------------------------------------------------------------------------
/scripts/matlab/code/tsmooth.m:
--------------------------------------------------------------------------------
1 | function S = tsmooth(I,lambda,sigma,sharpness,maxIter,mask)
2 | %tsmooth - Structure Extraction from Texture via Relative Total Variation
3 | % modified for StructureFlow image inpainting.
4 | % S = tsmooth(I, lambda, sigma, maxIter, mask) extracts structure S from
5 | % structure+texture input I, with smoothness weight lambda, scale
6 | % parameter sigma and iteration number maxIter.
7 | %
8 | % Paras:
9 | % @I : Input UINT8 image, both grayscale and color images are acceptable.
10 | % @lambda : Parameter controlling the degree of smooth.
11 | % Range (0, 0.05], 0.01 by default.
12 | % @sigma : Parameter specifying the maximum size of texture elements.
13 | % Range (0, 6], 3 by defalut.
14 | % @sharpness : Parameter controlling the sharpness of the final results,
15 | % which corresponds to \epsilon_s in the paper [1]. The smaller the value, the sharper the result.
16 | % Range (1e-3, 0.03], 0.02 by defalut.
17 | % @maxIter : Number of itearations, 4 by default.
18 | % @mask : The mask of input image of inpainting task.
19 | %
20 | % Example
21 | % ==========
22 | % I = imread('Bishapur_zan.jpg');
23 | % S = tsmooth(I); % Default Parameters (lambda = 0.01, sigma = 3, sharpness = 0.02, maxIter = 4)
24 | % figure, imshow(I), figure, imshow(S);
25 | %
26 | % ==========
27 | % The Code is created based on the method described in the following paper
28 | % [1] "Structure Extraction from Texture via Relative Total Variation", Li Xu, Qiong Yan, Yang Xia, Jiaya Jia, ACM Transactions on Graphics,
29 | % (SIGGRAPH Asia 2012), 2012.
30 | % The code and the algorithm are for non-comercial use only.
31 | %
32 | % Author: Li Xu (xuli@cse.cuhk.edu.hk)
33 | % Date : 08/25/2012
34 | % Version : 1.0
35 | % Copyright 2012, The Chinese University of Hong Kong.
36 | %
37 |
38 | if (~exist('lambda','var'))
39 | lambda=0.01;
40 | end
41 | if (~exist('sigma','var'))
42 | sigma=3.0;
43 | end
44 | if (~exist('sharpness','var'))
45 | sharpness = 0.02;
46 | end
47 | if (~exist('maxIter','var'))
48 | maxIter=4;
49 | end
50 | if (~exist('mask','var'))
51 | [w,h,~] = size(I);
52 | mask=zeros(w,h);
53 | end
54 | I = im2double(I);
55 | mask = im2double(mask);
56 | x = I;
57 | sigma_iter = sigma;
58 | lambda = lambda/2.0;
59 | dec=2.0;
60 | for iter = 1:maxIter
61 | [wx, wy] = computeTextureWeights(x, mask, sigma_iter, sharpness);
62 | x = solveLinearEquation(I, wx, wy, lambda);
63 | sigma_iter = sigma_iter/dec;
64 | if sigma_iter < 0.5
65 | sigma_iter = 0.5;
66 | end
67 | end
68 | S = x;
69 | end
70 |
71 | % function [retx, rety] = computeTextureWeights(fin, sigma,sharpness)
72 | %
73 | % fx = diff(fin,1,2);
74 | % fx = padarray(fx, [0 1 0], 'post');
75 | % fy = diff(fin,1,1);
76 | % fy = padarray(fy, [1 0 0], 'post');
77 | %
78 | % vareps_s = sharpness;
79 | % vareps = 0.001;
80 | %
81 | % wto = max(sum(sqrt(fx.^2+fy.^2),3)/size(fin,3),vareps_s).^(-1);
82 | % fbin = lpfilter(fin, sigma);
83 | % gfx = diff(fbin,1,2);
84 | % gfx = padarray(gfx, [0 1], 'post');
85 | % gfy = diff(fbin,1,1);
86 | % gfy = padarray(gfy, [1 0], 'post');
87 | % wtbx = max(sum(abs(gfx),3)/size(fin,3),vareps).^(-1);
88 | % wtby = max(sum(abs(gfy),3)/size(fin,3),vareps).^(-1);
89 | % retx = wtbx.*wto;
90 | % rety = wtby.*wto;
91 | %
92 | % retx(:,end) = 0;
93 | % rety(end,:) = 0;
94 | %
95 | % end
96 | function [retx, rety] = computeTextureWeights(fin, mask, sigma,sharpness)
97 |
98 | fx = diff(fin,1,2);
99 | fx = padarray(fx, [0 1 0], 'post');
100 | fy = diff(fin,1,1);
101 | fy = padarray(fy, [1 0 0], 'post');
102 |
103 | mask(mask>0)=1;
104 | fx=fx.*(1-mask);
105 | fy=fy.*(1-mask);
106 |
107 | vareps_s = sharpness;
108 | vareps = 0.001;
109 |
110 | wto = max(sum(sqrt(fx.^2+fy.^2),3)/size(fin,3),vareps_s).^(-1);
111 | fbin = lpfilter(fin, sigma);
112 | gfx = diff(fbin,1,2);
113 | gfx = padarray(gfx, [0 1], 'post');
114 | gfy = diff(fbin,1,1);
115 | gfy = padarray(gfy, [1 0], 'post');
116 | wtbx = max(sum(abs(gfx),3)/size(fin,3),vareps).^(-1);
117 | wtby = max(sum(abs(gfy),3)/size(fin,3),vareps).^(-1);
118 | retx = wtbx.*wto;
119 | rety = wtby.*wto;
120 |
121 | retx(:,end) = 0;
122 | rety(end,:) = 0;
123 |
124 | end
125 |
126 | function ret = conv2_sep(im, sigma)
127 | ksize = bitor(round(5*sigma),1);
128 | g = fspecial('gaussian', [1,ksize], sigma);
129 | ret = conv2(im,g,'same');
130 | ret = conv2(ret,g','same');
131 | end
132 |
133 | function FBImg = lpfilter(FImg, sigma)
134 | FBImg = FImg;
135 | for ic = 1:size(FBImg,3)
136 | FBImg(:,:,ic) = conv2_sep(FImg(:,:,ic), sigma);
137 | end
138 | end
139 |
140 | function OUT = solveLinearEquation(IN, wx, wy, lambda)
141 | %
142 | % The code for constructing inhomogenious Laplacian is adapted from
143 | % the implementaion of the wlsFilter.
144 | %
145 | % For color images, we enforce wx and wy be same for three channels
146 | % and thus the pre-conditionar only need to be computed once.
147 | %
148 | [r,c,ch] = size(IN);
149 | k = r*c;
150 | dx = -lambda*wx(:);
151 | dy = -lambda*wy(:);
152 | B(:,1) = dx;
153 | B(:,2) = dy;
154 | d = [-r,-1];
155 | A = spdiags(B,d,k,k);
156 | e = dx;
157 | w = padarray(dx, r, 'pre'); w = w(1:end-r);
158 | s = dy;
159 | n = padarray(dy, 1, 'pre'); n = n(1:end-1);
160 | D = 1-(e+w+s+n);
161 | A = A + A' + spdiags(D, 0, k, k);
162 | if exist('ichol','builtin')
163 | L = ichol(A,struct('michol','on'));
164 | OUT = IN;
165 | for ii=1:ch
166 | tin = IN(:,:,ii);
167 | [tout, flag] = pcg(A, tin(:),0.1,100, L, L');
168 | OUT(:,:,ii) = reshape(tout, r, c);
169 | end
170 | else
171 | OUT = IN;
172 | for ii=1:ch
173 | tin = IN(:,:,ii);
174 | tout = A\tin(:);
175 | OUT(:,:,ii) = reshape(tout, r, c);
176 | end
177 | end
178 |
179 | end
--------------------------------------------------------------------------------
/scripts/matlab/dirPlus.m:
--------------------------------------------------------------------------------
1 | function output = dirPlus(rootPath, varargin)
2 | %dirPlus Recursively collect files or directories within a folder.
3 | % LIST = dirPlus(ROOTPATH) will search recursively through the folder
4 | % tree beneath ROOTPATH and collect a cell array LIST of all files it
5 | % finds. The list will contain the absolute paths to each file starting
6 | % at ROOTPATH.
7 | %
8 | % LIST = dirPlus(ROOTPATH, 'PropertyName', PropertyValue, ...) will
9 | % modify how files and directories are selected, as well as the format of
10 | % LIST, based on the property/value pairs specified. Valid properties
11 | % that the user can set are:
12 | %
13 | % GENERAL:
14 | % 'Struct' - A logical value determining if the output LIST should
15 | % instead be a structure array of the form returned by
16 | % the DIR function. If TRUE, LIST will be an N-by-1
17 | % structure array instead of a cell array.
18 | % 'Depth' - A non-negative integer value for the maximum folder
19 | % tree depth that dirPlus will search through. A value
20 | % of 0 will only search in ROOTPATH, a value of 1 will
21 | % search in ROOTPATH and its subfolders, etc. Default
22 | % (and maximum allowable) value is the current
23 | % recursion limit set on the root object (i.e.
24 | % get(0, 'RecursionLimit')).
25 | % 'ReturnDirs' - A logical value determining if the output will be a
26 | % list of files or subdirectories. If TRUE, LIST will
27 | % be a cell array of subdirectory names/paths. Default
28 | % is FALSE.
29 | % 'PrependPath' - A logical value determining if the full path from
30 | % ROOTPATH to the file/subdirectory is prepended to
31 | % each item in LIST. The default TRUE will prepend the
32 | % full path, otherwise just the file/subdirectory name
33 | % is returned. This setting is ignored if the 'Struct'
34 | % argument is TRUE.
35 | %
36 | % FILE-SPECIFIC:
37 | % 'FileFilter' - A string defining a regular-expression pattern
38 | % that will be applied to the file name. Only files
39 | % matching the pattern will be included in LIST.
40 | % Default is '' (i.e. all files are included).
41 | % 'ValidateFileFcn' - A handle to a function that takes as input a
42 | % structure of the form returned by the DIR
43 | % function and returns a logical value. This
44 | % function will be applied to all files found and
45 | % only files that have a TRUE return value will be
46 | % included in LIST. Default is [] (i.e. all files
47 | % are included).
48 | %
49 | % DIRECTORY-SPECIFIC:
50 | % 'DirFilter' - A string defining a regular-expression pattern
51 | % that will be applied to the subdirectory name.
52 | % Only subdirectories matching the pattern will be
53 | % considered valid (i.e. included in LIST themselves
54 | % or having their files included in LIST). Default
55 | % is '' (i.e. all subdirectories are valid). The
56 | % setting of the 'RecurseInvalid' argument
57 | % determines if invalid subdirectories are still
58 | % recursed down.
59 | % 'ValidateDirFcn' - A handle to a function that takes as input a
60 | % structure of the form returned by the DIR function
61 | % and returns a logical value. This function will be
62 | % applied to all subdirectories found and only
63 | % subdirectories that have a TRUE return value will
64 | % be considered valid (i.e. included in LIST
65 | % themselves or having their files included in
66 | % LIST). Default is [] (i.e. all subdirectories are
67 | % valid). The setting of the 'RecurseInvalid'
68 | % argument determines if invalid subdirectories are
69 | % still recursed down.
70 | % 'RecurseInvalid' - A logical value determining if invalid
71 | % subdirectories (as identified by the 'DirFilter'
72 | % and 'ValidateDirFcn' arguments) should still be
73 | % recursed down. Default is FALSE (i.e the recursive
74 | % searching stops at invalid subdirectories).
75 | %
76 | % Examples:
77 | %
78 | % 1) Find all '.m' files:
79 | %
80 | % fileList = dirPlus(rootPath, 'FileFilter', '\.m$');
81 | %
82 | % 2) Find all '.m' files, returning the list as a structure array:
83 | %
84 | % fileList = dirPlus(rootPath, 'Struct', true, ...
85 | % 'FileFilter', '\.m$');
86 | %
87 | % 3) Find all '.jpg', '.png', and '.tif' files:
88 | %
89 | % fileList = dirPlus(rootPath, 'FileFilter', '\.(jpg|png|tif)$');
90 | %
91 | % 4) Find all '.m' files in the given folder and its subfolders:
92 | %
93 | % fileList = dirPlus(rootPath, 'Depth', 1, 'FileFilter', '\.m$');
94 | %
95 | % 5) Find all '.m' files, returning only the file names:
96 | %
97 | % fileList = dirPlus(rootPath, 'FileFilter', '\.m$', ...
98 | % 'PrependPath', false);
99 | %
100 | % 6) Find all '.jpg' files with a size of more than 1MB:
101 | %
102 | % bigFcn = @(s) (s.bytes > 1024^2);
103 | % fileList = dirPlus(rootPath, 'FileFilter', '\.jpg$', ...
104 | % 'ValidateFcn', bigFcn);
105 | %
106 | % 7) Find all '.m' files contained in folders containing the string
107 | % 'addons', recursing without restriction:
108 | %
109 | % fileList = dirPlus(rootPath, 'DirFilter', 'addons', ...
110 | % 'FileFilter', '\.m$', ...
111 | % 'RecurseInvalid', true);
112 | %
113 | % See also dir, regexp.
114 |
115 | % Author: Ken Eaton
116 | % Version: MATLAB R2016b - R2011a
117 | % Last modified: 4/14/17
118 | % Copyright 2017 by Kenneth P. Eaton
119 | % Copyright 2017 by Stephen Larroque - backwards compatibility
120 | %--------------------------------------------------------------------------
121 |
122 | % Create input parser (only have to do this once, hence the use of a
123 | % persistent variable):
124 |
125 | persistent parser
126 | if isempty(parser)
127 | recursionLimit = get(0, 'RecursionLimit');
128 | parser = inputParser();
129 | parser.FunctionName = 'dirPlus';
130 | if verLessThan('matlab', '8.2') % MATLAB R2013b = 8.2
131 | addPVPair = @addParamValue;
132 | else
133 | parser.PartialMatching = true;
134 | addPVPair = @addParameter;
135 | end
136 |
137 | % Add general parameters:
138 |
139 | addRequired(parser, 'rootPath', ...
140 | @(s) validateattributes(s, {'char'}, {'nonempty'}));
141 | addPVPair(parser, 'Struct', false, ...
142 | @(b) validateattributes(b, {'logical'}, {'scalar'}));
143 | addPVPair(parser, 'Depth', recursionLimit, ...
144 | @(s) validateattributes(s, {'numeric'}, ...
145 | {'scalar', 'nonnegative', ...
146 | 'nonnan', 'integer', ...
147 | '<=', recursionLimit}));
148 | addPVPair(parser, 'ReturnDirs', false, ...
149 | @(b) validateattributes(b, {'logical'}, {'scalar'}));
150 | addPVPair(parser, 'PrependPath', true, ...
151 | @(b) validateattributes(b, {'logical'}, {'scalar'}));
152 |
153 | % Add file-specific parameters:
154 |
155 | addPVPair(parser, 'FileFilter', '', ...
156 | @(s) validateattributes(s, {'char'}, {'row'}));
157 | addPVPair(parser, 'ValidateFileFcn', [], ...
158 | @(f) validateattributes(f, {'function_handle'}, {'scalar'}));
159 |
160 | % Add directory-specific parameters:
161 |
162 | addPVPair(parser, 'DirFilter', '', ...
163 | @(s) validateattributes(s, {'char'}, {'row'}));
164 | addPVPair(parser, 'ValidateDirFcn', [], ...
165 | @(f) validateattributes(f, {'function_handle'}, {'scalar'}));
166 | addPVPair(parser, 'RecurseInvalid', false, ...
167 | @(b) validateattributes(b, {'logical'}, {'scalar'}));
168 |
169 | end
170 |
171 | % Parse input and recursively find contents:
172 |
173 | parse(parser, rootPath, varargin{:});
174 | output = dirPlus_core(parser.Results.rootPath, ...
175 | rmfield(parser.Results, 'rootPath'), 0, true);
176 | if parser.Results.Struct
177 | output = vertcat(output{:});
178 | end
179 |
180 | end
181 |
182 | %~~~Begin local functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
183 |
184 | %--------------------------------------------------------------------------
185 | % Core recursive function to find files and directories.
186 | function output = dirPlus_core(rootPath, optionStruct, depth, isValid)
187 |
188 | % Backwards compatibility for fullfile:
189 |
190 | persistent fullfilecell
191 | if isempty(fullfilecell)
192 | if verLessThan('matlab', '8.0') % MATLAB R2012b = 8.0
193 | fullfilecell = @(P, C) cellfun(@(S) fullfile(P, S), C, ...
194 | 'UniformOutput', false);
195 | else
196 | fullfilecell = @fullfile;
197 | end
198 | end
199 |
200 | % Get current directory contents:
201 |
202 | rootData = dir(rootPath);
203 | dirIndex = [rootData.isdir];
204 | subDirs = {};
205 | validIndex = [];
206 |
207 | % Find valid subdirectories, only if necessary:
208 |
209 | if (depth < optionStruct.Depth) || optionStruct.ReturnDirs
210 |
211 | % Get subdirectories, not counting current or parent:
212 |
213 | dirData = rootData(dirIndex);
214 | subDirs = {dirData.name}.';
215 | index = ~ismember(subDirs, {'.', '..'});
216 | dirData = dirData(index);
217 | subDirs = subDirs(index);
218 | validIndex = true(size(subDirs));
219 | if any(validIndex)
220 | % Apply directory name filter, if specified:
221 | nameFilter = optionStruct.DirFilter;
222 | if ~isempty(nameFilter)
223 | validIndex = ~cellfun(@isempty, regexp(subDirs, nameFilter));
224 | end
225 | if any(validIndex)
226 | % Apply validation function to the directory list, if specified:
227 | validateFcn = optionStruct.ValidateDirFcn;
228 | if ~isempty(validateFcn)
229 | validIndex(validIndex) = arrayfun(validateFcn, ...
230 | dirData(validIndex));
231 | end
232 | end
233 | end
234 | end
235 | % Determine if files or subdirectories are being returned:
236 | if optionStruct.ReturnDirs % Return directories
237 | % Use structure data or prepend full path, if specified:
238 | if optionStruct.Struct
239 | output = {dirData(validIndex)};
240 | elseif any(validIndex) && optionStruct.PrependPath
241 | output = fullfilecell(rootPath, subDirs(validIndex));
242 | else
243 | output = subDirs(validIndex);
244 | end
245 | elseif isValid % Return files
246 | % Find all files in the current directory:
247 | fileData = rootData(~dirIndex);
248 | output = {fileData.name}.';
249 |
250 | if ~isempty(output)
251 |
252 | % Apply file name filter, if specified:
253 |
254 | fileFilter = optionStruct.FileFilter;
255 | if ~isempty(fileFilter)
256 | filterIndex = ~cellfun(@isempty, regexp(output, fileFilter));
257 | fileData = fileData(filterIndex);
258 | output = output(filterIndex);
259 | end
260 |
261 | if ~isempty(output)
262 |
263 | % Apply validation function to the file list, if specified:
264 |
265 | validateFcn = optionStruct.ValidateFileFcn;
266 | if ~isempty(validateFcn)
267 | validateIndex = arrayfun(validateFcn, fileData);
268 | fileData = fileData(validateIndex);
269 | output = output(validateIndex);
270 | end
271 |
272 | % Use structure data or prepend full path, if specified:
273 |
274 | if optionStruct.Struct
275 | output = {fileData};
276 | elseif ~isempty(output) && optionStruct.PrependPath
277 | output = fullfilecell(rootPath, output);
278 | end
279 |
280 | end
281 |
282 | end
283 |
284 | else % Return nothing
285 |
286 | output = {};
287 |
288 | end
289 |
290 | % Check recursion depth:
291 |
292 | if (depth < optionStruct.Depth)
293 |
294 | % Select subdirectories to recurse down:
295 |
296 | if ~optionStruct.RecurseInvalid
297 | subDirs = subDirs(validIndex);
298 | validIndex = validIndex(validIndex);
299 | end
300 |
301 | % Recursively collect output from subdirectories:
302 |
303 | nSubDirs = numel(subDirs);
304 | if (nSubDirs > 0)
305 | subDirs = fullfilecell(rootPath, subDirs);
306 | output = {output; cell(nSubDirs, 1)};
307 | for iSub = 1:nSubDirs
308 | output{iSub+1} = dirPlus_core(subDirs{iSub}, optionStruct, ...
309 | depth+1, validIndex(iSub));
310 | end
311 | output = vertcat(output{:});
312 | end
313 |
314 | end
315 |
316 | end
317 |
318 | %~~~End local functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------------------------------------------------------------
/scripts/matlab/generate_structure_images.m:
--------------------------------------------------------------------------------
1 | function generate_structure_images(dataset_path, output_path)
2 | addpath('code');
3 | image_list = dirPlus(dataset_path, 'FileFilter', '\.(jpg|png|tif)$');
4 | num_image = numel(image_list);
5 | for i=1:num_image
6 | image_name = image_list{i};
7 | image = im2double(imread(image_name));
8 | S = tsmooth(image, 0.015, 3, 0.001, 3);
9 | write_name = strrep(image_name, dataset_path, output_path);
10 | [filepath,~,~] = fileparts(write_name);
11 | if ~exist(filepath, 'dir')
12 | mkdir(filepath);
13 | end
14 | imwrite(S, write_name);
15 |
16 | if mod(i,100)==1
17 | fprintf('total: %d; output: %d; completed: %f%% \n',num_image, i, (i/num_image)*100) ;
18 | end
19 | end
20 | end
21 |
22 |
--------------------------------------------------------------------------------
/scripts/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import torch
4 | import numpy as np
5 | from scipy.misc import imread
6 | from scipy import linalg
7 | from torch.nn.functional import adaptive_avg_pool2d
8 | from inception import InceptionV3
9 | from skimage.measure import compare_ssim
10 | from skimage.measure import compare_psnr
11 | import glob
12 | import argparse
13 |
14 |
15 |
16 | class FID():
17 | """docstring for FID
18 | Calculates the Frechet Inception Distance (FID) to evalulate GANs
19 | The FID metric calculates the distance between two distributions of images.
20 | Typically, we have summary statistics (mean & covariance matrix) of one
21 | of these distributions, while the 2nd distribution is given by a GAN.
22 | When run as a stand-alone program, it compares the distribution of
23 | images that are stored as PNG/JPEG at a specified location with a
24 | distribution given by summary statistics (in pickle format).
25 | The FID is calculated by assuming that X_1 and X_2 are the activations of
26 | the pool_3 layer of the inception net for generated samples and real world
27 | samples respectivly.
28 | See --help to see further details.
29 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
30 | of Tensorflow
31 | Copyright 2018 Institute of Bioinformatics, JKU Linz
32 | Licensed under the Apache License, Version 2.0 (the "License");
33 | you may not use this file except in compliance with the License.
34 | You may obtain a copy of the License at
35 | http://www.apache.org/licenses/LICENSE-2.0
36 | Unless required by applicable law or agreed to in writing, software
37 | distributed under the License is distributed on an "AS IS" BASIS,
38 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39 | See the License for the specific language governing permissions and
40 | limitations under the License.
41 | """
42 | def __init__(self):
43 | self.dims = 2048
44 | self.batch_size = 64
45 | self.cuda = True
46 | self.verbose=False
47 |
48 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
49 | self.model = InceptionV3([block_idx])
50 | if self.cuda:
51 | # TODO: put model into specific GPU
52 | self.model.cuda()
53 |
54 | def __call__(self, images, gt_path):
55 | """ images: list of the generated image. The values must lie between 0 and 1.
56 | gt_path: the path of the ground truth images. The values must lie between 0 and 1.
57 | """
58 | if not os.path.exists(gt_path):
59 | raise RuntimeError('Invalid path: %s' % gt_path)
60 |
61 |
62 | print('calculate gt_path statistics...')
63 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
64 | print('calculate generated_images statistics...')
65 | m2, s2 = self.calculate_activation_statistics(images, self.verbose)
66 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
67 | return fid_value
68 |
69 |
70 | def calculate_from_disk(self, generated_path, gt_path):
71 | """
72 | """
73 | if not os.path.exists(gt_path):
74 | raise RuntimeError('Invalid path: %s' % gt_path)
75 | if not os.path.exists(generated_path):
76 | raise RuntimeError('Invalid path: %s' % generated_path)
77 |
78 | print('calculate gt_path statistics...')
79 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
80 | print('calculate generated_path statistics...')
81 | m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose)
82 | print('calculate frechet distance...')
83 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
84 | print('fid_distance %f' % (fid_value))
85 | return fid_value
86 |
87 |
88 | def compute_statistics_of_path(self, path, verbose):
89 | npz_file = os.path.join(path, 'statistics.npz')
90 | if os.path.exists(npz_file):
91 | f = np.load(npz_file)
92 | m, s = f['mu'][:], f['sigma'][:]
93 | f.close()
94 | else:
95 | path = pathlib.Path(path)
96 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
97 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
98 |
99 | # Bring images to shape (B, 3, H, W)
100 | imgs = imgs.transpose((0, 3, 1, 2))
101 |
102 | # Rescale images to be between 0 and 1
103 | imgs /= 255
104 |
105 | m, s = self.calculate_activation_statistics(imgs, verbose)
106 | np.savez(npz_file, mu=m, sigma=s)
107 |
108 | return m, s
109 |
110 | def calculate_activation_statistics(self, images, verbose):
111 | """Calculation of the statistics used by the FID.
112 | Params:
113 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
114 | must lie between 0 and 1.
115 | -- model : Instance of inception model
116 | -- batch_size : The images numpy array is split into batches with
117 | batch size batch_size. A reasonable batch size
118 | depends on the hardware.
119 | -- dims : Dimensionality of features returned by Inception
120 | -- cuda : If set to True, use GPU
121 | -- verbose : If set to True and parameter out_step is given, the
122 | number of calculated batches is reported.
123 | Returns:
124 | -- mu : The mean over samples of the activations of the pool_3 layer of
125 | the inception model.
126 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
127 | the inception model.
128 | """
129 | act = self.get_activations(images, verbose)
130 | mu = np.mean(act, axis=0)
131 | sigma = np.cov(act, rowvar=False)
132 | return mu, sigma
133 |
134 |
135 |
136 | def get_activations(self, images, verbose=False):
137 | """Calculates the activations of the pool_3 layer for all images.
138 | Params:
139 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
140 | must lie between 0 and 1.
141 | -- model : Instance of inception model
142 | -- batch_size : the images numpy array is split into batches with
143 | batch size batch_size. A reasonable batch size depends
144 | on the hardware.
145 | -- dims : Dimensionality of features returned by Inception
146 | -- cuda : If set to True, use GPU
147 | -- verbose : If set to True and parameter out_step is given, the number
148 | of calculated batches is reported.
149 | Returns:
150 | -- A numpy array of dimension (num images, dims) that contains the
151 | activations of the given tensor when feeding inception with the
152 | query tensor.
153 | """
154 | self.model.eval()
155 |
156 | d0 = images.shape[0]
157 | if self.batch_size > d0:
158 | print(('Warning: batch size is bigger than the data size. '
159 | 'Setting batch size to data size'))
160 | self.batch_size = d0
161 |
162 | n_batches = d0 // self.batch_size
163 | n_used_imgs = n_batches * self.batch_size
164 |
165 | pred_arr = np.empty((n_used_imgs, self.dims))
166 | for i in range(n_batches):
167 | if verbose:
168 | print('\rPropagating batch %d/%d' % (i + 1, n_batches))
169 | # end='', flush=True)
170 | start = i * self.batch_size
171 | end = start + self.batch_size
172 |
173 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
174 | # batch = Variable(batch, volatile=True)
175 | if self.cuda:
176 | batch = batch.cuda()
177 |
178 | pred = self.model(batch)[0]
179 |
180 | # If model output is not scalar, apply global spatial average pooling.
181 | # This happens if you choose a dimensionality not equal 2048.
182 | if pred.shape[2] != 1 or pred.shape[3] != 1:
183 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
184 |
185 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)
186 |
187 | if verbose:
188 | print(' done')
189 |
190 | return pred_arr
191 |
192 |
193 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
194 | """Numpy implementation of the Frechet Distance.
195 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
196 | and X_2 ~ N(mu_2, C_2) is
197 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
198 | Stable version by Dougal J. Sutherland.
199 | Params:
200 | -- mu1 : Numpy array containing the activations of a layer of the
201 | inception net (like returned by the function 'get_predictions')
202 | for generated samples.
203 | -- mu2 : The sample mean over activations, precalculated on an
204 | representive data set.
205 | -- sigma1: The covariance matrix over activations for generated samples.
206 | -- sigma2: The covariance matrix over activations, precalculated on an
207 | representive data set.
208 | Returns:
209 | -- : The Frechet Distance.
210 | """
211 |
212 | mu1 = np.atleast_1d(mu1)
213 | mu2 = np.atleast_1d(mu2)
214 |
215 | sigma1 = np.atleast_2d(sigma1)
216 | sigma2 = np.atleast_2d(sigma2)
217 |
218 | assert mu1.shape == mu2.shape, \
219 | 'Training and test mean vectors have different lengths'
220 | assert sigma1.shape == sigma2.shape, \
221 | 'Training and test covariances have different dimensions'
222 |
223 | diff = mu1 - mu2
224 |
225 | # Product might be almost singular
226 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
227 | if not np.isfinite(covmean).all():
228 | msg = ('fid calculation produces singular product; '
229 | 'adding %s to diagonal of cov estimates') % eps
230 | print(msg)
231 | offset = np.eye(sigma1.shape[0]) * eps
232 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
233 |
234 | # Numerical error might give slight imaginary component
235 | if np.iscomplexobj(covmean):
236 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
237 | m = np.max(np.abs(covmean.imag))
238 | raise ValueError('Imaginary component {}'.format(m))
239 | covmean = covmean.real
240 |
241 | tr_covmean = np.trace(covmean)
242 |
243 | return (diff.dot(diff) + np.trace(sigma1) +
244 | np.trace(sigma2) - 2 * tr_covmean)
245 |
246 |
247 | class Reconstruction_Metrics():
248 | def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True):
249 | self.data_range = data_range
250 | self.win_size = win_size
251 | self.multichannel = multichannel
252 | for metric in metric_list:
253 | if metric in ['ssim', 'psnr', 'l1', 'mae']:
254 | setattr(self, metric, True)
255 | else:
256 | print('unsupport reconstruction metric: %s'%metric)
257 |
258 |
259 | def __call__(self, inputs, gts):
260 | """
261 | inputs: the generated image, size (b,c,w,h), data range(0, data_range)
262 | gts: the ground-truth image, size (b,c,w,h), data range(0, data_range)
263 | """
264 | result = dict()
265 | [b,n,w,h] = inputs.size()
266 | inputs = inputs.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
267 | gts = gts.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
268 |
269 | if hasattr(self, 'ssim'):
270 | ssim_value = compare_ssim(inputs, gts, data_range=self.data_range,
271 | win_size=self.win_size, multichannel=self.multichannel)
272 | result['ssim'] = ssim_value
273 |
274 | if hasattr(self, 'psnr'):
275 | psnr_value = compare_psnr(inputs, gts, self.data_range)
276 | result['psnr'] = psnr_value
277 |
278 | if hasattr(self, 'l1'):
279 | l1_value = compare_l1(inputs, gts)
280 | result['l1'] = l1_value
281 |
282 | if hasattr(self, 'mae'):
283 | mae_value = compare_mae(inputs, gts)
284 | result['mae'] = mae_value
285 | return result
286 |
287 |
288 | def calculate_from_disk(self, inputs, gts, save_path=None, debug=0):
289 | """
290 | inputs: .txt files, floders, image files (string), image files (list)
291 | gts: .txt files, floders, image files (string), image files (list)
292 | """
293 | input_image_list = sorted(get_image_list(inputs))
294 | gt_image_list = sorted(get_image_list(gts))
295 | print(len(input_image_list))
296 | print(len(gt_image_list))
297 |
298 | psnr = []
299 | ssim = []
300 | mae = []
301 | l1 = []
302 | names = []
303 |
304 | for index in range(len(input_image_list)):
305 | name = os.path.basename(input_image_list[index])
306 | names.append(name)
307 |
308 | img_gt = (imread(str(gt_image_list[index]))).astype(np.float32) / 255.0
309 | img_pred = (imread(str(input_image_list[index]))).astype(np.float32) / 255.0
310 |
311 |
312 | if debug != 0:
313 | plt.subplot('121')
314 | plt.imshow(img_gt)
315 | plt.title('Groud truth')
316 | plt.subplot('122')
317 | plt.imshow(img_pred)
318 | plt.title('Output')
319 | plt.show()
320 |
321 | psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range))
322 | ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range,
323 | win_size=self.win_size,multichannel=self.multichannel))
324 | mae.append(compare_mae(img_gt, img_pred))
325 | l1.append(compare_l1(img_gt, img_pred))
326 |
327 |
328 | if np.mod(index, 200) == 0:
329 | print(
330 | str(index) + ' images processed',
331 | "PSNR: %.4f" % round(np.mean(psnr), 4),
332 | "SSIM: %.4f" % round(np.mean(ssim), 4),
333 | "MAE: %.4f" % round(np.mean(mae), 4),
334 | "l1: %.4f" % round(np.mean(l1), 4),
335 | )
336 |
337 | if save_path:
338 | np.savez(save_path + '/metrics.npz', psnr=psnr, ssim=ssim, mae=mae, names=names)
339 |
340 | print(
341 | "PSNR: %.4f" % round(np.mean(psnr), 4),
342 | "PSNR Variance: %.4f" % round(np.var(psnr), 4),
343 | "SSIM: %.4f" % round(np.mean(ssim), 4),
344 | "SSIM Variance: %.4f" % round(np.var(ssim), 4),
345 | "MAE: %.4f" % round(np.mean(mae), 4),
346 | "MAE Variance: %.4f" % round(np.var(mae), 4),
347 | "l1: %.4f" % round(np.mean(l1), 4),
348 | "l1 Variance: %.4f" % round(np.var(l1), 4)
349 | )
350 |
351 |
352 | def get_image_list(flist):
353 | if isinstance(flist, list):
354 | return flist
355 |
356 | # flist: image file path, image directory path, text file flist path
357 | if isinstance(flist, str):
358 | if os.path.isdir(flist):
359 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
360 | flist.sort()
361 | return flist
362 |
363 | if os.path.isfile(flist):
364 | try:
365 | return np.genfromtxt(flist, dtype=np.str)
366 | except:
367 | return [flist]
368 | print('can not read files from %s return empty list'%flist)
369 | return []
370 |
371 | def compare_l1(img_true, img_test):
372 | img_true = img_true.astype(np.float32)
373 | img_test = img_test.astype(np.float32)
374 | return np.mean(np.abs(img_true - img_test))
375 |
376 | def compare_mae(img_true, img_test):
377 | img_true = img_true.astype(np.float32)
378 | img_test = img_test.astype(np.float32)
379 | return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test)
380 |
381 | if __name__ == "__main__":
382 | fid = FID()
383 | rec = Reconstruction_Metrics()
384 |
385 | parser = argparse.ArgumentParser(description='script to compute all statistics')
386 | parser.add_argument('--input_path', help='Path to ground truth data', type=str)
387 | parser.add_argument('--output_path', help='Path to output data', type=str)
388 | parser.add_argument('--fid_real_path', help='Path to real images when calculate FID', type=str)
389 | args = parser.parse_args()
390 |
391 | for arg in vars(args):
392 | print('[%s] =' % arg, getattr(args, arg))
393 |
394 | # rec.calculate_from_disk(args.input_path, args.output_path, save_path=args.output_path)
395 | fid.calculate_from_disk(args.output_path, args.fid_real_path)
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
--------------------------------------------------------------------------------
/src/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | from .utils import weights_init, get_iteration
8 |
9 |
10 | class BaseModel(nn.Module):
11 | def __init__(self, name, config):
12 | super(BaseModel, self).__init__()
13 | self.config = config
14 | self.samples_path = os.path.join(config.PATH, config.NAME, 'images')
15 | self.checkpoints_path = os.path.join(config.PATH, config.NAME, 'checkpoints')
16 |
17 | def save(self, which_epoch):
18 | """Save all the networks to the disk"""
19 | for net_name in self.net_name:
20 | if hasattr(self, net_name) and not(self.config.MODEL == 3 and 's_' in net_name):
21 | sub_net = getattr(self, net_name)
22 | save_filename = '%s_net_%s.pth' % (which_epoch, net_name)
23 | save_path = os.path.join(self.checkpoints_path, save_filename)
24 | torch.save(sub_net.state_dict(), save_path)
25 |
26 |
27 | def load(self, which_epoch):
28 | for net_name in self.net_name:
29 | if hasattr(self, net_name):
30 | sub_net = getattr(self, net_name)
31 | filename = '%s_net_%s.pth' % (which_epoch, net_name)
32 | model_name = os.path.join(self.checkpoints_path, filename)
33 | if not os.path.isfile(model_name):
34 | print('checkpoint %s do not exist'%model_name)
35 | continue
36 | self.load_networks(model_name, sub_net, net_name)
37 | self.iterations = get_iteration(self.checkpoints_path, filename, net_name)
38 | print('Resume %s from iteration %s' % (net_name, which_epoch))
39 |
40 | sub_net_opt = getattr(self, net_name+'_opt')
41 | setattr(self, net_name+'_scheduler', self.get_scheduler(sub_net_opt))
42 |
43 |
44 | def load_networks(self, path, net, name):
45 | """Load all the networks from the disk"""
46 | try:
47 | net.load_state_dict(torch.load(path))
48 | except:
49 | pretrained_dict = torch.load(path)
50 | model_dict = net.state_dict()
51 | try:
52 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict}
53 | net.load_state_dict(pretrained_dict)
54 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % name)
55 | except:
56 | print('Pretrained network %s has fewer layers; The following are not initialized:' % name)
57 | not_initialized = set()
58 | for k, v in pretrained_dict.items():
59 | if v.size() == model_dict[k].size():
60 | model_dict[k] = v
61 |
62 | for k, v in model_dict.items():
63 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
64 | not_initialized.add(k)
65 | print(sorted(not_initialized))
66 | net.load_state_dict(model_dict)
67 |
68 |
69 | def init(self):
70 | for net_name in self.net_name:
71 | if hasattr(self, net_name):
72 | sub_net = getattr(self, net_name)
73 | sub_net.apply(weights_init(self.config.INIT_TYPE))
74 |
75 |
76 | def define_optimizer(self):
77 | for net_name in self.net_name:
78 | if hasattr(self, net_name):
79 | sub_net = getattr(self, net_name)
80 | optimizer = optim.Adam(sub_net.parameters(),lr=self.config.LR,betas=(self.config.BETA1, self.config.BETA2))
81 | scheduler = self.get_scheduler(optimizer)
82 | setattr(self, net_name+'_opt', optimizer)
83 | setattr(self, net_name+'_scheduler', scheduler)
84 |
85 |
86 | def get_scheduler(self, optimizer):
87 | if self.config.LR_POLICY == None or self.config.LR_POLICY == 'constant':
88 | scheduler = None
89 | elif self.config.LR_POLICY == 'step':
90 | scheduler = lr_scheduler.StepLR(optimizer, step_size=self.config.STEP_SIZE,
91 | gamma=self.config.GAMMA, last_epoch = self.iterations-1)
92 | else:
93 | return NotImplementedError('learning rate policy [%s] is not implemented', self.config.LR_POLICY)
94 | return scheduler
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | class Config(dict):
5 | def __init__(self, opts, mode):
6 | with open(opts.config, 'r') as f:
7 | self._yaml = f.read()
8 | self._dict = yaml.load(self._yaml)
9 |
10 | self.modify_param(opts, mode)
11 |
12 |
13 | def __getattr__(self, name):
14 | if self._dict.get(name) is not None:
15 | return self._dict[name]
16 | return None
17 |
18 | def print(self):
19 | print('Model configurations:')
20 | print('---------------------------------')
21 | print(self._yaml)
22 | print('')
23 | print('---------------------------------')
24 | print('')
25 |
26 |
27 | def modify_param(self, opts, mode):
28 | self._dict['PATH'] = opts.path
29 | self._dict['NAME'] = opts.name
30 | self._dict['RESUME_ALL'] = opts.resume_all
31 |
32 | if mode == 'test':
33 | self._dict['DATA_TEST_GT'] = opts.input
34 | self._dict['DATA_TEST_MASK'] = opts.mask
35 | self._dict['DATA_TEST_STRUCTURE'] = opts.structure
36 | self._dict['DATA_TEST_RESULTS'] = opts.output
37 | self._dict['MODEL'] = opts.model
38 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | import os
4 | import os.path
5 | import glob
6 | from torchvision import transforms
7 | import torchvision.transforms.functional as transFunc
8 | import random
9 | import numpy as np
10 | import torch.nn.functional as F
11 | import math
12 | import scipy.io as scio
13 | import torch.utils.data as data
14 | from PIL import Image
15 | from torch.utils.data import DataLoader
16 | import cv2
17 |
18 |
19 | ## TODO: choose with or without transformation at test mode
20 | class Dataset(data.Dataset):
21 | def __init__(self, gt_file, structure_file, config, mask_file=None):
22 | self.gt_image_files = self.load_file_list(gt_file)
23 | self.structure_image_files = self.load_file_list(structure_file)
24 |
25 | if len(self.gt_image_files) == 0:
26 | raise(RuntimeError("Found 0 images in the input files " + "\n"))
27 |
28 | if config.MODE == 'test':
29 | self.transform_opt = {'crop': False, 'flip': False,
30 | 'resize': config.DATA_TEST_SIZE, 'random_load_mask': False}
31 | config.DATA_MASK_TYPE == 'from_file' if mask_file is not None else config.DATA_MASK_TYPE
32 | else:
33 | self.transform_opt = {'crop': config.DATA_CROP, 'flip': config.DATA_FLIP,
34 | 'resize': config.DATA_TRAIN_SIZE, 'random_load_mask': True}
35 |
36 | self.mask_type = config.DATA_MASK_TYPE
37 | # generate random rectangle mask
38 | if self.mask_type == 'random_bbox':
39 | self.mask_setting = config.DATA_RANDOM_BBOX_SETTING
40 | # generate random free form mask
41 | elif self.mask_type == 'random_free_form':
42 | self.mask_setting = config.DATA_RANDOM_FF_SETTING
43 | # read masks from files
44 | elif self.mask_type == 'from_file':
45 | self.mask_image_files = self.load_file_list(mask_file)
46 |
47 | def __getitem__(self, index):
48 | try:
49 | item = self.load_item(index)
50 | except:
51 | print('loading error: ' + self.gt_image_files[index])
52 | item = self.load_item(0)
53 | return item
54 |
55 | def __len__(self):
56 | return len(self.gt_image_files)
57 |
58 |
59 | def load_file_list(self, flist):
60 | if isinstance(flist, list):
61 | return flist
62 |
63 | # flist: image file path, image directory path, text file flist path
64 | if isinstance(flist, str):
65 | if os.path.isdir(flist):
66 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
67 | flist.sort()
68 | return flist
69 |
70 | if os.path.isfile(flist):
71 | try:
72 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
73 | except:
74 | return [flist]
75 | return []
76 |
77 |
78 | def load_item(self, index):
79 | gt_path = self.gt_image_files[index]
80 | structure_path = self.structure_image_files[index]
81 | gt_image = loader(gt_path)
82 | structure_image = loader(structure_path)
83 | transform_param = get_params(gt_image.size, self.transform_opt)
84 | gt_image, structure_image = transform_image(transform_param, gt_image, structure_image)
85 |
86 | inpaint_map = self.load_mask(index, gt_image)
87 | input_image = gt_image*(1-inpaint_map)
88 |
89 | return input_image, structure_image, gt_image, inpaint_map
90 |
91 |
92 | def load_mask(self, index, img):
93 | _, w, h = img.shape
94 | image_shape = [w, h]
95 | if self.mask_type == 'random_bbox':
96 | bboxs = []
97 | for i in range(self.mask_setting['num']):
98 | bbox = random_bbox(self.mask_setting, image_shape)
99 | bboxs.append(bbox)
100 | mask = bbox2mask(bboxs, image_shape, self.mask_setting)
101 | return torch.from_numpy(mask)
102 |
103 | elif self.mask_type == 'random_free_form':
104 | mask = random_ff_mask(self.mask_setting, image_shape)
105 | return torch.from_numpy(mask)
106 |
107 | elif self.mask_type == 'from_file':
108 | if self.transform_opt['random_load_mask']:
109 | index = np.random.randint(0, len(self.mask_image_files))
110 | mask = gray_loader(self.mask_image_files[index])
111 | if random.random() > 0.5:
112 | mask = transFunc.hflip(mask)
113 | if random.random() > 0.5:
114 | mask = transFunc.vflip(mask)
115 | else:
116 | mask = gray_loader(self.mask_image_files[index])
117 | mask = transFunc.resize(mask, size=image_shape)
118 | mask = transFunc.to_tensor(mask)
119 | mask = (mask > 0).float()
120 | return mask
121 | else:
122 | raise(RuntimeError("No such mask type: %s"%self.mask_type))
123 |
124 | def load_name(self, index, add_mask_name=False):
125 | name = self.gt_image_files[index]
126 | name = os.path.basename(name)
127 |
128 | if not add_mask_name:
129 | return name
130 | else:
131 | if len(self.mask_image_files)==0:
132 | return name
133 | else:
134 | mask_name = os.path.basename(self.mask_image_files[index])
135 | mask_name, _ = os.path.splitext(mask_name)
136 | name, ext = os.path.splitext(name)
137 | name = name+'_'+mask_name+ext
138 | return name
139 |
140 |
141 |
142 | def create_iterator(self, batch_size):
143 | while True:
144 | sample_loader = DataLoader(
145 | dataset=self,
146 | batch_size=batch_size,
147 | drop_last=True
148 | )
149 |
150 | for item in sample_loader:
151 | yield item
152 |
153 |
154 | def random_bbox(config, shape):
155 | """Generate a random tlhw with configuration.
156 | Args:
157 | config: Config should have configuration including DATA_NEW_SHAPE,
158 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
159 | Returns:
160 | tuple: (top, left, height, width)
161 | """
162 | img_height = shape[0]
163 | img_width = shape[1]
164 | height, width = config['shape']
165 | ver_margin, hor_margin = config['margin']
166 | maxt = img_height - ver_margin - height
167 | maxl = img_width - hor_margin - width
168 | t = np.random.randint(low=ver_margin, high=maxt)
169 | l = np.random.randint(low=hor_margin, high=maxl)
170 | h = height
171 | w = width
172 | return (t, l, h, w)
173 |
174 | def random_ff_mask( config, shape):
175 | """Generate a random free form mask with configuration.
176 |
177 | Args:
178 | config: Config should have configuration including DATA_NEW_SHAPES,
179 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
180 |
181 | Returns:
182 | tuple: (top, left, height, width)
183 | """
184 |
185 | h,w = shape
186 | mask = np.zeros((h,w))
187 | num_v = 12+np.random.randint(config['mv']) #tf.random_uniform([], minval=0, maxval=config.MAXVERTEX, dtype=tf.int32)
188 |
189 | for i in range(num_v):
190 | start_x = np.random.randint(w)
191 | start_y = np.random.randint(h)
192 | for j in range(1+np.random.randint(5)):
193 | angle = 0.01+np.random.randint(config['ma'])
194 | if i % 2 == 0:
195 | angle = 2 * 3.1415926 - angle
196 | length = 10+np.random.randint(config['ml'])
197 | brush_w = 10+np.random.randint(config['mbw'])
198 | end_x = (start_x + length * np.sin(angle)).astype(np.int32)
199 | end_y = (start_y + length * np.cos(angle)).astype(np.int32)
200 |
201 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
202 | start_x, start_y = end_x, end_y
203 |
204 | return mask.reshape((1,)+mask.shape).astype(np.float32)
205 |
206 |
207 | def bbox2mask( bboxs, shape, config):
208 | """Generate mask tensor from bbox.
209 |
210 | Args:
211 | bbox: configuration tuple, (top, left, height, width)
212 | config: Config should have configuration including DATA_NEW_SHAPES,
213 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.
214 |
215 | Returns:
216 | tf.Tensor: output with shape [1, H, W, 1]
217 |
218 | """
219 | height, width = shape
220 | mask = np.zeros(( height, width), np.float32)
221 | #print(mask.shape)
222 | for bbox in bboxs:
223 | if config['random_size']:
224 | h = int(0.1*bbox[2])+np.random.randint(int(bbox[2]*0.2+1))
225 | w = int(0.1*bbox[3])+np.random.randint(int(bbox[3]*0.2)+1)
226 | else:
227 | h=0
228 | w=0
229 | mask[bbox[0]+h:bbox[0]+bbox[2]-h,
230 | bbox[1]+w:bbox[1]+bbox[3]-w] = 1.
231 | #print("after", mask.shape)
232 | return mask.reshape((1,)+mask.shape).astype(np.float32)
233 |
234 |
235 | def gray_loader( path):
236 | return Image.open(path)
237 |
238 | def loader( path):
239 | return Image.open(path).convert('RGB')
240 |
241 | def get_params(size, transform_opt):
242 | w, h = size
243 | if transform_opt['flip']:
244 | flip = random.random() > 0.5
245 | else:
246 | flip = False
247 | if transform_opt['crop']:
248 | transform_crop = transform_opt['crop'] \
249 | if w>=transform_opt['crop'][0] and h>=transform_opt['crop'][1] else [h, w]
250 | x = random.randint(0, np.maximum(0, w - transform_crop[0]))
251 | y = random.randint(0, np.maximum(0, h - transform_crop[1]))
252 | crop = [x, y, transform_crop[0], transform_crop[1]]
253 | else:
254 | crop = False
255 | if transform_opt['resize']:
256 | resize = [transform_opt['resize'], transform_opt['resize'],]
257 | else:
258 | resize = False
259 | param = {'crop': crop, 'flip': flip, 'resize': resize}
260 | return param
261 |
262 | def transform_image(transform_param, gt_image, structure_image, normalize=True, toTensor=True):
263 | transform_list = []
264 |
265 | if transform_param['crop']:
266 | crop_position = transform_param['crop'][:2]
267 | crop_size = transform_param['crop'][2:]
268 | transform_list.append(transforms.Lambda(lambda img: __crop(img, crop_position, crop_size)))
269 | if transform_param['resize']:
270 | transform_list.append(transforms.Resize(transform_param['resize']))
271 | if transform_param['flip']:
272 | transform_list.append(transforms.Lambda(lambda img: __flip(img, True)))
273 |
274 | if toTensor:
275 | transform_list += [transforms.ToTensor()]
276 |
277 | if normalize:
278 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
279 | (0.5, 0.5, 0.5))]
280 | trans = transforms.Compose(transform_list)
281 | if gt_image.size != structure_image.size:
282 | structure_image = transFunc.resize(structure_image, size=gt_image.size)
283 |
284 | gt_image = trans(gt_image)
285 | structure_image = trans(structure_image)
286 |
287 | return gt_image, structure_image
288 |
289 | def __crop(img, pos, size):
290 | ow, oh = img.size
291 | x1, y1 = pos
292 | tw, th = size
293 | return img.crop((x1, y1, x1 + tw, y1 + th))
294 |
295 | def __flip(img, flip):
296 | if flip:
297 | return img.transpose(Image.FLIP_LEFT_RIGHT)
298 | return img
299 |
300 |
301 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | import torch.nn.functional as F
5 | from .resample2d import Resample2d
6 |
7 | from .utils import write_2images
8 |
9 | class AdversarialLoss(nn.Module):
10 | r"""
11 | Adversarial loss
12 | https://arxiv.org/abs/1711.10337
13 | """
14 |
15 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
16 | r"""
17 | type = nsgan | lsgan | hinge
18 | """
19 | super(AdversarialLoss, self).__init__()
20 |
21 | self.type = type
22 | self.register_buffer('real_label', torch.tensor(target_real_label))
23 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
24 |
25 | if type == 'nsgan':
26 | self.criterion = nn.BCELoss()
27 |
28 | elif type == 'lsgan':
29 | self.criterion = nn.MSELoss()
30 |
31 | elif type == 'hinge':
32 | self.criterion = nn.ReLU()
33 |
34 | def __call__(self, outputs, is_real, for_dis=None):
35 | if self.type == 'hinge':
36 | if for_dis:
37 | if is_real:
38 | outputs = -outputs
39 | return self.criterion(1 + outputs).mean()
40 | else:
41 | return (-outputs).mean()
42 |
43 | else:
44 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
45 | loss = self.criterion(outputs, labels)
46 | return loss
47 |
48 |
49 | class StyleLoss(nn.Module):
50 | r"""
51 | Perceptual loss, VGG-based
52 | https://arxiv.org/abs/1603.08155
53 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
54 | """
55 |
56 | def __init__(self):
57 | super(StyleLoss, self).__init__()
58 | self.add_module('vgg', VGG19())
59 | self.criterion = torch.nn.L1Loss()
60 |
61 | def compute_gram(self, x):
62 | b, ch, h, w = x.size()
63 | f = x.view(b, ch, w * h)
64 | f_T = f.transpose(1, 2)
65 | G = f.bmm(f_T) / (h * w * ch)
66 |
67 | return G
68 |
69 | def __call__(self, x, y):
70 | # Compute features
71 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
72 |
73 | # Compute loss
74 | style_loss = 0.0
75 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
76 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
77 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
78 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
79 |
80 | return style_loss
81 |
82 |
83 |
84 | class PerceptualLoss(nn.Module):
85 | r"""
86 | Perceptual loss, VGG-based
87 | https://arxiv.org/abs/1603.08155
88 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
89 | """
90 |
91 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
92 | super(PerceptualLoss, self).__init__()
93 | self.add_module('vgg', VGG19())
94 | self.criterion = torch.nn.L1Loss()
95 | self.weights = weights
96 |
97 | def __call__(self, x, y):
98 | # Compute features
99 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
100 | content_loss = 0.0
101 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
102 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
103 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
104 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
105 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
106 |
107 | return content_loss
108 |
109 |
110 | class PerceptualCorrectness(nn.Module):
111 | r"""
112 |
113 | """
114 |
115 | def __init__(self, layer='relu3_1'):
116 | super(PerceptualCorrectness, self).__init__()
117 | self.add_module('vgg', VGG19())
118 | self.layer = layer
119 | self.eps=1e-8
120 | self.resample = Resample2d(4, 1, sigma=2)
121 |
122 | def __call__(self, gts, inputs, flow, maps):
123 | gts_vgg, inputs_vgg = self.vgg(gts), self.vgg(inputs)
124 | gts_vgg = gts_vgg[self.layer]
125 | inputs_vgg = inputs_vgg[self.layer]
126 | [b, c, h, w] = gts_vgg.shape
127 |
128 | maps = F.interpolate(maps, [h,w]).view(b,-1)
129 | flow = F.adaptive_avg_pool2d(flow, [h,w])
130 |
131 | gts_all = gts_vgg.view(b, c, -1) #[b C N2]
132 | inputs_all = inputs_vgg.view(b, c, -1).transpose(1,2) #[b N2 C]
133 |
134 |
135 | input_norm = inputs_all/(inputs_all.norm(dim=2, keepdim=True)+self.eps)
136 | gt_norm = gts_all/(gts_all.norm(dim=1, keepdim=True)+self.eps)
137 | correction = torch.bmm(input_norm, gt_norm) #[b N2 N2]
138 | (correction_max,max_indices) = torch.max(correction, dim=1)
139 |
140 | # interple with gaussian sampling
141 | input_sample = self.resample(inputs_vgg, flow)
142 | input_sample = input_sample.view(b, c, -1) #[b C N2]
143 |
144 | correction_sample = F.cosine_similarity(input_sample, gts_all) #[b 1 N2]
145 | loss_map = torch.exp(-correction_sample/(correction_max+self.eps))
146 | loss = torch.sum(loss_map*maps)/(torch.sum(maps)+self.eps)
147 |
148 | return loss
149 |
150 |
151 |
152 |
153 | class VGG19(torch.nn.Module):
154 | def __init__(self):
155 | super(VGG19, self).__init__()
156 | features = models.vgg19(pretrained=True).features
157 | self.relu1_1 = torch.nn.Sequential()
158 | self.relu1_2 = torch.nn.Sequential()
159 |
160 | self.relu2_1 = torch.nn.Sequential()
161 | self.relu2_2 = torch.nn.Sequential()
162 |
163 | self.relu3_1 = torch.nn.Sequential()
164 | self.relu3_2 = torch.nn.Sequential()
165 | self.relu3_3 = torch.nn.Sequential()
166 | self.relu3_4 = torch.nn.Sequential()
167 |
168 | self.relu4_1 = torch.nn.Sequential()
169 | self.relu4_2 = torch.nn.Sequential()
170 | self.relu4_3 = torch.nn.Sequential()
171 | self.relu4_4 = torch.nn.Sequential()
172 |
173 | self.relu5_1 = torch.nn.Sequential()
174 | self.relu5_2 = torch.nn.Sequential()
175 | self.relu5_3 = torch.nn.Sequential()
176 | self.relu5_4 = torch.nn.Sequential()
177 |
178 | for x in range(2):
179 | self.relu1_1.add_module(str(x), features[x])
180 |
181 | for x in range(2, 4):
182 | self.relu1_2.add_module(str(x), features[x])
183 |
184 | for x in range(4, 7):
185 | self.relu2_1.add_module(str(x), features[x])
186 |
187 | for x in range(7, 9):
188 | self.relu2_2.add_module(str(x), features[x])
189 |
190 | for x in range(9, 12):
191 | self.relu3_1.add_module(str(x), features[x])
192 |
193 | for x in range(12, 14):
194 | self.relu3_2.add_module(str(x), features[x])
195 |
196 | for x in range(14, 16):
197 | self.relu3_2.add_module(str(x), features[x])
198 |
199 | for x in range(16, 18):
200 | self.relu3_4.add_module(str(x), features[x])
201 |
202 | for x in range(18, 21):
203 | self.relu4_1.add_module(str(x), features[x])
204 |
205 | for x in range(21, 23):
206 | self.relu4_2.add_module(str(x), features[x])
207 |
208 | for x in range(23, 25):
209 | self.relu4_3.add_module(str(x), features[x])
210 |
211 | for x in range(25, 27):
212 | self.relu4_4.add_module(str(x), features[x])
213 |
214 | for x in range(27, 30):
215 | self.relu5_1.add_module(str(x), features[x])
216 |
217 | for x in range(30, 32):
218 | self.relu5_2.add_module(str(x), features[x])
219 |
220 | for x in range(32, 34):
221 | self.relu5_3.add_module(str(x), features[x])
222 |
223 | for x in range(34, 36):
224 | self.relu5_4.add_module(str(x), features[x])
225 |
226 | # don't need the gradients, just want the features
227 | for param in self.parameters():
228 | param.requires_grad = False
229 |
230 | def forward(self, x):
231 | relu1_1 = self.relu1_1(x)
232 | relu1_2 = self.relu1_2(relu1_1)
233 |
234 | relu2_1 = self.relu2_1(relu1_2)
235 | relu2_2 = self.relu2_2(relu2_1)
236 |
237 | relu3_1 = self.relu3_1(relu2_2)
238 | relu3_2 = self.relu3_2(relu3_1)
239 | relu3_3 = self.relu3_3(relu3_2)
240 | relu3_4 = self.relu3_4(relu3_3)
241 |
242 | relu4_1 = self.relu4_1(relu3_4)
243 | relu4_2 = self.relu4_2(relu4_1)
244 | relu4_3 = self.relu4_3(relu4_2)
245 | relu4_4 = self.relu4_4(relu4_3)
246 |
247 | relu5_1 = self.relu5_1(relu4_4)
248 | relu5_2 = self.relu5_2(relu5_1)
249 | relu5_3 = self.relu5_3(relu5_2)
250 | relu5_4 = self.relu5_4(relu5_3)
251 |
252 | out = {
253 | 'relu1_1': relu1_1,
254 | 'relu1_2': relu1_2,
255 |
256 | 'relu2_1': relu2_1,
257 | 'relu2_2': relu2_2,
258 |
259 | 'relu3_1': relu3_1,
260 | 'relu3_2': relu3_2,
261 | 'relu3_3': relu3_3,
262 | 'relu3_4': relu3_4,
263 |
264 | 'relu4_1': relu4_1,
265 | 'relu4_2': relu4_2,
266 | 'relu4_3': relu4_3,
267 | 'relu4_4': relu4_4,
268 |
269 | 'relu5_1': relu5_1,
270 | 'relu5_2': relu5_2,
271 | 'relu5_3': relu5_3,
272 | 'relu5_4': relu5_4,
273 | }
274 | return out
275 |
--------------------------------------------------------------------------------
/src/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 | When run as a stand-alone program, it compares the distribution of
7 | images that are stored as PNG/JPEG at a specified location with a
8 | distribution given by summary statistics (in pickle format).
9 | The FID is calculated by assuming that X_1 and X_2 are the activations of
10 | the pool_3 layer of the inception net for generated samples and real world
11 | samples respectivly.
12 | See --help to see further details.
13 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
14 | of Tensorflow
15 | Copyright 2018 Institute of Bioinformatics, JKU Linz
16 | Licensed under the Apache License, Version 2.0 (the "License");
17 | you may not use this file except in compliance with the License.
18 | You may obtain a copy of the License at
19 | http://www.apache.org/licenses/LICENSE-2.0
20 | Unless required by applicable law or agreed to in writing, software
21 | distributed under the License is distributed on an "AS IS" BASIS,
22 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | See the License for the specific language governing permissions and
24 | limitations under the License.
25 | """
26 | import os
27 | import pathlib
28 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
29 |
30 | import torch
31 | import numpy as np
32 | from scipy.misc import imread
33 | from scipy import linalg
34 | from torch.autograd import Variable
35 | from torch.nn.functional import adaptive_avg_pool2d
36 |
37 | from .inception import InceptionV3
38 |
39 |
40 |
41 | class FID():
42 | """docstring for FID"""
43 | def __init__(self):
44 | self.dims = 2048
45 | self.batch_size = 64
46 | self.cuda = True
47 |
48 |
49 |
50 |
51 | def __call__(self, images, gt_path):
52 | if not os.path.exists(gt_path):
53 | raise RuntimeError('Invalid path: %s' % gt_path)
54 |
55 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
56 |
57 | self.model = InceptionV3([block_idx])
58 | if self.cuda:
59 | # TODO: put model into specific GPU
60 | self.model.cuda()
61 |
62 | print('calculate gt_path statistics...')
63 | m1, s1 = self.compute_statistics_of_path(gt_path)
64 | print('calculate generated_images statistics...')
65 | m2, s2 = self.calculate_activation_statistics(images)
66 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
67 | return fid_value
68 |
69 |
70 |
71 | def compute_statistics_of_path(self, path):
72 | npz_file = os.path.join(path, 'statistics.npz')
73 | if os.path.exists(npz_file):
74 | f = np.load(npz_file)
75 | m, s = f['mu'][:], f['sigma'][:]
76 | f.close()
77 | else:
78 | path = pathlib.Path(path)
79 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
80 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
81 |
82 | # Bring images to shape (B, 3, H, W)
83 | imgs = imgs.transpose((0, 3, 1, 2))
84 |
85 | # Rescale images to be between 0 and 1
86 | imgs /= 255
87 |
88 | m, s = self.calculate_activation_statistics(imgs)
89 | np.savez(npz_file, mu=m, sigma=s)
90 |
91 | return m, s
92 |
93 | def calculate_activation_statistics(self, images, verbose=False):
94 | """Calculation of the statistics used by the FID.
95 | Params:
96 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
97 | must lie between 0 and 1.
98 | -- model : Instance of inception model
99 | -- batch_size : The images numpy array is split into batches with
100 | batch size batch_size. A reasonable batch size
101 | depends on the hardware.
102 | -- dims : Dimensionality of features returned by Inception
103 | -- cuda : If set to True, use GPU
104 | -- verbose : If set to True and parameter out_step is given, the
105 | number of calculated batches is reported.
106 | Returns:
107 | -- mu : The mean over samples of the activations of the pool_3 layer of
108 | the inception model.
109 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
110 | the inception model.
111 | """
112 | act = self.get_activations(images, True)
113 | mu = np.mean(act, axis=0)
114 | sigma = np.cov(act, rowvar=False)
115 | return mu, sigma
116 |
117 |
118 |
119 | def get_activations(self, images, verbose=False):
120 | """Calculates the activations of the pool_3 layer for all images.
121 | Params:
122 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
123 | must lie between 0 and 1.
124 | -- model : Instance of inception model
125 | -- batch_size : the images numpy array is split into batches with
126 | batch size batch_size. A reasonable batch size depends
127 | on the hardware.
128 | -- dims : Dimensionality of features returned by Inception
129 | -- cuda : If set to True, use GPU
130 | -- verbose : If set to True and parameter out_step is given, the number
131 | of calculated batches is reported.
132 | Returns:
133 | -- A numpy array of dimension (num images, dims) that contains the
134 | activations of the given tensor when feeding inception with the
135 | query tensor.
136 | """
137 | self.model.eval()
138 |
139 | d0 = images.shape[0]
140 | if self.batch_size > d0:
141 | print(('Warning: batch size is bigger than the data size. '
142 | 'Setting batch size to data size'))
143 | self.batch_size = d0
144 |
145 | n_batches = d0 // self.batch_size
146 | n_used_imgs = n_batches * self.batch_size
147 |
148 | pred_arr = np.empty((n_used_imgs, self.dims))
149 | for i in range(n_batches):
150 | if verbose:
151 | print('\rPropagating batch %d/%d' % (i + 1, n_batches),
152 | end='', flush=True)
153 | start = i * self.batch_size
154 | end = start + self.batch_size
155 |
156 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
157 | # batch = Variable(batch, volatile=True)
158 | if self.cuda:
159 | batch = batch.cuda()
160 |
161 | pred = self.model(batch)[0]
162 |
163 | # If model output is not scalar, apply global spatial average pooling.
164 | # This happens if you choose a dimensionality not equal 2048.
165 | if pred.shape[2] != 1 or pred.shape[3] != 1:
166 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
167 |
168 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)
169 |
170 | if verbose:
171 | print(' done')
172 |
173 | return pred_arr
174 |
175 |
176 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
177 | """Numpy implementation of the Frechet Distance.
178 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
179 | and X_2 ~ N(mu_2, C_2) is
180 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
181 | Stable version by Dougal J. Sutherland.
182 | Params:
183 | -- mu1 : Numpy array containing the activations of a layer of the
184 | inception net (like returned by the function 'get_predictions')
185 | for generated samples.
186 | -- mu2 : The sample mean over activations, precalculated on an
187 | representive data set.
188 | -- sigma1: The covariance matrix over activations for generated samples.
189 | -- sigma2: The covariance matrix over activations, precalculated on an
190 | representive data set.
191 | Returns:
192 | -- : The Frechet Distance.
193 | """
194 |
195 | mu1 = np.atleast_1d(mu1)
196 | mu2 = np.atleast_1d(mu2)
197 |
198 | sigma1 = np.atleast_2d(sigma1)
199 | sigma2 = np.atleast_2d(sigma2)
200 |
201 | assert mu1.shape == mu2.shape, \
202 | 'Training and test mean vectors have different lengths'
203 | assert sigma1.shape == sigma2.shape, \
204 | 'Training and test covariances have different dimensions'
205 |
206 | diff = mu1 - mu2
207 |
208 | # Product might be almost singular
209 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
210 | if not np.isfinite(covmean).all():
211 | msg = ('fid calculation produces singular product; '
212 | 'adding %s to diagonal of cov estimates') % eps
213 | print(msg)
214 | offset = np.eye(sigma1.shape[0]) * eps
215 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
216 |
217 | # Numerical error might give slight imaginary component
218 | if np.iscomplexobj(covmean):
219 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
220 | m = np.max(np.abs(covmean.imag))
221 | raise ValueError('Imaginary component {}'.format(m))
222 | covmean = covmean.real
223 |
224 | tr_covmean = np.trace(covmean)
225 |
226 | return (diff.dot(diff) + np.trace(sigma1) +
227 | np.trace(sigma2) - 2 * tr_covmean)
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .base_model import BaseModel
4 | from .network import StructureGen
5 | from .network import MultiDiscriminator
6 | from .network import FlowGen
7 | from .loss import AdversarialLoss, PerceptualCorrectness, StyleLoss, PerceptualLoss
8 |
9 |
10 | class StructureFlowModel(BaseModel):
11 | def __init__(self, config):
12 | super(StructureFlowModel, self).__init__('StructureFlow', config)
13 | self.config = config
14 | self.net_name = ['s_gen', 's_dis', 'f_gen', 'f_dis']
15 |
16 | self.structure_param = {'input_dim':3, 'dim':64, 'n_res':4, 'activ':'relu',
17 | 'norm':'in', 'pad_type':'reflect', 'use_sn':True}
18 | self.flow_param = {'input_dim':3, 'dim':64, 'n_res':2, 'activ':'relu',
19 | 'norm_conv':'ln', 'norm_flow':'in', 'pad_type':'reflect', 'use_sn':False}
20 | self.dis_param = {'input_dim':3, 'dim':64, 'n_layers':3,
21 | 'norm':'none', 'activ':'lrelu', 'pad_type':'reflect', 'use_sn':True}
22 |
23 | l1_loss = nn.L1Loss()
24 | adversarial_loss = AdversarialLoss(type=config.DIS_GAN_LOSS)
25 | correctness_loss = PerceptualCorrectness()
26 | vgg_style = StyleLoss()
27 | vgg_content = PerceptualLoss()
28 | self.use_correction_loss = True
29 | self.use_vgg_loss = True if self.config.MODEL == 3 else False
30 |
31 | self.add_module('l1_loss', l1_loss)
32 | self.add_module('adversarial_loss', adversarial_loss)
33 | self.add_module('correctness_loss', correctness_loss)
34 | self.add_module('vgg_style', vgg_style)
35 | self.add_module('vgg_content', vgg_content)
36 |
37 | self.build_model()
38 |
39 | def build_model(self):
40 | self.iterations = 0
41 | # structure model
42 | if self.config.MODEL == 1:
43 | self.s_gen = StructureGen(**self.structure_param)
44 | self.s_dis = MultiDiscriminator(**self.dis_param)
45 | # flow model with true input smooth
46 | elif self.config.MODEL == 2:
47 | self.f_gen = FlowGen(**self.flow_param)
48 | self.f_dis = MultiDiscriminator(**self.dis_param)
49 | # flow model with fake input smooth
50 | elif self.config.MODEL == 3:
51 | self.s_gen = StructureGen(**self.structure_param)
52 | self.f_gen = FlowGen(**self.flow_param)
53 | self.f_dis = MultiDiscriminator(**self.dis_param)
54 |
55 | self.define_optimizer()
56 | self.init()
57 |
58 |
59 | def structure_forward(self, inputs, smooths, maps):
60 | smooths_input = smooths*(1-maps)
61 | outputs = self.s_gen(torch.cat((inputs, smooths_input, maps),dim=1))
62 | return outputs
63 |
64 | def flow_forward(self, inputs, stage_1, maps):
65 | outputs, flow = self.f_gen(torch.cat((inputs, stage_1, maps),dim=1))
66 | return outputs, flow
67 |
68 | def sample(self, inputs, smooths, gts, maps):
69 | with torch.no_grad():
70 | if self.config.MODEL == 1:
71 | outputs = self.structure_forward(inputs, smooths, maps)
72 | result =[inputs,smooths,gts,maps,outputs]
73 | flow = None
74 | elif self.config.MODEL == 2:
75 | outputs, flow = self.flow_forward(inputs, smooths, maps)
76 | result =[inputs,smooths,gts,maps,outputs]
77 | if flow is not None:
78 | flow = [flow[:,0,:,:].unsqueeze(1)/30, flow[:,1,:,:].unsqueeze(1)/30]
79 |
80 | elif self.config.MODEL == 3:
81 | smooth_stage_1 = self.structure_forward(inputs, smooths, maps)
82 | outputs, flow = self.flow_forward(inputs, smooth_stage_1, maps)
83 | result =[inputs,smooths,gts,maps,smooth_stage_1,outputs]
84 | if flow is not None:
85 | flow = [flow[:,0,:,:].unsqueeze(1)/30, flow[:,1,:,:].unsqueeze(1)/30]
86 | return result, flow
87 |
88 | def update_structure(self, inputs, smooths, maps):
89 | self.iterations += 1
90 |
91 | self.s_gen.zero_grad()
92 | self.s_dis.zero_grad()
93 | outputs = self.structure_forward(inputs, smooths, maps)
94 |
95 | dis_loss = 0
96 | dis_fake_input = outputs.detach()
97 | dis_real_input = smooths
98 | fake_labels = self.s_dis(dis_fake_input)
99 | real_labels = self.s_dis(dis_real_input)
100 | for i in range(len(fake_labels)):
101 | dis_real_loss = self.adversarial_loss(real_labels[i], True, True)
102 | dis_fake_loss = self.adversarial_loss(fake_labels[i], False, True)
103 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
104 | self.structure_adv_dis_loss = dis_loss/len(fake_labels)
105 |
106 | self.structure_adv_dis_loss.backward()
107 | self.s_dis_opt.step()
108 | if self.s_dis_scheduler is not None:
109 | self.s_dis_scheduler.step()
110 |
111 |
112 | dis_gen_loss = 0
113 | fake_labels = self.s_dis(outputs)
114 | for i in range(len(fake_labels)):
115 | dis_fake_loss = self.adversarial_loss(fake_labels[i], True, False)
116 | dis_gen_loss += dis_fake_loss
117 | self.structure_adv_gen_loss = dis_gen_loss/len(fake_labels) * self.config.STRUCTURE_ADV_GEN
118 | self.structure_l1_loss = self.l1_loss(outputs, smooths) * self.config.STRUCTURE_L1
119 | self.structure_gen_loss = self.structure_l1_loss + self.structure_adv_gen_loss
120 |
121 | self.structure_gen_loss.backward()
122 | self.s_gen_opt.step()
123 | if self.s_gen_scheduler is not None:
124 | self.s_gen_scheduler.step()
125 |
126 | logs = [
127 | ("l_s_adv_dis", self.structure_adv_dis_loss.item()),
128 | ("l_s_l1", self.structure_l1_loss.item()),
129 | ("l_s_adv_gen", self.structure_adv_gen_loss.item()),
130 | ("l_s_gen", self.structure_gen_loss.item()),
131 | ]
132 | return logs
133 |
134 |
135 | def update_flow(self, inputs, smooths, gts, maps, use_correction_loss, use_vgg_loss):
136 | self.iterations += 1
137 |
138 | self.f_dis.zero_grad()
139 | self.f_gen.zero_grad()
140 | outputs, flow_maps = self.flow_forward(inputs, smooths, maps)
141 |
142 |
143 | dis_loss = 0
144 | dis_fake_input = outputs.detach()
145 | dis_real_input = gts
146 | fake_labels = self.f_dis(dis_fake_input)
147 | real_labels = self.f_dis(dis_real_input)
148 | # self.flow_adv_dis_loss = (dis_real_loss + dis_fake_loss) / 2
149 | for i in range(len(fake_labels)):
150 | dis_real_loss = self.adversarial_loss(real_labels[i], True, True)
151 | dis_fake_loss = self.adversarial_loss(fake_labels[i], False, True)
152 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
153 | self.flow_adv_dis_loss = dis_loss/len(fake_labels)
154 |
155 | self.flow_adv_dis_loss.backward()
156 | self.f_dis_opt.step()
157 | if self.f_dis_scheduler is not None:
158 | self.f_dis_scheduler.step()
159 |
160 |
161 | dis_gen_loss = 0
162 | fake_labels = self.f_dis(outputs)
163 | for i in range(len(fake_labels)):
164 | dis_fake_loss = self.adversarial_loss(fake_labels[i], True, False)
165 | dis_gen_loss += dis_fake_loss
166 | self.flow_adv_gen_loss = dis_gen_loss/len(fake_labels) * self.config.FLOW_ADV_GEN
167 | self.flow_l1_loss = self.l1_loss(outputs, gts) * self.config.FLOW_L1
168 | self.flow_correctness_loss = self.correctness_loss(gts, inputs, flow_maps, maps)* \
169 | self.config.FLOW_CORRECTNESS if use_correction_loss else 0
170 |
171 |
172 | if use_vgg_loss:
173 | self.vgg_loss_style = self.vgg_style(outputs*maps, gts*maps)*self.config.VGG_STYLE
174 | self.vgg_loss_content = self.vgg_content(outputs, gts)*self.config.VGG_CONTENT
175 | self.vgg_loss = self.vgg_loss_style + self.vgg_loss_content
176 | else:
177 | self.vgg_loss = 0
178 |
179 | self.flow_loss = self.flow_adv_gen_loss + self.flow_l1_loss + self.flow_correctness_loss + self.vgg_loss
180 |
181 | self.flow_loss.backward()
182 | self.f_gen_opt.step()
183 |
184 | if self.f_gen_scheduler is not None:
185 | self.f_gen_scheduler.step()
186 |
187 |
188 | logs = [
189 | ("l_f_adv_dis", self.flow_adv_dis_loss.item()),
190 | ("l_f_adv_gen", self.flow_adv_gen_loss.item()),
191 | ("l_f_l1_gen", self.flow_l1_loss.item()),
192 | ("l_f_total_gen", self.flow_loss.item()),
193 | ]
194 | if use_correction_loss:
195 | logs = logs + [("l_f_correctness_gen", self.flow_correctness_loss.item())]
196 | if use_vgg_loss:
197 | logs = logs + [("l_f_vgg_style", self.vgg_loss_style.item())]
198 | logs = logs + [("l_f_vgg_content", self.vgg_loss_content.item())]
199 | return logs
200 |
--------------------------------------------------------------------------------
/src/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from torch.nn.utils.spectral_norm import spectral_norm
6 | # from spatial_correlation_sampler import spatial_correlation_sample
7 | from .resample2d import Resample2d
8 | import torchvision.utils as vutils
9 |
10 |
11 | class Discriminator(nn.Module):
12 | def __init__(self, input_dim=3, dim=64, n_layers=3,
13 | norm='none', activ='lrelu', pad_type='reflect', use_sn=True):
14 | super(Discriminator, self).__init__()
15 |
16 | self.model = nn.ModuleList()
17 | self.model.append(Conv2dBlock(input_dim,dim,4,2,1,'none',activ,pad_type,use_sn=use_sn))
18 | dim_in = dim
19 | for i in range(n_layers - 1):
20 | dim_out = min(dim*8, dim_in*2)
21 | self.model.append(DownsampleResBlock(dim_in,dim_out,'none',activ,pad_type,use_sn))
22 | dim_in = dim_out
23 |
24 | self.model.append(Conv2dBlock(dim_in,1,1,1,activation='none',use_bias=False, use_sn=use_sn))
25 | self.model = nn.Sequential(*self.model)
26 |
27 | def forward(self, x):
28 | return self.model(x)
29 |
30 | class MultiDiscriminator(nn.Module):
31 | def __init__(self, **parameter_dic):
32 | super(MultiDiscriminator, self).__init__()
33 | self.model_1 = Discriminator(**parameter_dic)
34 | self.down = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
35 | self.model_2 = Discriminator(**parameter_dic)
36 |
37 | def forward(self, x):
38 | pre1 = self.model_1(x)
39 | pre2 = self.model_2(self.down(x))
40 | return [pre1, pre2]
41 |
42 |
43 | class StructureGen(nn.Module):
44 | def __init__(self, input_dim=3, dim=64, n_res=4, activ='relu',
45 | norm='in', pad_type='reflect', use_sn=True):
46 | super(StructureGen, self).__init__()
47 |
48 | self.down_sample=nn.ModuleList()
49 | self.up_sample=nn.ModuleList()
50 | self.content_param=nn.ModuleList()
51 |
52 | self.input_layer = Conv2dBlock(input_dim*2+1, dim, 7, 1, 3, norm, activ, pad_type, use_sn=use_sn)
53 | self.down_sample += [nn.Sequential(
54 | Conv2dBlock(dim, 2*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
55 | Conv2dBlock(2*dim, 2*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))]
56 |
57 | self.down_sample += [nn.Sequential(
58 | Conv2dBlock(2*dim, 4*dim, 4, 2, 1,norm, activ, pad_type, use_sn=use_sn),
59 | Conv2dBlock(4*dim, 4*dim, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn))]
60 |
61 | self.down_sample += [nn.Sequential(
62 | Conv2dBlock(4*dim, 8*dim, 4, 2, 1,norm, activ, pad_type, use_sn=use_sn))]
63 | dim = 8*dim
64 | # content decoder
65 | self.up_sample += [(nn.Sequential(
66 | ResBlocks(n_res, dim, norm, activ, pad_type=pad_type),
67 | nn.Upsample(scale_factor=2),
68 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)) )]
69 |
70 | self.up_sample += [(nn.Sequential(
71 | ResBlocks(n_res, dim//2, norm, activ, pad_type=pad_type),
72 | nn.Upsample(scale_factor=2),
73 | Conv2dBlock(dim//2, dim//4, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn)) )]
74 |
75 | self.up_sample += [(nn.Sequential(
76 | ResBlocks(n_res, dim//4, norm, activ, pad_type=pad_type),
77 | nn.Upsample(scale_factor=2),
78 | Conv2dBlock(dim//4, dim//8, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn)) )]
79 |
80 | self.content_param += [Conv2dBlock(dim//2, dim//2, 5, 1, 2, norm, activ, pad_type)]
81 | self.content_param += [Conv2dBlock(dim//4, dim//4, 5, 1, 2, norm, activ, pad_type)]
82 | self.content_param += [Conv2dBlock(dim//8, dim//8, 5, 1, 2, norm, activ, pad_type)]
83 |
84 | self.image_net = Get_image(dim//8, input_dim)
85 |
86 | def forward(self, inputs):
87 | x0 = self.input_layer(inputs)
88 | x1 = self.down_sample[0](x0)
89 | x2 = self.down_sample[1](x1)
90 | x3 = self.down_sample[2](x2)
91 |
92 | u1 = self.up_sample[0](x3) + self.content_param[0](x2)
93 | u2 = self.up_sample[1](u1) + self.content_param[1](x1)
94 | u3 = self.up_sample[2](u2) + self.content_param[2](x0)
95 |
96 | images_out = self.image_net(u3)
97 | return images_out
98 |
99 |
100 | class FlowGen(nn.Module):
101 | def __init__(self, input_dim=3, dim=64, n_res=2, activ='relu',
102 | norm_flow='ln', norm_conv='in', pad_type='reflect', use_sn=True):
103 | super(FlowGen, self).__init__()
104 |
105 | self.flow_column = FlowColumn(input_dim, dim, n_res, activ,
106 | norm_flow, pad_type, use_sn)
107 | self.conv_column = ConvColumn(input_dim, dim, n_res, activ,
108 | norm_conv, pad_type, use_sn)
109 |
110 | def forward(self, inputs):
111 | flow_map = self.flow_column(inputs)
112 | images_out = self.conv_column(inputs, flow_map)
113 | return images_out, flow_map
114 |
115 |
116 |
117 | class ConvColumn(nn.Module):
118 | def __init__(self, input_dim=3, dim=64, n_res=2, activ='lrelu',
119 | norm='ln', pad_type='reflect', use_sn=True):
120 | super(ConvColumn, self).__init__()
121 |
122 | self.down_sample = nn.ModuleList()
123 | self.up_sample = nn.ModuleList()
124 |
125 |
126 | self.down_sample += [nn.Sequential(
127 | Conv2dBlock(input_dim*2+1, dim//2, 7, 1, 3, norm, activ, pad_type, use_sn=use_sn),
128 | Conv2dBlock(dim//2, dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
129 | Conv2dBlock(dim, dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn),
130 | Conv2dBlock(dim, 2*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
131 | Conv2dBlock(2*dim, 2*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn))]
132 |
133 | self.down_sample += [nn.Sequential(
134 | Conv2dBlock(2*dim, 4*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn),
135 | Conv2dBlock(4*dim, 8*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn))]
136 | dim = 8*dim
137 |
138 | # content decoder
139 | self.up_sample += [(nn.Sequential(
140 | ResBlocks(n_res, dim, norm, activ, pad_type=pad_type),
141 | nn.Upsample(scale_factor=2),
142 | Conv2dBlock(dim, dim//2, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)) )]
143 |
144 | self.up_sample += [(nn.Sequential(
145 | Conv2dBlock(dim, dim//2, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn),
146 | ResBlocks(n_res, dim//2, norm, activ, pad_type=pad_type),
147 | nn.Upsample(scale_factor=2),
148 | Conv2dBlock(dim//2, dim//4, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn),
149 |
150 | ResBlocks(n_res, dim//4, norm, activ, pad_type=pad_type),
151 | nn.Upsample(scale_factor=2),
152 | Conv2dBlock(dim//4, dim//8, 5, 1, 2,norm, activ, pad_type, use_sn=use_sn),
153 | Get_image(dim//8, input_dim)) )]
154 |
155 | self.resample16 = Resample2d(16, 1, sigma=4)
156 | self.resample4 = Resample2d(4, 1, sigma=2)
157 |
158 |
159 | def forward(self, inputs, flow_maps):
160 | x1 = self.down_sample[0](inputs)
161 | x2 = self.down_sample[1](x1)
162 | flow_fea = self.resample_image(x1, flow_maps)
163 |
164 | u1 = torch.cat((self.up_sample[0](x2), flow_fea), 1)
165 | images_out = self.up_sample[1](u1)
166 | return images_out
167 |
168 | def resample_image(self, img, flow):
169 | output16 = self.resample16(img, flow)
170 | output4 = self.resample4 (img, flow)
171 | outputs = torch.cat((output16,output4), 1)
172 | return outputs
173 |
174 |
175 |
176 | class FlowColumn(nn.Module):
177 | def __init__(self, input_dim=3, dim=64, n_res=2, activ='lrelu',
178 | norm='in', pad_type='reflect', use_sn=True):
179 | super(FlowColumn, self).__init__()
180 |
181 | self.down_sample_flow = nn.ModuleList()
182 | self.up_sample_flow = nn.ModuleList()
183 |
184 | self.down_sample_flow.append( nn.Sequential(
185 | Conv2dBlock(input_dim*2+1, dim//2, 7, 1, 3, norm, activ, pad_type, use_sn=use_sn),
186 | Conv2dBlock(dim//2, dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
187 | Conv2dBlock( dim, dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)))
188 | self.down_sample_flow.append( nn.Sequential(
189 | Conv2dBlock( dim, 2*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
190 | Conv2dBlock(2*dim, 2*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)))
191 | self.down_sample_flow.append(nn.Sequential(
192 | Conv2dBlock(2*dim, 4*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
193 | Conv2dBlock(4*dim, 4*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)))
194 | self.down_sample_flow.append(nn.Sequential(
195 | Conv2dBlock(4*dim, 8*dim, 4, 2, 1, norm, activ, pad_type, use_sn=use_sn),
196 | Conv2dBlock(8*dim, 8*dim, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn)))
197 | dim = 8*dim
198 |
199 | # content decoder
200 | self.up_sample_flow.append(nn.Sequential(
201 | ResBlocks(n_res, dim, norm, activ, pad_type=pad_type),
202 | TransConv2dBlock(dim, dim//2, 6, 2, 2, norm=norm, activation=activ) ))
203 |
204 | self.up_sample_flow.append(nn.Sequential(
205 | Conv2dBlock(dim, dim//2, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn),
206 | ResBlocks(n_res, dim//2, norm, activ, pad_type=pad_type),
207 | TransConv2dBlock(dim//2, dim//4, 6, 2, 2, norm=norm, activation=activ) ))
208 |
209 | self.location = nn.Sequential(
210 | Conv2dBlock(dim//2, dim//8, 5, 1, 2, norm, activ, pad_type, use_sn=use_sn),
211 | Conv2dBlock(dim//8, 2, 3, 1, 1, norm='none', activation='none', pad_type=pad_type, use_bias=False) )
212 |
213 | def forward(self, inputs):
214 | f_x1 = self.down_sample_flow[0](inputs)
215 | f_x2 = self.down_sample_flow[1](f_x1)
216 | f_x3 = self.down_sample_flow[2](f_x2)
217 | f_x4 = self.down_sample_flow[3](f_x3)
218 |
219 | f_u1 = torch.cat((self.up_sample_flow[0](f_x4), f_x3), 1)
220 | f_u2 = torch.cat((self.up_sample_flow[1](f_u1), f_x2), 1)
221 | flow_map = self.location(f_u2)
222 | return flow_map
223 |
224 |
225 |
226 | ##################################################################################
227 | # Basic Blocks
228 | ##################################################################################
229 | class Get_image(nn.Module):
230 | def __init__(self, input_dim, output_dim, activation='tanh'):
231 | super(Get_image, self).__init__()
232 | self.conv = Conv2dBlock(input_dim, output_dim, kernel_size=3, stride=1,
233 | padding=1, pad_type='reflect', activation=activation)
234 | def forward(self, x):
235 | return self.conv(x)
236 |
237 | class ResBlocks(nn.Module):
238 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero', use_sn=False):
239 | super(ResBlocks, self).__init__()
240 | self.model = []
241 | for i in range(num_blocks):
242 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, use_sn=use_sn)]
243 | self.model = nn.Sequential(*self.model)
244 |
245 | def forward(self, x):
246 | return self.model(x)
247 |
248 | class ResBlock(nn.Module):
249 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero', use_sn=False):
250 | super(ResBlock, self).__init__()
251 |
252 | model = []
253 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type, use_sn=use_sn)]
254 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type, use_sn=use_sn)]
255 | self.model = nn.Sequential(*model)
256 |
257 | def forward(self, x):
258 | residual = x
259 | out = self.model(x)
260 | out += residual
261 | return out
262 |
263 | class DilationBlock(nn.Module):
264 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
265 | super(DilationBlock, self).__init__()
266 |
267 | model = []
268 | model += [Conv2dBlock(dim ,dim, 3, 1, 2, norm=norm, activation=activation, pad_type=pad_type, dilation=2)]
269 | model += [Conv2dBlock(dim ,dim, 3, 1, 4, norm=norm, activation=activation, pad_type=pad_type, dilation=4)]
270 | model += [Conv2dBlock(dim ,dim, 3, 1, 8, norm=norm, activation=activation, pad_type=pad_type, dilation=8)]
271 | self.model = nn.Sequential(*model)
272 |
273 | def forward(self, x):
274 | out = self.model(x)
275 | return out
276 |
277 | class Conv2dBlock(nn.Module):
278 | def __init__(self, input_dim ,output_dim, kernel_size, stride,
279 | padding=0, norm='none', activation='relu', pad_type='zero', dilation=1,
280 | use_bias=True, use_sn=False):
281 | super(Conv2dBlock, self).__init__()
282 | self.use_bias = use_bias
283 | # initialize padding
284 | if pad_type == 'reflect':
285 | self.pad = nn.ReflectionPad2d(padding)
286 | elif pad_type == 'replicate':
287 | self.pad = nn.ReplicationPad2d(padding)
288 | elif pad_type == 'zero':
289 | self.pad = nn.ZeroPad2d(padding)
290 | else:
291 | assert 0, "Unsupported padding type: {}".format(pad_type)
292 |
293 | # initialize normalization
294 | norm_dim = output_dim
295 | if norm == 'bn':
296 | self.norm = nn.BatchNorm2d(norm_dim)
297 | elif norm == 'in':
298 | self.norm = nn.InstanceNorm2d(norm_dim)
299 | elif norm == 'ln':
300 | self.norm = LayerNorm(norm_dim)
301 | elif norm == 'adain':
302 | self.norm = AdaptiveInstanceNorm2d(norm_dim)
303 | elif norm == 'none':
304 | self.norm = None
305 | else:
306 | assert 0, "Unsupported normalization: {}".format(norm)
307 |
308 | # initialize activation
309 | if activation == 'relu':
310 | self.activation = nn.ReLU(inplace=True)
311 | elif activation == 'lrelu':
312 | self.activation = nn.LeakyReLU(0.2, inplace=True)
313 | elif activation == 'prelu':
314 | self.activation = nn.PReLU()
315 | elif activation == 'selu':
316 | self.activation = nn.SELU(inplace=True)
317 | elif activation == 'tanh':
318 | self.activation = nn.Tanh()
319 | elif activation == 'sigmoid':
320 | self.activation = nn.Sigmoid()
321 | elif activation == 'none':
322 | self.activation = None
323 | else:
324 | assert 0, "Unsupported activation: {}".format(activation)
325 |
326 | # initialize convolution
327 | if use_sn:
328 | self.conv = spectral_norm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias, dilation=dilation))
329 | else:
330 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias, dilation=dilation)
331 |
332 | def forward(self, x):
333 | x = self.conv(self.pad(x))
334 | if self.norm:
335 | x = self.norm(x)
336 | if self.activation:
337 | x = self.activation(x)
338 | return x
339 |
340 | class TransConv2dBlock(nn.Module):
341 | def __init__(self, input_dim ,output_dim, kernel_size, stride,
342 | padding=0, norm='none', activation='relu'):
343 | super(TransConv2dBlock, self).__init__()
344 | self.use_bias = True
345 |
346 | # initialize normalization
347 | norm_dim = output_dim
348 | if norm == 'bn':
349 | self.norm = nn.BatchNorm2d(norm_dim)
350 | elif norm == 'in':
351 | self.norm = nn.InstanceNorm2d(norm_dim)
352 | elif norm == 'in_affine':
353 | self.norm = nn.InstanceNorm2d(norm_dim, affine=True)
354 | elif norm == 'ln':
355 | self.norm = LayerNorm(norm_dim)
356 | elif norm == 'adain':
357 | self.norm = AdaptiveInstanceNorm2d(norm_dim)
358 | elif norm == 'none':
359 | self.norm = None
360 | else:
361 | assert 0, "Unsupported normalization: {}".format(norm)
362 |
363 | # initialize activation
364 | if activation == 'relu':
365 | self.activation = nn.ReLU(inplace=True)
366 | elif activation == 'lrelu':
367 | self.activation = nn.LeakyReLU(0.2, inplace=True)
368 | elif activation == 'prelu':
369 | self.activation = nn.PReLU()
370 | elif activation == 'selu':
371 | self.activation = nn.SELU(inplace=True)
372 | elif activation == 'tanh':
373 | self.activation = nn.Tanh()
374 | elif activation == 'sigmoid':
375 | self.activation = nn.Sigmoid()
376 | elif activation == 'none':
377 | self.activation = None
378 | else:
379 | assert 0, "Unsupported activation: {}".format(activation)
380 |
381 | # initialize convolution
382 | self.transConv = nn.ConvTranspose2d(input_dim, output_dim, kernel_size, stride, padding, bias=self.use_bias)
383 |
384 | def forward(self, x):
385 | x = self.transConv(x)
386 | if self.norm:
387 | x = self.norm(x)
388 | if self.activation:
389 | x = self.activation(x)
390 | return x
391 |
392 | class AdaptiveInstanceNorm2d(nn.Module):
393 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
394 | super(AdaptiveInstanceNorm2d, self).__init__()
395 | self.num_features = num_features
396 | self.eps = eps
397 | self.momentum = momentum
398 | # weight and bias are dynamically assigned
399 | self.weight = None
400 | self.bias = None
401 | # just dummy buffers, not used
402 | self.register_buffer('running_mean', torch.zeros(num_features))
403 | self.register_buffer('running_var', torch.ones(num_features))
404 |
405 | def forward(self, x):
406 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
407 | b, c = x.size(0), x.size(1)
408 | running_mean = self.running_mean.repeat(b)
409 | running_var = self.running_var.repeat(b)
410 |
411 | # Apply instance norm
412 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
413 |
414 | out = F.batch_norm(
415 | x_reshaped, running_mean, running_var, self.weight, self.bias,
416 | True, self.momentum, self.eps)
417 |
418 | return out.view(b, c, *x.size()[2:])
419 |
420 | def __repr__(self):
421 | return self.__class__.__name__ + '(' + str(self.num_features) + ')'
422 |
423 | class LayerNorm(nn.Module):
424 | def __init__(self, n_out, eps=1e-5, affine=True):
425 | super(LayerNorm, self).__init__()
426 | self.n_out = n_out
427 | self.affine = affine
428 |
429 | if self.affine:
430 | self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
431 | self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
432 |
433 | def forward(self, x):
434 | normalized_shape = x.size()[1:]
435 | if self.affine:
436 | return F.layer_norm(x, normalized_shape, self.weight.expand(normalized_shape), self.bias.expand(normalized_shape))
437 | else:
438 | return F.layer_norm(x, normalized_shape)
439 |
440 | class DownsampleResBlock(nn.Module):
441 | def __init__(self, input_dim, output_dim, norm='in', activation='relu', pad_type='zero', use_sn=False):
442 | super(DownsampleResBlock, self).__init__()
443 | self.conv_1 = nn.ModuleList()
444 | self.conv_2 = nn.ModuleList()
445 |
446 | self.conv_1.append(Conv2dBlock(input_dim,input_dim,3,1,1,'none',activation,pad_type,use_sn=use_sn))
447 | self.conv_1.append(Conv2dBlock(input_dim,output_dim,3,1,1,'none',activation,pad_type,use_sn=use_sn))
448 | self.conv_1.append(nn.AvgPool2d(kernel_size=2, stride=2))
449 | self.conv_1 = nn.Sequential(*self.conv_1)
450 |
451 |
452 | self.conv_2.append(nn.AvgPool2d(kernel_size=2, stride=2))
453 | self.conv_2.append(Conv2dBlock(input_dim,output_dim,1,1,0,'none',activation,pad_type,use_sn=use_sn))
454 | self.conv_2 = nn.Sequential(*self.conv_2)
455 |
456 |
457 | def forward(self, x):
458 | out = self.conv_1(x) + self.conv_2(x)
459 | return out
460 |
461 |
462 |
--------------------------------------------------------------------------------
/src/resample2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function, Variable
4 | import resample2d_cuda
5 |
6 | class Resample2dFunction(Function):
7 |
8 | @staticmethod
9 | def forward(ctx, input1, input2, kernel_size=2, dilation=1):
10 | assert input1.is_contiguous()
11 | assert input2.is_contiguous()
12 |
13 | ctx.save_for_backward(input1, input2)
14 | ctx.kernel_size = kernel_size
15 | ctx.dilation = dilation
16 |
17 | _, d, _, _ = input1.size()
18 | b, _, h, w = input2.size()
19 | output = input1.new(b, d, h, w).zero_()
20 |
21 | resample2d_cuda.forward(input1, input2, output, kernel_size, dilation)
22 |
23 | return output
24 |
25 | @staticmethod
26 | def backward(ctx, grad_output):
27 | if not grad_output.is_contiguous():
28 | grad_output.contiguous()
29 |
30 | input1, input2 = ctx.saved_tensors
31 |
32 | grad_input1 = Variable(input1.new(input1.size()).zero_())
33 | grad_input2 = Variable(input1.new(input2.size()).zero_())
34 |
35 | resample2d_cuda.backward(input1, input2, grad_output.data,
36 | grad_input1.data, grad_input2.data,
37 | ctx.kernel_size, ctx.dilation)
38 |
39 | return grad_input1, grad_input2, None, None
40 |
41 | class Resample2d(Module):
42 |
43 | def __init__(self, kernel_size=2, dilation=1, sigma=5 ):
44 | super(Resample2d, self).__init__()
45 | self.kernel_size = kernel_size
46 | self.dilation = dilation
47 | self.sigma = torch.tensor(sigma, dtype=torch.float).cuda()
48 |
49 | def forward(self, input1, input2):
50 | input1_c = input1.contiguous()
51 | sigma = self.sigma.expand(input2.size(0), 1, input2.size(2), input2.size(3)).type(input2.dtype)
52 | input2 = torch.cat((input2,sigma), 1)
53 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.dilation)
54 |
--------------------------------------------------------------------------------
/src/structure_flow.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable, grad
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import os
6 | import torch.nn.functional as F
7 | import glob
8 | import torchvision.utils as vutils
9 | import math
10 | import shutil
11 | import tensorboardX
12 | from itertools import islice
13 | from torch.utils.data import DataLoader
14 | from .data import Dataset
15 | from .utils import Progbar, write_2images, write_2tensorboard, create_dir, imsave
16 | from skimage.measure import compare_ssim
17 | from skimage.measure import compare_psnr
18 | from .models import StructureFlowModel
19 |
20 |
21 | class StructureFlow():
22 | def __init__(self, config):
23 | self.config = config
24 | self.debug=False
25 | self.flow_model = StructureFlowModel(config).to(config.DEVICE)
26 |
27 | self.samples_path = os.path.join(config.PATH, config.NAME, 'images')
28 | self.checkpoints_path = os.path.join(config.PATH, config.NAME, 'checkpoints')
29 | self.test_image_path = os.path.join(config.PATH, config.NAME, 'test_result')
30 |
31 | if self.config.MODE == 'train' and not self.config.RESUME_ALL:
32 | pass
33 | else:
34 | self.flow_model.load(self.config.WHICH_ITER)
35 |
36 | if self.config.MODEL == 1:
37 | self.stage_name='structure_reconstructor'
38 | elif self.config.MODEL == 2:
39 | self.stage_name='texture_generator'
40 | elif self.config.MODEL == 3:
41 | self.stage_name='joint_train'
42 |
43 | def train(self):
44 | train_writer = self.obtain_log(self.config)
45 | train_dataset = Dataset(self.config.DATA_TRAIN_GT, self.config.DATA_TRAIN_STRUCTURE,
46 | self.config, self.config.DATA_MASK_FILE)
47 | train_loader = DataLoader(dataset=train_dataset, batch_size=self.config.TRAIN_BATCH_SIZE,
48 | shuffle=True, drop_last=True, num_workers=8)
49 |
50 | val_dataset = Dataset(self.config.DATA_VAL_GT, self.config.DATA_VAL_STRUCTURE,
51 | self.config, self.config.DATA_MASK_FILE)
52 | sample_iterator = val_dataset.create_iterator(self.config.SAMPLE_SIZE)
53 |
54 |
55 | iterations = self.flow_model.iterations
56 | total = len(train_dataset)
57 | epoch = math.floor(iterations*self.config.TRAIN_BATCH_SIZE/total)
58 | keep_training = True
59 | model = self.config.MODEL
60 | max_iterations = int(float(self.config.MAX_ITERS))
61 |
62 | while(keep_training):
63 | epoch += 1
64 | print('\n\nTraining epoch: %d' % epoch)
65 |
66 | progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter'])
67 |
68 | for items in train_loader:
69 | inputs, smooths, gts, maps = self.cuda(*items)
70 |
71 | # structure model
72 | if model == 1:
73 | logs = self.flow_model.update_structure(inputs, smooths, maps)
74 | iterations = self.flow_model.iterations
75 | # flow model
76 | elif model == 2:
77 | logs = self.flow_model.update_flow(inputs, smooths, gts, maps, self.flow_model.use_correction_loss, self.flow_model.use_vgg_loss)
78 | iterations = self.flow_model.iterations
79 | # flow with structure model
80 | elif model == 3:
81 | with torch.no_grad():
82 | smooth_stage_1 = self.flow_model.structure_forward(inputs, smooths, maps)
83 | logs = self.flow_model.update_flow(inputs, smooth_stage_1.detach(), gts, maps, self.flow_model.use_correction_loss, self.flow_model.use_vgg_loss)
84 | iterations = self.flow_model.iterations
85 |
86 | if iterations >= max_iterations:
87 | keep_training = False
88 | break
89 |
90 | # print(logs)
91 | logs = [
92 | ("epoch", epoch),
93 | ("iter", iterations),
94 | ] + logs
95 |
96 | progbar.add(len(inputs), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')])
97 |
98 | # log model
99 | if self.config.LOG_INTERVAL and iterations % self.config.LOG_INTERVAL == 0:
100 | self.write_loss(logs, train_writer)
101 | # sample model
102 | if self.config.SAMPLE_INTERVAL and iterations % self.config.SAMPLE_INTERVAL == 0:
103 | items = next(sample_iterator)
104 | inputs, smooths, gts, maps = self.cuda(*items)
105 | result,flow = self.flow_model.sample(inputs, smooths, gts, maps)
106 | self.write_image(result, train_writer, iterations, 'image')
107 | self.write_image(flow, train_writer, iterations, 'flow')
108 | # evaluate model
109 | if self.config.EVAL_INTERVAL and iterations % self.config.EVAL_INTERVAL == 0:
110 | self.flow_model.eval()
111 | print('\nstart eval...\n')
112 | self.eval(writer=train_writer)
113 | self.flow_model.train()
114 |
115 | # save the latest model
116 | if self.config.SAVE_LATEST and iterations % self.config.SAVE_LATEST == 0:
117 | print('\nsaving the latest model (total_steps %d)\n' % (iterations))
118 | self.flow_model.save('latest')
119 |
120 | # save the model
121 | if self.config.SAVE_INTERVAL and iterations % self.config.SAVE_INTERVAL == 0:
122 | print('\nsaving the model of iterations %d\n' % iterations)
123 | self.flow_model.save(iterations)
124 | print('\nEnd training....')
125 |
126 |
127 | def eval(self, writer=None):
128 | val_dataset = Dataset(self.config.DATA_VAL_GT , self.config.DATA_VAL_STRUCTURE, self.config, self.config.DATA_VAL_MASK)
129 | val_loader = DataLoader(
130 | dataset=val_dataset,
131 | batch_size = self.config.TRAIN_BATCH_SIZE,
132 | shuffle=False
133 | )
134 | model = self.config.MODEL
135 | total = len(val_dataset)
136 | iterations = self.flow_model.iterations
137 |
138 | progbar = Progbar(total, width=20, stateful_metrics=['it'])
139 | iteration = 0
140 | psnr_list = []
141 |
142 | # TODO: add fid score to evaluate
143 | with torch.no_grad():
144 | # for items in val_loader:
145 | for j, items in enumerate(islice(val_loader, 50)):
146 |
147 | logs = []
148 | iteration += 1
149 | inputs, smooths, gts, maps = self.cuda(*items)
150 | if model == 1:
151 | outputs_structure = self.flow_model.structure_forward(inputs, smooths, maps)
152 | psnr, ssim, l1 = self.metrics(outputs_structure, smooths)
153 | logs.append(('psnr', psnr.item()))
154 | psnr_list.append(psnr.item())
155 |
156 | # inpaint model
157 | elif model == 2:
158 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooths, maps)
159 | psnr, ssim, l1 = self.metrics(outputs, gts)
160 | logs.append(('psnr', psnr.item()))
161 | psnr_list.append(psnr.item())
162 |
163 |
164 | # inpaint with structure model
165 | elif model == 3:
166 | smooth_stage_1 = self.flow_model.structure_forward(inputs, smooths, maps)
167 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooth_stage_1, maps)
168 | psnr, ssim, l1 = self.metrics(outputs, gts)
169 | logs.append(('psnr', psnr.item()))
170 | psnr_list.append(psnr.item())
171 |
172 | logs = [("it", iteration), ] + logs
173 | progbar.add(len(inputs), values=logs)
174 |
175 | avg_psnr = np.average(psnr_list)
176 |
177 | if writer is not None:
178 | writer.add_scalar('eval_psnr', avg_psnr, iterations)
179 |
180 | print('model eval at iterations:%d'%iterations)
181 | print('average psnr:%f'%avg_psnr)
182 |
183 |
184 | def test(self):
185 | self.flow_model.eval()
186 |
187 | model = self.config.MODEL
188 | print(self.config.DATA_TEST_RESULTS)
189 | create_dir(self.config.DATA_TEST_RESULTS)
190 | test_dataset = Dataset(self.config.DATA_TEST_GT, self.config.DATA_TEST_STRUCTURE, self.config, self.config.DATA_TEST_MASK)
191 | test_loader = DataLoader(
192 | dataset=test_dataset,
193 | batch_size=8,
194 | )
195 |
196 | index = 0
197 | with torch.no_grad():
198 | for items in test_loader:
199 | inputs, smooths, gts, maps = self.cuda(*items)
200 |
201 | # structure model
202 | if model == 1:
203 | outputs = self.flow_model.structure_forward(inputs, smooths, maps)
204 | outputs_merged = (outputs * maps) + (smooths * (1 - maps))
205 |
206 | # flow model
207 | elif model == 2:
208 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooths, maps)
209 | outputs_merged = (outputs * maps) + (gts * (1 - maps))
210 |
211 |
212 | # inpaint with structure model / joint model
213 | else:
214 | smooth_stage_1 = self.flow_model.structure_forward(inputs, smooths, maps)
215 | outputs, flow_maps = self.flow_model.flow_forward(inputs, smooth_stage_1, maps)
216 | outputs_merged = (outputs * maps) + (gts * (1 - maps))
217 |
218 | outputs_merged = self.postprocess(outputs_merged)*255.0
219 | inputs_show = inputs + maps
220 |
221 |
222 | for i in range(outputs_merged.size(0)):
223 | name = test_dataset.load_name(index, self.debug)
224 | print(index, name)
225 | path = os.path.join(self.config.DATA_TEST_RESULTS, name)
226 | imsave(outputs_merged[i,:,:,:].unsqueeze(0), path)
227 | index += 1
228 |
229 | if self.debug and model == 3:
230 | smooth_ = self.postprocess(smooth_stage_1[i,:,:,:].unsqueeze(0))*255.0
231 | inputs_ = self.postprocess(inputs_show[i,:,:,:].unsqueeze(0))*255.0
232 | gts_ = self.postprocess(gts[i,:,:,:].unsqueeze(0))*255.0
233 | print(path)
234 | fname, fext = os.path.splitext(path)
235 | imsave(smooth_, fname+'_smooth.'+fext)
236 | imsave(inputs_, fname+'_inputs.'+fext)
237 | imsave(gts_, fname+'_gts.'+fext)
238 |
239 | print('\nEnd test....')
240 |
241 |
242 |
243 | def obtain_log(self, config):
244 | log_dir = os.path.join(config.PATH, config.NAME, self.stage_name+'_log')
245 | if os.path.exists(log_dir) and config.REMOVE_LOG:
246 | shutil.rmtree(log_dir)
247 | train_writer = tensorboardX.SummaryWriter(log_dir)
248 | return train_writer
249 |
250 |
251 | def cuda(self, *args):
252 | return (item.to(self.config.DEVICE) for item in args)
253 |
254 |
255 | def write_loss(self, logs, train_writer):
256 | iteration = [x[1] for x in logs if x[0]=='iter']
257 | for x in logs:
258 | if x[0].startswith('l_'):
259 | train_writer.add_scalar(x[0], x[1], iteration[-1])
260 |
261 | def write_image(self, result, train_writer, iterations, label):
262 | if result:
263 | name = '%s/model%d_sample_%08d'%(self.samples_path, self.config.MODEL, iterations) + label + '.jpg'
264 | write_2images(result, self.config.SAMPLE_SIZE, name)
265 | write_2tensorboard(iterations, result, train_writer, self.config.SAMPLE_SIZE, label)
266 |
267 |
268 | def postprocess(self, x):
269 | x = (x + 1) / 2
270 | x.clamp_(0, 1)
271 | return x
272 |
273 | def metrics(self, inputs, gts):
274 | inputs = self.postprocess(inputs)
275 | gts = self.postprocess(gts)
276 | psnr_value=[]
277 | l1_value = torch.mean(torch.abs(inputs-gts))
278 |
279 | [b,n,w,h] = inputs.size()
280 | inputs = (inputs*255.0).int().float()/255.0
281 | gts = (gts*255.0).int().float()/255.0
282 |
283 | for i in range(inputs.size(0)):
284 | inputs_p = inputs[i,:,:,:].cpu().numpy().astype(np.float32).transpose(1,2,0)
285 | gts_p = gts[i,:,:,:].cpu().numpy().astype(np.float32).transpose(1,2,0)
286 | psnr_value.append(compare_psnr(inputs_p, gts_p, data_range=1))
287 |
288 | psnr_value = np.average(psnr_value)
289 | inputs = inputs.view(b*n, w, h).cpu().numpy().astype(np.float32).transpose(1,2,0)
290 | gts = gts.view(b*n, w, h).cpu().numpy().astype(np.float32).transpose(1,2,0)
291 | ssim_value = compare_ssim(inputs, gts, data_range=1, win_size=51, multichannel=True)
292 | return psnr_value, ssim_value, l1_value
293 |
294 |
295 |
296 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from PIL import Image
8 | import torch.nn.init as init
9 | import math
10 | import torch
11 | import torchvision.utils as vutils
12 | from natsort import natsorted
13 |
14 | def create_dir(dir):
15 | if not os.path.exists(dir):
16 | os.makedirs(dir)
17 |
18 |
19 | def create_mask(width, height, mask_width, mask_height, x=None, y=None):
20 | mask = np.zeros((height, width))
21 | mask_x = x if x is not None else random.randint(0, width - mask_width)
22 | mask_y = y if y is not None else random.randint(0, height - mask_height)
23 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
24 | return mask
25 |
26 |
27 | def stitch_images(inputs, *outputs, img_per_row=2):
28 | gap = 5
29 | columns = len(outputs) + 1
30 |
31 | width, height = inputs[0][:, :, 0].shape
32 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
33 | images = [inputs, *outputs]
34 |
35 | for ix in range(len(inputs)):
36 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
37 | yoffset = int(ix / img_per_row) * height
38 |
39 | for cat in range(len(images)):
40 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
41 | im = Image.fromarray(im)
42 | img.paste(im, (xoffset + cat * width, yoffset))
43 |
44 | return img
45 |
46 |
47 | def imshow(img, title=''):
48 | fig = plt.gcf()
49 | fig.canvas.set_window_title(title)
50 | plt.axis('off')
51 | plt.imshow(img, interpolation='none')
52 | plt.show()
53 |
54 |
55 | def imsave(img, path):
56 | # print(img)
57 | # print(img.shape)
58 | img = img.permute(0, 2, 3, 1)
59 | im = Image.fromarray(img.cpu().detach().numpy().astype(np.uint8).squeeze())
60 | im.save(path)
61 |
62 |
63 | class Progbar(object):
64 | """Displays a progress bar.
65 |
66 | Arguments:
67 | target: Total number of steps expected, None if unknown.
68 | width: Progress bar width on screen.
69 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
70 | stateful_metrics: Iterable of string names of metrics that
71 | should *not* be averaged over time. Metrics in this list
72 | will be displayed as-is. All others will be averaged
73 | by the progbar before display.
74 | interval: Minimum visual progress update interval (in seconds).
75 | """
76 |
77 | def __init__(self, target, width=25, verbose=1, interval=0.05,
78 | stateful_metrics=None):
79 | self.target = target
80 | self.width = width
81 | self.verbose = verbose
82 | self.interval = interval
83 | if stateful_metrics:
84 | self.stateful_metrics = set(stateful_metrics)
85 | else:
86 | self.stateful_metrics = set()
87 |
88 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
89 | sys.stdout.isatty()) or
90 | 'ipykernel' in sys.modules or
91 | 'posix' in sys.modules)
92 | self._total_width = 0
93 | self._seen_so_far = 0
94 | # We use a dict + list to avoid garbage collection
95 | # issues found in OrderedDict
96 | self._values = {}
97 | self._values_order = []
98 | self._start = time.time()
99 | self._last_update = 0
100 |
101 | def update(self, current, values=None):
102 | """Updates the progress bar.
103 |
104 | Arguments:
105 | current: Index of current step.
106 | values: List of tuples:
107 | `(name, value_for_last_step)`.
108 | If `name` is in `stateful_metrics`,
109 | `value_for_last_step` will be displayed as-is.
110 | Else, an average of the metric over time will be displayed.
111 | """
112 | values = values or []
113 | for k, v in values:
114 | if k not in self._values_order:
115 | self._values_order.append(k)
116 | if k not in self.stateful_metrics:
117 | if k not in self._values:
118 | self._values[k] = [v * (current - self._seen_so_far),
119 | current - self._seen_so_far]
120 | else:
121 | self._values[k][0] += v * (current - self._seen_so_far)
122 | self._values[k][1] += (current - self._seen_so_far)
123 | else:
124 | self._values[k] = v
125 | self._seen_so_far = current
126 |
127 | now = time.time()
128 | info = ' - %.0fs' % (now - self._start)
129 | if self.verbose == 1:
130 | if (now - self._last_update < self.interval and
131 | self.target is not None and current < self.target):
132 | return
133 |
134 | prev_total_width = self._total_width
135 | if self._dynamic_display:
136 | sys.stdout.write('\b' * prev_total_width)
137 | sys.stdout.write('\r')
138 | else:
139 | sys.stdout.write('\n')
140 |
141 | if self.target is not None:
142 | numdigits = int(np.floor(np.log10(self.target))) + 1
143 | barstr = '%%%dd/%d [' % (numdigits, self.target)
144 | bar = barstr % current
145 | prog = float(current) / self.target
146 | prog_width = int(self.width * prog)
147 | if prog_width > 0:
148 | bar += ('=' * (prog_width - 1))
149 | if current < self.target:
150 | bar += '>'
151 | else:
152 | bar += '='
153 | bar += ('.' * (self.width - prog_width))
154 | bar += ']'
155 | else:
156 | bar = '%7d/Unknown' % current
157 |
158 | self._total_width = len(bar)
159 | sys.stdout.write(bar)
160 |
161 | if current:
162 | time_per_unit = (now - self._start) / current
163 | else:
164 | time_per_unit = 0
165 | if self.target is not None and current < self.target:
166 | eta = time_per_unit * (self.target - current)
167 | if eta > 3600:
168 | eta_format = '%d:%02d:%02d' % (eta // 3600,
169 | (eta % 3600) // 60,
170 | eta % 60)
171 | elif eta > 60:
172 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
173 | else:
174 | eta_format = '%ds' % eta
175 |
176 | info = ' - ETA: %s' % eta_format
177 | else:
178 | if time_per_unit >= 1:
179 | info += ' %.0fs/step' % time_per_unit
180 | elif time_per_unit >= 1e-3:
181 | info += ' %.0fms/step' % (time_per_unit * 1e3)
182 | else:
183 | info += ' %.0fus/step' % (time_per_unit * 1e6)
184 |
185 | for k in self._values_order:
186 | info += ' - %s:' % k
187 | if isinstance(self._values[k], list):
188 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
189 | if abs(avg) > 1e-3:
190 | info += ' %.4f' % avg
191 | else:
192 | info += ' %.4e' % avg
193 | else:
194 | info += ' %s' % self._values[k]
195 |
196 | self._total_width += len(info)
197 | if prev_total_width > self._total_width:
198 | info += (' ' * (prev_total_width - self._total_width))
199 |
200 | if self.target is not None and current >= self.target:
201 | info += '\n'
202 |
203 | sys.stdout.write(info)
204 | sys.stdout.flush()
205 |
206 | elif self.verbose == 2:
207 | if self.target is None or current >= self.target:
208 | for k in self._values_order:
209 | info += ' - %s:' % k
210 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
211 | if avg > 1e-3:
212 | info += ' %.4f' % avg
213 | else:
214 | info += ' %.4e' % avg
215 | info += '\n'
216 |
217 | sys.stdout.write(info)
218 | sys.stdout.flush()
219 |
220 | self._last_update = now
221 |
222 | def add(self, n, values=None):
223 | self.update(self._seen_so_far + n, values)
224 |
225 | # network init function
226 | def weights_init(init_type='gaussian'):
227 | def init_fun(m):
228 | classname = m.__class__.__name__
229 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
230 | if init_type == 'gaussian':
231 | init.normal_(m.weight.data, 0.0, 0.02)
232 | elif init_type == 'xavier':
233 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
234 | elif init_type == 'kaiming':
235 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
236 | elif init_type == 'orthogonal':
237 | init.orthogonal_(m.weight.data, gain=math.sqrt(2))
238 | elif init_type == 'default':
239 | pass
240 | else:
241 | assert 0, "Unsupported initialization: {}".format(init_type)
242 | if hasattr(m, 'bias') and m.bias is not None:
243 | init.constant_(m.bias.data, 0.0)
244 | return init_fun
245 |
246 |
247 | # Get model list for resume
248 | def get_model_list(dirname, key):
249 | if os.path.exists(dirname) is False:
250 | return None
251 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
252 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
253 | if gen_models == []:
254 | return None
255 | gen_models.sort()
256 | last_model_name = gen_models[-1]
257 | return last_model_name
258 |
259 | def get_iteration(dir_name, file_name, net_name):
260 | if os.path.exists(os.path.join(dir_name, file_name)) is False:
261 | return None
262 | if 'latest' in file_name:
263 | gen_models = [os.path.join(dir_name, f) for f in os.listdir(dir_name) if
264 | os.path.isfile(os.path.join(dir_name, f)) and (not 'latest' in f) and (".pt" in f) and (net_name in f)]
265 | if gen_models == []:
266 | return 0
267 | model_name = os.path.basename(natsorted(gen_models)[-1])
268 | else:
269 | model_name = file_name
270 | iterations = int(model_name.replace('_net_'+net_name+'.pth', ''))
271 | return iterations
272 |
273 | def __denorm(x):
274 | x = (x + 1) / 2
275 | return x.clamp_(0, 1)
276 |
277 |
278 | def __write_images(image_outputs, display_image_num, file_name):
279 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
280 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
281 | image_tensor = __denorm(image_tensor)
282 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=False)
283 | vutils.save_image(image_grid, file_name, nrow=1)
284 |
285 |
286 | def write_2images(image_outputs, display_image_num, name):
287 | n = len(image_outputs)
288 | __write_images(image_outputs[0:n], display_image_num, name)
289 |
290 | def write_2tensorboard(iterations, results, train_writer, display_image_num, name):
291 | results = [images.expand(-1, 3, -1, -1) for images in results] # expand gray-scale images to 3 channels
292 | image_tensor = torch.cat([images[:display_image_num] for images in results], 0)
293 | image_tensor = __denorm(image_tensor)
294 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=False)
295 | train_writer.add_image(name, image_grid, iterations)
296 |
297 |
298 | def write_flow_visualization(flow, images_input, images_gt, name):
299 | if not os.path.exists(name):
300 | os.mkdir(name)
301 |
302 | kernel_size = 16
303 | n, _, h, w = flow.size()
304 |
305 | x = torch.arange(w).view(1, -1).expand(h, -1)
306 | y = torch.arange(h).view(-1, 1).expand(-1, w)
307 | grid = torch.stack([x,y], dim=0).float().cuda()
308 | grid = grid.unsqueeze(0).expand(n, -1, -1, -1)
309 | grid = grid+flow
310 |
311 | images_input = torch.nn.functional.interpolate(images_input, size=(h,w))
312 | images_gt = torch.nn.functional.interpolate(images_gt, size=(h,w))
313 |
314 | index=0
315 | for i in range(0, h, kernel_size):
316 | for j in range(0, w, kernel_size):
317 | img_ = _write_flow_helper(i, j, grid, kernel_size)
318 | out = img_+images_input
319 |
320 | name_write = os.path.join(name, '%04d.jpg'%(index))
321 | write_2images([out, images_gt,flow[:,0,:,:].unsqueeze(1)/30, flow[:,1,:,:].unsqueeze(1)/30], n, name_write)
322 |
323 | index=index+1
324 |
325 |
326 | def _write_flow_helper( i, j, grid, kernel_size):
327 | n, _, h, w = grid.size()
328 | img = torch.zeros(n,3,h,w).cuda()
329 | _grid = grid[:,:,i:i+kernel_size,j:j+kernel_size]
330 | img[:,:,i:i+kernel_size,j:j+kernel_size] = 0.5
331 |
332 | for _i in range(kernel_size):
333 | for _j in range(kernel_size):
334 | for _b in range(n):
335 | index_x = int(min(max(_grid[_b,0,_i,_j],0), w-1))
336 | index_y = int(min(max(_grid[_b,1,_i,_j],0), h-1))
337 | img[_b, :, index_y, index_x]=0.8
338 |
339 | return img
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from main import main
2 |
3 | main('test')
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from main import main
2 |
3 | main('train')
--------------------------------------------------------------------------------